# -*- coding: utf-8 -*-
# BSD 3-Clause License
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
# ---------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved.
# ---------------------------------------------------------------------
import os
import shutil

from tqdm import tqdm

from import is_distributed
from ....logger import get_logger
from ....serialization import load, save
from .folder import ImageFolder
from .utils import _default_dataset_root, calculate_md5, untar, untargz

logger = get_logger(__name__)

[文档]class ImageNet(ImageFolder): r"""Load ImageNet from raw files or folder. Expected folder looks like: .. code-block:: shell ${root}/ | [REQUIRED TAR FILES] |- ILSVRC2012_img_train.tar |- ILSVRC2012_img_val.tar |- ILSVRC2012_devkit_t12.tar.gz | [OPTIONAL IMAGE FOLDERS] |- train/cls/xxx.${img_ext} |- val/cls/xxx.${img_ext} |- ILSVRC2012_devkit_t12/data/meta.mat |- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt If the image folders don't exist, raw tar files are required to get extracted and processed. * if ``root`` contains ``self.target_folder`` depending on ``train``: * initialize ImageFolder with target_folder. * else: * if all raw files are in ``root``: * parse ``self.target_folder`` from raw files. * initialize ImageFolder with ``self.target_folder``. * else: * raise error. Args: root: root directory of imagenet data, if root is ``None``, use default_dataset_root. train: if ``True``, load the train split, otherwise load the validation split. """ raw_file_meta = { "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"), "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"), "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"), } # ImageNet raw files default_train_dir = "train" default_val_dir = "val" default_devkit_dir = "ILSVRC2012_devkit_t12" def __init__(self, root: str = None, train: bool = True, **kwargs): # process the root path if root is None: self.root = self._default_root else: self.root = root if not os.path.exists(self.root): raise FileNotFoundError("dir %s does not exist" % self.root) self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) if not os.path.exists(self.devkit_dir): logger.warning("devkit directory %s does not exists", self.devkit_dir) self._prepare_devkit() self.train = train if train: self.target_folder = os.path.join(self.root, self.default_train_dir) else: self.target_folder = os.path.join(self.root, self.default_val_dir) if not os.path.exists(self.target_folder): logger.warning( "expected image folder %s does not exist, try to load from raw file", self.target_folder, ) if not self.check_raw_file(): raise FileNotFoundError( "expected image folder %s does not exist, and raw files do not exist in %s" % (self.target_folder, self.root) ) elif is_distributed(): raise RuntimeError( "extracting raw file shouldn't be done in distributed mode, use single process instead" ) elif train: self._prepare_train() else: self._prepare_val() super().__init__(self.target_folder, **kwargs) @property def _default_root(self): return os.path.join(_default_dataset_root(), self.__class__.__name__) @property def valid_ground_truth(self): groud_truth_path = os.path.join( self.devkit_dir, "data", "ILSVRC2012_validation_ground_truth.txt" ) if os.path.exists(groud_truth_path): with open(groud_truth_path, "r") as f: val_labels = f.readlines() return [int(val_label) for val_label in val_labels] else: raise FileNotFoundError( "valid ground truth file %s does not exist" % groud_truth_path ) @property def meta(self): try: return load(os.path.join(self.devkit_dir, "meta.pkl")) except FileNotFoundError: import meta_path = os.path.join(self.devkit_dir, "data", "meta.mat") if not os.path.exists(meta_path): raise FileNotFoundError("meta file %s does not exist" % meta_path) meta =, squeeze_me=True)["synsets"] nums_children = list(zip(*meta))[4] meta = [ meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0 ] idcs, wnids, classes = list(zip(*meta))[:3] classes = [tuple(clss.split(", ")) for clss in classes] idx_to_wnid = dict(zip(idcs, wnids)) wnid_to_classes = dict(zip(wnids, classes)) "saving cached meta file to %s", os.path.join(self.devkit_dir, "meta.pkl"), ) save( (idx_to_wnid, wnid_to_classes), os.path.join(self.devkit_dir, "meta.pkl"), ) return idx_to_wnid, wnid_to_classes def check_raw_file(self) -> bool: return all( [ os.path.exists(os.path.join(self.root, value[0])) for _, value in self.raw_file_meta.items() ] ) def _organize_val_data(self): id2wnid = self.meta[0] val_idcs = self.valid_ground_truth val_wnids = [id2wnid[idx] for idx in val_idcs] val_images = sorted( [ os.path.join(self.target_folder, image) for image in os.listdir(self.target_folder) ] ) logger.debug("mkdir for val set wnids") for wnid in set(val_wnids): os.makedirs(os.path.join(self.root, self.default_val_dir, wnid)) logger.debug("mv val images into wnids dir") for wnid, img_file in tqdm(zip(val_wnids, val_images)): shutil.move( img_file, os.path.join( self.root, self.default_val_dir, wnid, os.path.basename(img_file) ), ) def _prepare_val(self): assert not self.train raw_filename, checksum = self.raw_file_meta["val"] raw_file = os.path.join(self.root, raw_filename)"checksum valid tar file %s ...", raw_file) assert ( calculate_md5(raw_file) == checksum ), "checksum mismatch, {} may be damaged".format(raw_file)"extract valid tar file... this may take 10-20 minutes") untar(os.path.join(self.root, raw_file), self.target_folder) self._organize_val_data() def _prepare_train(self): assert self.train raw_filename, checksum = self.raw_file_meta["train"] raw_file = os.path.join(self.root, raw_filename)"checksum train tar file %s ...", raw_file) assert ( calculate_md5(raw_file) == checksum ), "checksum mismatch, {} may be damaged".format(raw_file)"extract train tar file.. this may take several hours") untar( os.path.join(self.root, raw_file), self.target_folder, ) paths = [ os.path.join(self.target_folder, child_dir) for child_dir in os.listdir(self.target_folder) ] for path in tqdm(paths): untar(path, os.path.splitext(path)[0], remove=True) def _prepare_devkit(self): raw_filename, checksum = self.raw_file_meta["devkit"] raw_file = os.path.join(self.root, raw_filename)"checksum devkit tar file %s ...", raw_file) assert ( calculate_md5(raw_file) == checksum ), "checksum mismatch, {} may be damaged".format(raw_file)"extract devkit file..") untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0]))