使用 Collator 定义合并策略

注解

在使用 DataLoader 获取批数据的整个流程中, Collator 负责合并样本,最终得到批数据。

>>> dataloader = DataLoader(dataset, collator=...)

通常用户不用实现自己的 Collator, 使用默认合并策略就可以处理大部分批数据合并的情况。 但遇到一些默认合并策略难以处理的情景时,用户可以使用自己实现的 Collator. 参考 自定义 Collator.

警告

Collator 仅适用于 Map-style 的数据集,因为 Iterable-style 数据集的批数据必然是逐个合并的。

默认合并策略

经过前面的处理流程后, Collator 通常会接收到一个列表:

  • 如果你的 Dataset 子类的 __getitem__ 方法返回的是单个元素,则 Collator 得到一个普通列表;

  • 如果你的 Dataset 子类的 __getitem__ 方法返回的是一个元组,则 Collator 得到一个元组列表。

MegEngine 中使用 Collator 作为默认实现,通过调用 apply 方法来将列表数据合并成批数据:

>>> from megengine.data import Collator
>>> collator = Collator()

其实现逻辑中使用 numpy.stack 函数来将列表中包含的所有样例在第一个维度( axis=0 )合并。

参见

MegEngine 中也提供了类似的 stack 函数,不过它仅适用于 Tensor 数据。

警告

默认的 Collator 支持 NumPy ndarrays, Numbers, Unicode strings, bytes, dicts 或 lists 数据类型。 要求输入必须包含至少一种上述数据类型,否则用户需要使用自己定义的 Collator.

Collator 效果示范

如果此时每个样本是形状为 \((C, H, W)\) 的图片 image, 且在 Sampler 中指定了 batch_size\(N\). 那么 Collator 的主要目的就是将获得的该样本列表合并成一个形状为 \((N, C, H, W)\) 的批样本结构。

我们可以模拟得到这样一个 image_list 数据,并借助 Collator 得到 batch_image:

>>> N, C, H, W = 5, 3, 32, 32
>>> image_list = []
>>> for i in range(N):
...     image_list.append(np.random.random((C, H, W)))
>>> print(len(image_list), image_list[0].shape)
5 (3, 32, 32)
>>> batch_image = collator.apply(image_list)
>>> batch_image.shape
(5, 3, 32, 32)

如果样本带有标签,则 Collator 就需要将由 (image, label) 元组构成的列表合并, 形成一个大的 (batch_image, bacth_label) 元组。这也是我们对 DataLoader 进行迭代时通常会获得的东西。

在下面的示例代码中,sample_list 中每个元素都是一个元组(假设所有的标签都用整型 1 来表示):

>>> sample_list = []
>>> for i in range(N):
...     sample_list.append((np.random.random((C, H, W)), 1))
>>> type(sample_list[0])
tuple
>>> print(sample_list[0][0].shape, type(sample_list[0][1]))
(3, 32, 32) <class 'int'>

MegEngine 提供的默认 Collator 也能够很好地处理这种情况:

>>> batch_image, batch_label = collator.apply(sample_list)
>>> print(batch_image.shape, batch_label.shape)
(5, 3, 32, 32) (5,)

警告

需要注意的是,此时 batch_label 已经被转换成了 ndarray 数据结构。

自定义 Collator

当默认的 stack 合并策略无法满足我们的需求时,我们则需要考虑自定义 Collator:

  • 需要继承 Collator 类,并在子类中实现 apply 方法;

  • 我们实现的 apply 方法将被 DataLoader 调用。