ArrayDataset#

class ArrayDataset(*arrays)[源代码]#

适用于 Numpy ndarray 数据的 Dataset 类。

需要一个或多个 NumPy 数组来初始化数据集,且表示样本数的维数应当一致。

实际案例

from megengine.data.dataset import ArrayDataset
from megengine.data.dataloader import DataLoader
from megengine.data.sampler import SequentialSampler

rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8)
label = np.random.randint(0, 10, size=(sample_num,), dtype=int)
dataset = ArrayDataset(rand_data, label)
seque_sampler = SequentialSampler(dataset, batch_size=2)

dataloader = DataLoader(
    dataset,
    sampler = seque_sampler,
    num_workers=3,
)

for step, data in enumerate(dataloader):
    print(data)