megengine.distributed.helper.param_pack_split

param_pack_split(inp, offsets, shapes)[源代码]

按照 offsetsshapes 的描述拆分输入 Tensor,并返回拆分后的 Tensor 列表,仅用于 parampack

参数
  • inp (Tensor) – 输入张量。

  • offsets (list) – 每个输出所在对应输入 Tensor 的偏移值,长度为 2 * nn 是预期输出的个数,格式为 [begin0, end0, begin1, end1]

  • shapes (list) – 拆分后输出 Tensor 的 shape

返回

拆分后的 Tensor。

实际案例

>>> a = F.ones(10)
>>> b, c = dist.helper.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
>>> b
Tensor([1.], device=xpux:0)
>>> c
Tensor([[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]], device=xpux:0)