# -*- coding: utf-8 -*-# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")## Copyright (c) 2014-2021 Megvii Inc. All rights reserved.## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.importcollectionsimportgcimportmathimportmultiprocessingimportosimportplatformimportqueueimportrandomimportthreadingimporttimefromtypingimportCallable,Unionimportnumpyasnpfrom..deviceimport_sh,get_default_devicefrom..functional.tensorimportcopyfrom..loggerimportget_loggerfrom..random.rngimport_random_seed_generatorfrom..tensorimportTensorfrom.collatorimportCollatorfrom.datasetimportDataset,StreamDatasetfrom.samplerimportMapSampler,Sampler,SequentialSampler,StreamSamplerfrom.transformimportPseudoTransform,Transformtry:importthreadexcept:import_threadasthreadlogger=get_logger(__name__)GLOBAL_TIMEOUT=5defraise_timeout_error():raiseRuntimeError("dataloader timeout")
[文档]classDataLoader:r"""Provides a convenient way to iterate on a given dataset. DataLoader combines a dataset with :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`, make it flexible to get minibatch continually from a dataset. Args: dataset: dataset from which to load the minibatch. sampler: defines the strategy to sample data from the dataset. transform: defined the transforming strategy for a sampled batch. Default: None collator: defined the merging strategy for a transformed batch. Default: None num_workers: the number of sub-process to load, transform and collate the batch. ``0`` means using single-process. Default: 0 timeout: if positive, means the timeout value(second) for collecting a batch from workers. Default: 0 timeout_event: callback function triggered by timeout, default to raise runtime error. divide: define the paralleling strategy in multi-processing mode. ``True`` means one batch is divided into :attr:`num_workers` pieces, and the workers will process these pieces parallelly. ``False`` means different sub-process will process different batch. Default: False preload: whether to enable the preloading strategy of the dataloader. When enabling, the dataloader will preload one batch to the device memory to speed up the whole training process. All values in the map, list, and tuple will be converted to :class:`~.Tensor` by preloading, and you will get :class:`~.Tensor` instead of the original Numpy array or Python number. .. note:: By enabling preload, tensors' host2device copy and device kernel execution will be overlapped, which will improve the training speed at the cost of higher device memory usage (due to one more batch data on device memory). This feature saves more time when your NN training time is short or your machine's host PCIe bandwidth for each device is low. """__initialized=Falsedef__init__(self,dataset:Dataset,sampler:Sampler=None,transform:Transform=None,collator:Collator=None,num_workers:int=0,timeout:int=0,timeout_event:Callable=raise_timeout_error,divide:bool=False,preload:bool=False,):ifnum_workers<0:raiseValueError("num_workers should not be negative")iftimeout<0:raiseValueError("timeout should not be negative")ifdivideandnum_workers<=1:raiseValueError("divide should not be set to True when num_workers <= 1")self.dataset=datasetself.num_workers=num_workersself.timeout=timeoutself.timeout_event=timeout_eventself.divide=divideself.preload=preloadifisinstance(dataset,StreamDataset):self.sampler=samplerifsamplerelseStreamSampler(batch_size=1)assertisinstance(self.sampler,StreamSampler),"types of dataset and sampler do not match"else:assertisinstance(dataset,Dataset),"Can not recognize this kind of dataset: %s"%type(dataset)self.sampler=(samplerifsamplerelseSequentialSampler(dataset,batch_size=1,drop_last=False))assertisinstance(self.sampler,MapSampler),"types of dataset and sampler do not match"ifdivide:ifself.sampler.batch_size<=self.num_workers:raiseValueError("batch size must not smaller than num_workers in divide mode.")elifself.sampler.batch_size%self.num_workers:logger.warning("batch size is not divisible by num_workers, may lose performance in divide mode.")iftransformisNone:self.transform=PseudoTransform()else:self.transform=transformifcollatorisNone:self.collator=Collator()else:self.collator=collatorself.__initialized=Truedef__iter__(self):ifplatform.system()=="Windows"andself.num_workers>0:print("pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero")self.num_workers=0ifos.getenv("TERMUX_VERSION"):# FIXME: termux install pyarrow will build error now# remove this logic after pyarrow fix this issueprint("pyarrow do not support on termux env now, changing num_workers to be zero")self.num_workers=0ifisinstance(self.dataset,StreamDataset):ifnotself.num_workers:return_SerialStreamDataLoaderIter(self,self.preload)else:return_ParallelStreamDataLoaderIter(self,self.preload)else:assertisinstance(self.dataset,Dataset),"Can not recognize this kind of dataset: %s"%type(self.dataset)ifnotself.num_workers:return_SerialMapDataLoaderIter(self,self.preload)else:return_ParallelMapDataLoaderIter(self,self.preload)def__len__(self):returnlen(self.sampler)
classPreLoader:def__init__(self,preload):ifpreload:self.default_device=get_default_device()self.pre_load_device=self.default_device+":"+str(_sh.get_next())self.pre_load_device_cache=Noneself.preload=preload""" strategy one: load from numpy data, and generate dtype tensor """def_load_tensor(self,batch,cached=True):ifisinstance(batch,np.ndarray):device=self.pre_load_deviceifcachedelseself.default_devicereturnTensor(batch,device=device)elifisinstance(batch,collections.abc.Mapping):return{k:self._load_tensor(v,cached)fork,vinbatch.items()}elifisinstance(batch,tuple)andhasattr(batch,"_fields"):# namedtuplereturntype(batch)(*(self._load_tensor(value,cached)forvalueinbatch))elifisinstance(batch,collections.abc.Sequence):return[self._load_tensor(value,cached)forvalueinbatch]else:returnbatch""" strategy two: load from cache that is already tensor just do d2d copy """def_load_cache(self,data):ifisinstance(data,Tensor):ifdata.device==self.default_device:returndatareturncopy(data,device=self.default_device)elifisinstance(data,collections.abc.Mapping):return{k:self._load_cache(v)fork,vindata.items()}elifisinstance(data,tuple)andhasattr(data,"_fields"):# namedtuplereturntype(data)(*(self._load_cache(value)forvalueindata))elifisinstance(data,collections.abc.Sequence):return[self._load_cache(value)forvalueindata]else:returndatadef_swap_out_cache(self):out=self._load_cache(self.pre_load_device_cache)self.pre_load_device_cache=None# clean cachereturnoutclass_BaseMapDataLoaderIter(PreLoader):def__init__(self,loader,preload):super().__init__(preload)self.dataset=loader.datasetself.sampler=loader.samplerself.seed=_random_seed_generator().__next__()self.transform=loader.transformself.collator=loader.collatorself.num_workers=loader.num_workersself.timeout=loader.timeoutself.timeout_event=loader.timeout_eventself.divide=loader.divideself.num_processed=0def_get_next_batch(self):raiseNotImplementedErrordef__len__(self):returnlen(self.sampler)def__iter__(self):returnselfdef__next__(self):ifself.preload:cached=self.pre_load_device_cacheifcachedisNone:# first and lastifself.num_processed>=len(self):# lastraiseStopIterationelifself.num_processed==0:# firstself._try_load_tensor(cached=False)# first do the h2dout=self._swap_out_cache()self._try_load_tensor()returnoutelse:ifself.num_processed>=len(self):raiseStopIterationminibatch=self._get_next_batch()self.num_processed+=1returnminibatchdef_try_load_tensor(self,cached=True):ifself.num_processed>=len(self):returnelse:self.num_processed+=1batch=self._get_next_batch()self.pre_load_device_cache=self._load_tensor(batch,cached)class_SerialMapDataLoaderIter(_BaseMapDataLoaderIter):def__init__(self,loader,preload):super(_SerialMapDataLoaderIter,self).__init__(loader,preload)self.indices_iter=iter(self.sampler)def_get_next_batch(self):indices=next(self.indices_iter)items=[self.dataset[idx]foridxinindices]trans_items=self.transform.apply_batch(items)returnself.collator.apply(trans_items)class_ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):__initialized=Falsedef__init__(self,loader,preload):super(_ParallelMapDataLoaderIter,self).__init__(loader,preload)self.task_queues=[multiprocessing.Queue(maxsize=2)for_inrange(self.num_workers)]self.feed_batch_idx=multiprocessing.Value("i",0)self.target_batch_idx=multiprocessing.Value("i",0)self.shutdown_flag=multiprocessing.Value("i",0)self.trans_data_queues=[multiprocessing.Queue(maxsize=1)for_inrange(self.num_workers)]# use shared-memory queue implemented by pyarrow plasma store.from.tools._queueimportPlasmaShmQueueself.batch_queue=PlasmaShmQueue(maxsize=2)self.task_feeding_worker=multiprocessing.Process(target=_task_feeding_loop,args=(iter(self.sampler),self.task_queues,self.num_workers,self.divide,self.shutdown_flag,self.feed_batch_idx,),daemon=True,)gc.collect()self.task_feeding_worker.start()self.workers=[]forworker_idinrange(self.num_workers):worker=multiprocessing.Process(target=_worker_loop,args=(self.dataset,self.task_queues[worker_id],self.trans_data_queues[worker_id],self.transform,self.seed+worker_id+1,self.shutdown_flag,),daemon=True,)gc.collect()worker.start()self.workers.append(worker)ifself.divide:self.data_collecting_worker=multiprocessing.Process(target=_data_gathering_loop,args=(self.trans_data_queues,self.batch_queue,self.collator,len(self),self.num_workers,self.shutdown_flag,self.target_batch_idx,),daemon=True,)else:self.data_collecting_worker=multiprocessing.Process(target=_data_selecting_loop,args=(self.trans_data_queues,self.batch_queue,self.collator,len(self),self.num_workers,self.shutdown_flag,self.target_batch_idx,),daemon=True,)gc.collect()self.data_collecting_worker.start()self.__initialized=Truedef_check_workers(self):# Check the status of each worker.ifnotself.data_collecting_worker.is_alive():exitcode=self.data_collecting_worker.exitcodeifexitcode!=0:raiseRuntimeError("data collecting worker died. {}".format(exitcode))ifnotself.task_feeding_worker.is_alive():exitcode=self.task_feeding_worker.exitcodeifexitcode!=0:raiseRuntimeError("task feeding worker died. {}".format(exitcode))forworker_id,workerinenumerate(self.workers):ifnotworker.is_alive():exitcode=worker.exitcodeifexitcode!=0:raiseRuntimeError("worker:{} died. {}".format(worker_id,exitcode))logger.debug("all workers are alive.")def_get_next_batch(self):start_time=time.time()whileTrue:self._check_workers()try:returnself.batch_queue.get(timeout=1)exceptqueue.Empty:logger.debug("batch queue empty!")waited_time=time.time()-start_timeifself.timeout>0:ifwaited_time>self.timeout:raiseRuntimeError("get_next_batch timeout!")def_shutdown(self):withself.shutdown_flag.get_lock():self.shutdown_flag.value=1ifself.task_feeding_worker.is_alive():self.task_feeding_worker.terminate()self.task_feeding_worker.join()ifself.data_collecting_worker.is_alive():self.data_collecting_worker.terminate()self.data_collecting_worker.join()forworkerinself.workers:ifworker.is_alive():worker.terminate()worker.join()forqinself.trans_data_queues:q.cancel_join_thread()q.close()forqinself.task_queues:q.cancel_join_thread()q.close()self.batch_queue.cancel_join_thread()self.batch_queue.close()def__del__(self):ifself.__initialized:self._shutdown()class_BaseStreamDataLoaderIter(PreLoader):def__init__(self,loader,preload):super().__init__(preload)self.dataset=loader.datasetself.sampler=loader.samplerself.transform=loader.transformself.collator=loader.collatorself.num_workers=loader.num_workersself.timeout=loader.timeoutself.timeout_event=loader.timeout_eventdef_get_next_batch(self):raiseNotImplementedErrordef_process_raw_data(self,raw_data):assertlen(raw_data)==2andisinstance(raw_data[0],bool),"StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."ifnotraw_data[0]:data=list((x,)forxinraw_data[1])else:data=raw_data[1]ret=[]foridxinrange(len(data[0])):ret.append(tuple(e[idx]foreindata))returnretdef__iter__(self):returnselfdef__next__(self):ifself.preload:ifself.pre_load_device_cacheisNone:self._try_load_tensor(cached=False)# load in currentout=self._swap_out_cache()self._try_load_tensor()# load in cachedreturnoutelse:returnself._get_next_batch()def_try_load_tensor(self,cached=True):batch=self._get_next_batch()self.pre_load_device_cache=self._load_tensor(batch,cached)class_SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):def__init__(self,loader,preload):super().__init__(loader,preload)self.dataset_iter=iter(self.dataset)self.idx=0self.unused=[]def_try_get_raw_data(self,start_time):raw_data=Nonewhilenotraw_data:try:ifself.timeout>0:timer=threading.Timer(self.timeout,thread.interrupt_main)timer.start()raw_data=next(self.dataset_iter)ifself.timeout>0:timer.cancel()exceptKeyboardInterrupt:raw_data=self.timeout_event()except:ifself.timeout>0:timer.cancel()waited_time=time.time()-start_timeifwaited_time>self.timeout:raw_data=self.timeout_event()returnraw_datadef_get_next_batch(self):ret=[]start_time=time.time()whilelen(ret)<self.sampler.batch_size:iflen(self.unused)!=0:batch_data=self.unusedelse:raw_data=self._try_get_raw_data(start_time)batch_data=self._process_raw_data(raw_data)whilelen(batch_data)!=0andlen(ret)<self.sampler.batch_size:data=batch_data.pop()ret.append(self.transform.apply(data))self.unused=batch_datareturnself.collator.apply(ret)class_ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):__initialized=Falsedef__init__(self,loader,preload):super().__init__(loader,preload)self.shutdown_flag=multiprocessing.Value("i",0)self.raw_data_queues=[multiprocessing.Queue(maxsize=1)for_inrange(self.num_workers)]self.trans_data_queues=[multiprocessing.Queue(maxsize=1)for_inrange(self.num_workers)]# shared-memory queue implemented by pyarrow plasma storefrom.tools._queueimportPlasmaShmQueueself.batch_queue=PlasmaShmQueue(maxsize=2)self.recieve_worker=multiprocessing.Process(target=self._worker_to_raw_data_queues,daemon=True)gc.collect()self.recieve_worker.start()self.transform_workers=[]forworker_idinrange(self.num_workers):worker=multiprocessing.Process(target=self._worker_to_trans_data_queues,args=(worker_id,),daemon=True)gc.collect()worker.start()self.transform_workers.append(worker)self.collect_worker=multiprocessing.Process(target=self._worker_to_batch_queue,daemon=True)gc.collect()self.collect_worker.start()self.__initialized=Truedef_put_raw_data_queues(self,raw_data,qidx):batch_data=self._process_raw_data(raw_data)fordatainbatch_data:whileTrue:qidx=qidx%self.num_workerstry:self.raw_data_queues[qidx].put(data)breakexceptqueue.Full:ifself.shutdown_flag.value==1:breaklogger.debug("raw data queue %d is full"%qidx)finally:qidx+=1returnqidxdef_worker_to_raw_data_queues(self):dataset_iter=iter(self.dataset)qidx=0whileTrue:ifself.shutdown_flag.value==1:breakraw_data=next(dataset_iter)qidx=self._put_raw_data_queues(raw_data,qidx)def_worker_to_trans_data_queues(self,worker_id):whileTrue:ifself.shutdown_flag.value==1:breaktry:data=self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT)exceptqueue.Empty:continuetrans_data=self.transform.apply(data)whileTrue:try:self.trans_data_queues[worker_id].put(trans_data)breakexceptqueue.Full:ifself.shutdown_flag.value==1:breaklogger.debug("batch queue if full")def_worker_to_batch_queue(self):cnt=-1trans_items=[]whileTrue:ifself.shutdown_flag.value==1:breakcnt+=1queue_id=cnt%self.num_workerstry:trans_item=self.trans_data_queues[queue_id].get(timeout=GLOBAL_TIMEOUT)exceptqueue.Empty:continuetrans_items.append(trans_item)iflen(trans_items)==self.sampler.batch_size:batch_data=self.collator.apply(trans_items)whileTrue:try:self.batch_queue.put(batch_data,timeout=1)breakexceptqueue.Full:ifself.shutdown_flag.value==1:breaklogger.debug("batch queue is full")trans_items=[]def_check_workers(self):ifnotself.collect_worker.is_alive():exitcode=self.collect_worker.exitcodeifexitcode!=0:raiseRuntimeError("collator worker died. {}".format(exitcode))forworker_id,workerinenumerate(self.transform_workers):ifnotworker.is_alive():exitcode=worker.exitcodeifexitcode!=0:raiseRuntimeError("worker: {} died. {}".format(worker_id,exitcode))def_get_next_batch(self):start_time=time.time()whileTrue:self._check_workers()try:returnself.batch_queue.get(timeout=1)exceptqueue.Empty:logger.debug("batch queue empty!")waited_time=time.time()-start_timeifself.timeout>0andwaited_time>self.timeout:self._put_raw_data_queues(self.timeout_event(),0)def_shutdown(self):withself.shutdown_flag.get_lock():self.shutdown_flag.value=1ifself.recieve_worker.is_alive():self.recieve_worker.terminate()self.recieve_worker.join()ifself.collect_worker.is_alive():self.collect_worker.terminate()self.collect_worker.join()forworkerinself.transform_workers:ifworker.is_alive():worker.terminate()worker.join()forqinself.raw_data_queues:q.cancel_join_thread()q.close()forqinself.trans_data_queues:q.cancel_join_thread()q.close()self.batch_queue.cancel_join_thread()self.batch_queue.close()def__del__(self):ifself.__initialized:self._shutdown()def_task_feeding_loop(indices_iter,task_queues,num_workers,divide,shutdown_flag,feed_batch_idx):# Feed the indices into the task queueswhileTrue:ifshutdown_flag.value==1:breakbatch_idx=feed_batch_idx.valuetry:indices=next(indices_iter)exceptStopIteration:breakifdivide:# make sure all task_queues is ready for putwhileany([q.full()forqintask_queues]):ifshutdown_flag.value==1:return# divide into small pieces, feed to different workers.sub_num=math.ceil(len(indices)/num_workers)forworker_idinrange(num_workers):sub_indices=indices[worker_id*sub_num:(worker_id+1)*sub_num]task_queues[worker_id].put((batch_idx,sub_indices))else:# distribute tasks to different workers uniformly.target_id=batch_idx%num_workerswhiletask_queues[target_id].full():ifshutdown_flag.value==1:returntask_queues[target_id].put((batch_idx,indices))withfeed_batch_idx.get_lock():feed_batch_idx.value+=1def_worker_loop(dataset,task_queue,trans_data_queue,transform,seed,shutdown_flag):# Get dataset items and do the transformrandom.seed(seed)np.random.seed(seed)whileTrue:ifshutdown_flag.value==1:breaktry:batch_idx,indices=task_queue.get(timeout=GLOBAL_TIMEOUT)exceptqueue.Empty:continueiflen(indices)>0:items=[dataset[idx]foridxinindices]trans_items=transform.apply_batch(items)else:# in case of incomplete last batchtrans_items=()whileTrue:try:trans_data_queue.put((batch_idx,trans_items),timeout=1)breakexceptqueue.Full:ifshutdown_flag.value==1:breaklogger.debug("batch part queue is full!")def_data_gathering_loop(trans_data_queues,batch_queue,collator,length,num_workers,shutdown_flag,target_idx,):# Gathering the small pieces of batch data into full batch datawhileTrue:ifshutdown_flag.value==1:breaktarget_batch_idx=target_idx.valueiftarget_batch_idx>=length:breakfull_trans_items=[]forworker_idinrange(num_workers):whileTrue:try:batch_idx,trans_items=trans_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT)breakexceptqueue.Empty:ifshutdown_flag.value==1:breaklogger.debug("worker:{} data queue get timeout! target batch idx:{}".format(worker_id,target_batch_idx))ifbatch_idx!=target_batch_idx:raiseRuntimeError("Unexperted batch_idx in data gathering loop. worker_id:{}.".format(worker_id))else:full_trans_items.extend(trans_items)# Merge different parts into a batch.full_batch=collator.apply(full_trans_items)whileTrue:try:batch_queue.put(full_batch,timeout=1)breakexceptqueue.Full:ifshutdown_flag.value==1:breaklogger.debug("batch queue is full!")withtarget_idx.get_lock():target_idx.value+=1batch_queue.disconnect_client()def_data_selecting_loop(trans_data_queues,batch_queue,collator,length,num_workers,shutdown_flag,target_idx,):# Make sure that batch is generated exactly with the same order as generated indiceswhileTrue:ifshutdown_flag.value==1:breaktarget_batch_idx=target_idx.valueiftarget_batch_idx>=length:breaktarget_worker_id=target_batch_idx%num_workerswhileTrue:try:batch_idx,trans_items=trans_data_queues[target_worker_id].get(timeout=GLOBAL_TIMEOUT)batch_data=collator.apply(trans_items)breakexceptqueue.Empty:ifshutdown_flag.value==1:breaklogger.debug("worker:{} data queue get timeout! target batch idx:{}".format(target_worker_id,target_batch_idx))ifbatch_idx!=target_batch_idx:raiseRuntimeError("batch_idx {} mismatch the target_batch_idx {}".format(batch_idx,target_batch_idx))whileTrue:try:batch_queue.put(batch_data,timeout=1)breakexceptqueue.Full:ifshutdown_flag.value==1:breaklogger.debug("batch queue is full!")withtarget_idx.get_lock():target_idx.value+=1batch_queue.disconnect_client()