megengine.data.dataset.meta_dataset 源代码

# -*- coding: utf-8 -*-
import bisect
from abc import ABC, abstractmethod
from typing import Tuple


[文档]class Dataset(ABC): r"""An abstract base class for all map-style datasets. .. admonition:: Abstract methods All subclasses should overwrite these two methods: * ``__getitem__()``: fetch a data sample for a given key. * ``__len__()``: return the size of the dataset. They play roles in the data pipeline, see the description below. .. admonition:: Dataset in the Data Pipline Usually a dataset works with :class:`~.DataLoader`, :class:`~.Sampler`, :class:`~.Collator` and other components. For example, the sampler generates **indexes** of batches in advance according to the size of the dataset (calling ``__len__``), When dataloader need to yield a batch of data, pass indexes into the ``__getitem__`` method, then collate them to a batch. * Highly recommended reading :ref:`dataset-guide` for more details; * It might helpful to read the implementation of :class:`~.MNIST`, :class:`~.CIFAR10` and other existed subclass. .. warning:: By default, all elements in a dataset would be :class:`numpy.ndarray`. It means that if you want to do Tensor operations, it's better to do the conversion explicitly, such as: .. code-block:: python dataset = MyCustomDataset() # A subclass of Dataset data, label = MyCustomDataset[0] # equals to MyCustomDataset.__getitem__[0] data = Tensor(data, dtype="float32") # convert to MegEngine Tensor explicitly megengine.functional.ops(data) Tensor ops on ndarray directly are undefined behaviors. """ @abstractmethod def __init__(self): pass @abstractmethod def __getitem__(self, index): pass @abstractmethod def __len__(self): pass
[文档]class StreamDataset(Dataset): r"""All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream. All subclasses should overwrite __iter__(), which would return an iterator of samples in this dataset. Returns: Dataset: An iterable Dataset. Examples: .. code-block:: python from megengine.data.dataset import StreamDataset from megengine.data.dataloader import DataLoader, get_worker_info from megengine.data.sampler import StreamSampler class MyStream(StreamDataset): def __init__(self): self.data = [iter([1, 2, 3]), iter([4, 5, 6]), iter([7, 8, 9])] def __iter__(self): worker_info = get_worker_info() data_iter = self.data[worker_info.idx] while True: yield next(data_iter) dataloader = DataLoader( dataset = MyStream(), sampler = StreamSampler(batch_size=2), num_workers=3, parallel_stream = True, ) for step, data in enumerate(dataloader): print(data) """ @abstractmethod def __init__(self): pass @abstractmethod def __iter__(self): pass def __getitem__(self, idx): raise AssertionError("can not get item from StreamDataset by index") def __len__(self): raise AssertionError("StreamDataset does not have length")
[文档]class ArrayDataset(Dataset): r"""ArrayDataset is a dataset for numpy array data. One or more numpy arrays are needed to initiate the dataset. And the dimensions represented sample number are expected to be the same. Args: Arrays(dataset and labels): the datas and labels to be returned iteratively. Returns: Tuple: A set of raw data and corresponding label. Examples: .. code-block:: python from megengine.data.dataset import ArrayDataset from megengine.data.dataloader import DataLoader from megengine.data.sampler import SequentialSampler rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8) label = np.random.randint(0, 10, size=(sample_num,), dtype=int) dataset = ArrayDataset(rand_data, label) seque_sampler = SequentialSampler(dataset, batch_size=2) dataloader = DataLoader( dataset, sampler = seque_sampler, num_workers=3, ) for step, data in enumerate(dataloader): print(data) """ def __init__(self, *arrays): super().__init__() if not all(len(arrays[0]) == len(array) for array in arrays): raise ValueError("lengths of input arrays are inconsistent") self.arrays = arrays def __getitem__(self, index: int) -> Tuple: return tuple(array[index] for array in self.arrays) def __len__(self) -> int: return len(self.arrays[0])
[文档]class ConcatDataset(Dataset): r"""ConcatDataset is a concatenation of multiple datasets. This dataset is used for assembleing multiple map-style datasets. Args: datasets(list of Dataset): list of datasets to be composed. Returns: Dataset: A Dataset which composes fields of multiple datasets. Examples: .. code-block:: python from megengine.data.dataset import ArrayDataset, ConcatDataset data1 = np.random.randint(0, 255, size=(2, 1, 32, 32), dtype=np.uint8) data2 = np.random.randint(0, 255, size=(2, 1, 32, 32), dtype=np.uint8) label = np.random.randint(0, 10, size=(2,), dtype=int) labe2 = np.random.randint(0, 10, size=(2,), dtype=int) dataset1 = ArrayDataset(data1, label1) dataset2 = ArrayDataset(data2, label2) dataset = ConcatDataset([dataset1, dataset2]) seque_sampler = SequentialSampler(dataset, batch_size=2) dataloader = DataLoader( dataset, sampler = seque_sampler, num_workers=3, ) for step, data in enumerate(dataloader): print(data) """ def __init__(self, datasets): super(ConcatDataset, self).__init__() self.datasets = datasets def cumsum(datasets): r, s = [], 0 for e in datasets: l = len(e) r.append(l + s) s += l return r assert len(self.datasets) > 0, "datasets should not be an empty iterable" for d in self.datasets: assert not isinstance( d, StreamDataset ), "ConcatDataset does not support StreamDataset" self.datasets = list(datasets) self.cumulative_sizes = cumsum(self.datasets) def __getitem__(self, idx): if idx < 0: if -idx > len(self): raise ValueError( "absolute value of index should not exceed dataset length" ) idx = len(self) + idx dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return self.datasets[dataset_idx][sample_idx] def __len__(self): return self.cumulative_sizes[-1]