megengine.functional.distributed.gather

gather(inp, group=WORLD, device=None, axis=0)[源代码]

在指定的组中收集张量。只有根进程才会收到最终结果。

参数
  • inp (Tensor) – 输入张量

  • group (Optional[Group]) – 需要处理的组,默认为包含所有进程的 WORLD 组。你可以使用进程序号来创建新的组并使用,例如 [1,3,5] 。

  • device (Optional[str]) – 执行此操作的设备。默认为输入张量所在的设备。可以通过指定设备为 ”gpu0:1“ 以在不同的 cuda 流上执行此操作,其中1是 cuda 流的编号,默认 cuda 流编号为0。

  • axis – 集合通信结果的拼接维度

实际案例

input = Tensor([rank])
# Rank 0 # input: Tensor([0])
# Rank 1 # input: Tensor([1])
output = gather(input)
# Rank 0 # output: Tensor([0 1])
# Rank 1 # output: None

input = Tensor([rank])
group = Group([1, 0]) # first rank is root
output = gather(input, group)
# Rank 0 # output: None
# Rank 1 # output: Tensor([1 0])
返回类型

Tensor