Dataset

class Dataset[源代码]

所有映射式(map-style)数据集的抽象基类

抽象方法

所有子类都应该重写这两个方法:

  • __getitem__(): 获取给定索引对应的数据样本

  • __len__(): 返回数据集的大小

他们是在数据队列中发挥作用,详见下面的说明。

Dataset in the Data Pipline

通常加载数据集会使用 DataLoaderSamplerCollator 以及其他组件

例如,采样器会根据数据集的大小(调用 __len__) 提前生成batch数据的**索引**。当dataloader需要返回一个batch的数据时,其会将索引传递给``__getitem__`` 方法,最后将数据处理为一个batch。

警告

默认情况下,数据集中所有的元素都是 numpy.ndarray。这耶意味着如果你想做张量运算,最好是显示的进行转换,例如:

dataset = MyCustomDataset()  # A subclass of Dataset
data, label = MyCustomDataset[0]  # equals to MyCustomDataset.__getitem__[0]
data = Tensor(data, dtype="float32")  # convert to MegEngine Tensor explicitly

megengine.functional.ops(data)

在ndarray上进行张量运算是未定义的行为