megengine.data.sampler 源代码

# -*- coding: utf-8 -*-
import collections.abc
import math
from abc import ABC, abstractmethod
from typing import Any, Generator, Iterator, List, Union

import numpy as np

from .. import distributed as dist


[文档]class Sampler(ABC): r"""An abstract base class for all Sampler""" @abstractmethod def __init__(self): pass
[文档]class MapSampler(Sampler): r"""Sampler for map dataset. Args: dataset: dataset to sample from. batch_size: batch size for batch method. drop_last: set ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch_size, then the last batch will be smaller. Default: False num_samples: number of samples assigned to one rank. world_size: number of ranks. rank: rank id, non-negative interger within 0 and ``world_size``. seed: seed for random operators. """ def __init__( self, dataset, batch_size=1, drop_last=False, num_samples=None, world_size=None, rank=None, seed=None, ): if ( not isinstance(batch_size, int) or isinstance(batch_size, bool) or batch_size <= 0 ): raise ValueError( "batch_size should be a positive integer value, " "but got batch_size={}".format(batch_size) ) if not isinstance(drop_last, bool): raise ValueError( "drop_last should be a boolean value, but got " "drop_last={}".format(drop_last) ) if num_samples is not None and ( not isinstance(num_samples, int) or isinstance(num_samples, bool) or num_samples <= 0 ): raise ValueError( "num_samples should be a positive integer " "value, but got num_samples={}".format(num_samples) ) self.batch_size = batch_size self.dataset = dataset self.drop_last = drop_last if world_size is None: world_size = dist.get_world_size() if dist.is_distributed() else 1 self.world_size = world_size if rank is None: rank = dist.get_rank() if dist.is_distributed() else 0 self.rank = rank if num_samples is None: num_samples = len(self.dataset) self.num_samples = int(math.ceil(num_samples / self.world_size)) # Make sure seeds are the same at each rank if seed is None and self.world_size > 1: seed = 0 self.rng = np.random.RandomState(seed) def __iter__(self) -> Union[Generator, Iterator]: return self.batch() def __len__(self) -> int: if self.drop_last: return self.num_samples // self.batch_size else: return int(math.ceil(self.num_samples / self.batch_size))
[文档] def sample(self): r"""Return a list contains all sample indices.""" raise NotImplementedError
[文档] def scatter(self, indices) -> List: r"""Scatter method is used for splitting indices into subset, each subset will be assigned to a rank. Indices are evenly splitted by default. If customized indices assignment method is needed, please rewrite this method. """ total_size = self.num_samples * self.world_size # add extra indices to make it evenly divisible indices += indices[: (total_size - len(indices))] assert len(indices) == total_size # subsample indices = indices[self.rank : total_size : self.world_size] assert len(indices) == self.num_samples return indices
[文档] def batch(self) -> Iterator[List[Any]]: r"""Batch method provides a batch indices generator.""" indices = list(self.sample()) # user might pass the world_size parameter without dist, # so dist.is_distributed() should not be used if self.world_size > 1: indices = self.scatter(indices) step, length = self.batch_size, len(indices) batch_index = [indices[i : i + step] for i in range(0, length, step)] if self.drop_last and len(batch_index[-1]) < self.batch_size: batch_index.pop() return iter(batch_index)
[文档]class StreamSampler(Sampler): r"""Sampler for stream dataset. Warning: In the case of multiple machines, sampler should ensure that each worker gets different data. But this class cannot do it yet, please build your own dataset and sampler to achieve this goal. Usually, :meth:`~.StreamDataset.__iter__` can return different iterator by ``rank = dist.get_rank()``. So that they will get different data. """ def __init__(self, batch_size=1): self.batch_size = batch_size def __iter__(self): return self def __next__(self): return iter(range(self.batch_size))
[文档]class SequentialSampler(MapSampler): r"""Sample elements sequentially. Args: dataset: dataset to sample from. batch_size: batch size for batch method. drop_last: set ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch_size, then the last batch will be smaller. Default: False indices: indice of samples. world_size: number of ranks. rank: rank id, non-negative interger within 0 and ``world_size``. """ def __init__( self, dataset, batch_size=1, drop_last=False, indices=None, world_size=None, rank=None, ): super().__init__(dataset, batch_size, drop_last, None, world_size, rank) if indices is not None and not isinstance(indices, collections.abc.Sequence): raise ValueError( "indices should be None or a sequence, " "but got indices={}".format(indices) ) self.indices = indices
[文档] def sample(self) -> Iterator[Any]: r"""Return a generator.""" if self.indices is None: return iter(range(len(self.dataset))) else: return self.indices
[文档]class RandomSampler(MapSampler): r"""Sample elements randomly without replacement. Args: dataset: dataset to sample from. batch_size: batch size for batch method. drop_last: set ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch_size, then the last batch will be smaller. Default: False indices: indice of samples. world_size: number of ranks. rank: rank id, non-negative interger within 0 and ``world_size``. seed: seed for random operators. """ def __init__( self, dataset, batch_size=1, drop_last=False, indices=None, world_size=None, rank=None, seed=None, ): super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed) if indices is not None and not isinstance(indices, collections.abc.Sequence): raise ValueError( "indices should be None or a sequence, " "but got indices={}".format(indices) ) self.indices = indices
[文档] def sample(self) -> List: if self.indices is None: return self.rng.permutation(len(self.dataset)).tolist() else: return self.rng.permutation(self.indices).tolist()
[文档]class ReplacementSampler(MapSampler): r"""Sample elements randomly with replacement. Args: dataset: dataset to sample from. batch_size: batch size for batch method. drop_last: set ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch_size, then the last batch will be smaller. Default: False num_samples: number of samples assigned to one rank. weights: weights for sampling indices, it could be unnormalized weights. world_size: number of ranks. rank: rank id, non-negative interger within 0 and ``world_size``. seed: seed for random operators. """ def __init__( self, dataset, batch_size=1, drop_last=False, num_samples=None, weights=None, world_size=None, rank=None, seed=None, ): super().__init__( dataset, batch_size, drop_last, num_samples, world_size, rank, seed ) if weights is not None: if not isinstance(weights, collections.abc.Sequence): raise ValueError( "weights should be None or a sequence, " "but got weights={}".format(weights) ) if len(weights) != len(dataset): raise ValueError( "len(dataset)={} should be equal to" "len(weights)={}".format(len(dataset), len(weights)) ) self.weights = weights if self.weights is not None: self.weights = np.array(weights) / sum(weights)
[文档] def sample(self) -> List: n = len(self.dataset) if self.weights is None: return self.rng.randint(n, size=self.num_samples).tolist() else: return self.rng.multinomial(n, self.weights, self.num_samples).tolist()
[文档]class Infinite(MapSampler): r"""Infinite Sampler warper for basic sampler."""
[文档] def sample(self): raise NotImplementedError("sample method not supported in Infinite")
def __init__(self, sampler): self.sampler = sampler self.sampler_iter = iter(self.sampler) def __iter__(self): return self def __next__(self): try: index = next(self.sampler_iter) except StopIteration: self.sampler_iter = iter(self.sampler) index = next(self.sampler_iter) return index def __len__(self): return np.iinfo(np.int64).max