megengine.functional.cond_take¶
-
cond_take
(mask, x)[源代码]¶ 如果在mask上满足了特定条件,则从数据中取出元素。此算子有两个输出:第一个是取出的元素,第二个是这些元素对应的索引;两个输出都是一维的。高维数据输入时将首先被展平。
- 参数
mask (
Tensor
) – 条件参数;必须与数据的形状相同x (
Tensor
) – 将从其中取出元素的输入张量
- 返回类型
megengine.tensor.Tensor
例如:
import numpy as np from megengine import tensor import megengine.functional as F mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool_)) x = tensor(np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)) v, index = F.cond_take(mask, x) print(v.numpy(), index.numpy())
输出:
[1. 4.] [0 3]
- 返回类型
Tensor
- 参数
mask (megengine.tensor.Tensor) –
x (megengine.tensor.Tensor) –