# -*- coding: utf-8 -*-# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")## Copyright (c) 2014-2021 Megvii Inc. All rights reserved.## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.importcollections.abcimportmathfromabcimportABC,abstractmethodfromtypingimportAny,Generator,Iterator,List,Unionimportnumpyasnpimportmegengine.distributedasdist
[文档]classSampler(ABC):r"""An abstract base class for all Sampler"""@abstractmethoddef__init__(self):pass
[文档]classMapSampler(Sampler):r"""Sampler for map dataset. Args: dataset: dataset to sample from. batch_size: batch size for batch method. drop_last: set ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch_size, then the last batch will be smaller. Default: False num_samples: number of samples assigned to one rank. world_size: number of ranks. rank: rank id, non-negative interger within 0 and ``world_size``. seed: seed for random operators. """def__init__(self,dataset,batch_size=1,drop_last=False,num_samples=None,world_size=None,rank=None,seed=None,):if(notisinstance(batch_size,int)orisinstance(batch_size,bool)orbatch_size<=0):raiseValueError("batch_size should be a positive integer value, ""but got batch_size={}".format(batch_size))ifnotisinstance(drop_last,bool):raiseValueError("drop_last should be a boolean value, but got ""drop_last={}".format(drop_last))ifnum_samplesisnotNoneand(notisinstance(num_samples,int)orisinstance(num_samples,bool)ornum_samples<=0):raiseValueError("num_samples should be a positive integer ""value, but got num_samples={}".format(num_samples))self.batch_size=batch_sizeself.dataset=datasetself.drop_last=drop_lastifworld_sizeisNone:world_size=dist.get_world_size()ifdist.is_distributed()else1self.world_size=world_sizeifrankisNone:rank=dist.get_rank()ifdist.is_distributed()else0self.rank=rankifnum_samplesisNone:num_samples=len(self.dataset)self.num_samples=int(math.ceil(num_samples/self.world_size))# Make sure seeds are the same at each rankifseedisNoneandself.world_size>1:seed=0self.rng=np.random.RandomState(seed)def__iter__(self)->Union[Generator,Iterator]:returnself.batch()def__len__(self)->int:ifself.drop_last:returnself.num_samples//self.batch_sizeelse:returnint(math.ceil(self.num_samples/self.batch_size))
[文档]defsample(self):r"""Return a list contains all sample indices."""raiseNotImplementedError
[文档]defscatter(self,indices)->List:r"""Scatter method is used for splitting indices into subset, each subset will be assigned to a rank. Indices are evenly splitted by default. If customized indices assignment method is needed, please rewrite this method. """total_size=self.num_samples*self.world_size# add extra indices to make it evenly divisibleindices+=indices[:(total_size-len(indices))]assertlen(indices)==total_size# subsampleindices=indices[self.rank:total_size:self.world_size]assertlen(indices)==self.num_samplesreturnindices
[文档]defbatch(self)->Iterator[List[Any]]:r"""Batch method provides a batch indices generator."""indices=list(self.sample())# user might pass the world_size parameter without dist,# so dist.is_distributed() should not be usedifself.world_size>1:indices=self.scatter(indices)step,length=self.batch_size,len(indices)batch_index=[indices[i:i+step]foriinrange(0,length,step)]ifself.drop_lastandlen(batch_index[-1])<self.batch_size:batch_index.pop()returniter(batch_index)
[文档]classStreamSampler(Sampler):r"""Sampler for stream dataset. Warning: In the case of multiple machines, sampler should ensure that each worker gets different data. But this class cannot do it yet, please build your own dataset and sampler to achieve this goal. Usually, :meth:`~.StreamDataset.__iter__` can return different iterator by ``rank = dist.get_rank()``. So that they will get different data. """def__init__(self,batch_size=1):self.batch_size=batch_sizedef__iter__(self):returnselfdef__next__(self):returniter(range(self.batch_size))
[文档]classSequentialSampler(MapSampler):r"""Sample elements sequentially. Args: dataset: dataset to sample from. batch_size: batch size for batch method. drop_last: set ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch_size, then the last batch will be smaller. Default: False indices: indice of samples. world_size: number of ranks. rank: rank id, non-negative interger within 0 and ``world_size``. """def__init__(self,dataset,batch_size=1,drop_last=False,indices=None,world_size=None,rank=None,):super().__init__(dataset,batch_size,drop_last,None,world_size,rank)ifindicesisnotNoneandnotisinstance(indices,collections.abc.Sequence):raiseValueError("indices should be None or a sequence, ""but got indices={}".format(indices))self.indices=indices
[文档]defsample(self)->Iterator[Any]:r"""Return a generator."""ifself.indicesisNone:returniter(range(len(self.dataset)))else:returnself.indices
[文档]classRandomSampler(MapSampler):r"""Sample elements randomly without replacement. Args: dataset: dataset to sample from. batch_size: batch size for batch method. drop_last: set ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch_size, then the last batch will be smaller. Default: False indices: indice of samples. world_size: number of ranks. rank: rank id, non-negative interger within 0 and ``world_size``. seed: seed for random operators. """def__init__(self,dataset,batch_size=1,drop_last=False,indices=None,world_size=None,rank=None,seed=None,):super().__init__(dataset,batch_size,drop_last,None,world_size,rank,seed)ifindicesisnotNoneandnotisinstance(indices,collections.abc.Sequence):raiseValueError("indices should be None or a sequence, ""but got indices={}".format(indices))self.indices=indices
[文档]classReplacementSampler(MapSampler):r"""Sample elements randomly with replacement. Args: dataset: dataset to sample from. batch_size: batch size for batch method. drop_last: set ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch_size, then the last batch will be smaller. Default: False num_samples: number of samples assigned to one rank. weights: weights for sampling indices, it could be unnormalized weights. world_size: number of ranks. rank: rank id, non-negative interger within 0 and ``world_size``. seed: seed for random operators. """def__init__(self,dataset,batch_size=1,drop_last=False,num_samples=None,weights=None,world_size=None,rank=None,seed=None,):super().__init__(dataset,batch_size,drop_last,num_samples,world_size,rank,seed)ifweightsisnotNone:ifnotisinstance(weights,collections.abc.Sequence):raiseValueError("weights should be None or a sequence, ""but got weights={}".format(weights))iflen(weights)!=len(dataset):raiseValueError("len(dataset)={} should be equal to""len(weights)={}".format(len(dataset),len(weights)))self.weights=weightsifself.weightsisnotNone:self.weights=np.array(weights)/sum(weights)