megengine.distributed.helper.param_pack_concat#

param_pack_concat(inps, offsets, offsets_val)[源代码]#

返回拼接后的 Tensor,仅用于 ParamPack。

参数:
  • inps (list) – 由输入张量组成的列表。

  • offsets (Tensor) – 每个输入的 Tensor 在输出 Tensor 中的偏移量,长度为 2 * n,其中 n 为输入 Tensor 的个数,格式为 [begin0, end0, begin1, end1],且该值需要是 Tensor 类型。

  • offsets_val (list) – 每个输入的 Tensor 在输出 Tensor 中的偏移量,长度为 2 * n,其中 n 为输入 Tensor 的个数,格式为 [begin0, end0, begin1, end1]

返回:

拼接后的 Tensor。

实际案例

>>> a = F.ones(1)
>>> b = F.ones((3, 3))
>>> offsets_val = [0, 1, 1, 10]
>>> offsets = Tensor(offsets_val)
>>> c = dist.helper.param_pack_concat([a, b], offsets, offsets_val)  
Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], device=xpux:0)