megengine.functional.where¶
-
where
(mask, x, y)[源代码]¶ 根据mask选出张量x或张量y中的元素。
\[\textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i\]- 参数
mask (
Tensor
) – 用于选择x或y的 mask。x (
Tensor
) – 第一个选择。y (
Tensor
) – 第二个选择。
- 返回类型
Tensor
- 返回
输出张量。
例如:
from megengine import tensor import megengine.functional as F mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool)) x = tensor(np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)) y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32)) out = F.where(mask, x, y) print(out.numpy())
输出:
[[1. 6.] [7. 4.]]