# -*- coding: utf-8 -*-# BSD 3-Clause License## Copyright (c) Soumith Chintala 2016,# All rights reserved.# ---------------------------------------------------------------------# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")## Copyright (c) 2014-2021 Megvii Inc. All rights reserved.## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.## This file has been modified by Megvii ("Megvii Modifications").# All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved.# ---------------------------------------------------------------------importosimportshutilfromtqdmimporttqdmfrom....distributed.groupimportis_distributedfrom....loggerimportget_loggerfrom....serializationimportload,savefrom.folderimportImageFolderfrom.utilsimport_default_dataset_root,calculate_md5,untar,untargzlogger=get_logger(__name__)
[文档]classImageNet(ImageFolder):r"""Load ImageNet from raw files or folder. Expected folder looks like: .. code-block:: shell ${root}/ | [REQUIRED TAR FILES] |- ILSVRC2012_img_train.tar |- ILSVRC2012_img_val.tar |- ILSVRC2012_devkit_t12.tar.gz | [OPTIONAL IMAGE FOLDERS] |- train/cls/xxx.${img_ext} |- val/cls/xxx.${img_ext} |- ILSVRC2012_devkit_t12/data/meta.mat |- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt If the image folders don't exist, raw tar files are required to get extracted and processed. * if ``root`` contains ``self.target_folder`` depending on ``train``: * initialize ImageFolder with target_folder. * else: * if all raw files are in ``root``: * parse ``self.target_folder`` from raw files. * initialize ImageFolder with ``self.target_folder``. * else: * raise error. Args: root: root directory of imagenet data, if root is ``None``, use default_dataset_root. train: if ``True``, load the train split, otherwise load the validation split. """raw_file_meta={"train":("ILSVRC2012_img_train.tar","1d675b47d978889d74fa0da5fadfb00e"),"val":("ILSVRC2012_img_val.tar","29b22e2961454d5413ddabcf34fc5622"),"devkit":("ILSVRC2012_devkit_t12.tar.gz","fa75699e90414af021442c21a62c3abf"),}# ImageNet raw filesdefault_train_dir="train"default_val_dir="val"default_devkit_dir="ILSVRC2012_devkit_t12"def__init__(self,root:str=None,train:bool=True,**kwargs):# process the root pathifrootisNone:self.root=self._default_rootelse:self.root=rootifnotos.path.exists(self.root):raiseFileNotFoundError("dir %s does not exist"%self.root)self.devkit_dir=os.path.join(self.root,self.default_devkit_dir)ifnotos.path.exists(self.devkit_dir):logger.warning("devkit directory %s does not exists",self.devkit_dir)self._prepare_devkit()self.train=trainiftrain:self.target_folder=os.path.join(self.root,self.default_train_dir)else:self.target_folder=os.path.join(self.root,self.default_val_dir)ifnotos.path.exists(self.target_folder):logger.warning("expected image folder %s does not exist, try to load from raw file",self.target_folder,)ifnotself.check_raw_file():raiseFileNotFoundError("expected image folder %s does not exist, and raw files do not exist in %s"%(self.target_folder,self.root))elifis_distributed():raiseRuntimeError("extracting raw file shouldn't be done in distributed mode, use single process instead")eliftrain:self._prepare_train()else:self._prepare_val()super().__init__(self.target_folder,**kwargs)@propertydef_default_root(self):returnos.path.join(_default_dataset_root(),self.__class__.__name__)@propertydefvalid_ground_truth(self):groud_truth_path=os.path.join(self.devkit_dir,"data","ILSVRC2012_validation_ground_truth.txt")ifos.path.exists(groud_truth_path):withopen(groud_truth_path,"r")asf:val_labels=f.readlines()return[int(val_label)forval_labelinval_labels]else:raiseFileNotFoundError("valid ground truth file %s does not exist"%groud_truth_path)@propertydefmeta(self):try:returnload(os.path.join(self.devkit_dir,"meta.pkl"))exceptFileNotFoundError:importscipy.iometa_path=os.path.join(self.devkit_dir,"data","meta.mat")ifnotos.path.exists(meta_path):raiseFileNotFoundError("meta file %s does not exist"%meta_path)meta=scipy.io.loadmat(meta_path,squeeze_me=True)["synsets"]nums_children=list(zip(*meta))[4]meta=[meta[idx]foridx,num_childreninenumerate(nums_children)ifnum_children==0]idcs,wnids,classes=list(zip(*meta))[:3]classes=[tuple(clss.split(", "))forclssinclasses]idx_to_wnid=dict(zip(idcs,wnids))wnid_to_classes=dict(zip(wnids,classes))logger.info("saving cached meta file to %s",os.path.join(self.devkit_dir,"meta.pkl"),)save((idx_to_wnid,wnid_to_classes),os.path.join(self.devkit_dir,"meta.pkl"),)returnidx_to_wnid,wnid_to_classesdefcheck_raw_file(self)->bool:returnall([os.path.exists(os.path.join(self.root,value[0]))for_,valueinself.raw_file_meta.items()])def_organize_val_data(self):id2wnid=self.meta[0]val_idcs=self.valid_ground_truthval_wnids=[id2wnid[idx]foridxinval_idcs]val_images=sorted([os.path.join(self.target_folder,image)forimageinos.listdir(self.target_folder)])logger.debug("mkdir for val set wnids")forwnidinset(val_wnids):os.makedirs(os.path.join(self.root,self.default_val_dir,wnid))logger.debug("mv val images into wnids dir")forwnid,img_fileintqdm(zip(val_wnids,val_images)):shutil.move(img_file,os.path.join(self.root,self.default_val_dir,wnid,os.path.basename(img_file)),)def_prepare_val(self):assertnotself.trainraw_filename,checksum=self.raw_file_meta["val"]raw_file=os.path.join(self.root,raw_filename)logger.info("checksum valid tar file %s ...",raw_file)assert(calculate_md5(raw_file)==checksum),"checksum mismatch, {} may be damaged".format(raw_file)logger.info("extract valid tar file... this may take 10-20 minutes")untar(os.path.join(self.root,raw_file),self.target_folder)self._organize_val_data()def_prepare_train(self):assertself.trainraw_filename,checksum=self.raw_file_meta["train"]raw_file=os.path.join(self.root,raw_filename)logger.info("checksum train tar file %s ...",raw_file)assert(calculate_md5(raw_file)==checksum),"checksum mismatch, {} may be damaged".format(raw_file)logger.info("extract train tar file.. this may take several hours")untar(os.path.join(self.root,raw_file),self.target_folder,)paths=[os.path.join(self.target_folder,child_dir)forchild_dirinos.listdir(self.target_folder)]forpathintqdm(paths):untar(path,os.path.splitext(path)[0],remove=True)def_prepare_devkit(self):raw_filename,checksum=self.raw_file_meta["devkit"]raw_file=os.path.join(self.root,raw_filename)logger.info("checksum devkit tar file %s ...",raw_file)assert(calculate_md5(raw_file)==checksum),"checksum mismatch, {} may be damaged".format(raw_file)logger.info("extract devkit file..")untargz(os.path.join(self.root,self.raw_file_meta["devkit"][0]))