megengine.functional.gather

gather(inp, axis, index)[源代码]

根据给定的索引从输入 Tensor 中收集数据。

对于一个3维张量,输出按照如下规则指定:

out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0
out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1
out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2

如果 inp 是一个尺寸为 \((x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})\) 且 axis=i 的 n 维 Tensor 则 index 必须是一个尺寸为 \((x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})\) 的 n 维 Tensor,这里的 \(y\ge 1\) 和输出的尺寸都必须必须与 index 的尺寸相同。

参数
  • inp (Tensor) – 输入张量。

  • axis (int) – 将要进行索引的轴。

  • index (Tensor) – 将要进行聚合的元素的索引

返回类型

Tensor

返回

输出张量。

实际案例

>>> inp = Tensor([
...     [1,2], [3,4], [5,6],
... ])
>>> index = Tensor([[0,2], [1,0]])
>>> F.gather(inp, 0, index)
Tensor([[1 6]
 [3 2]], dtype=int32, device=xpux:0)