megengine.core.tensor.array_method 源代码

# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc
import collections
from functools import lru_cache
from typing import Union

import numpy as np

from .. import _config
from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import (
    SymbolVar,
    Tensor,
    apply,
    astype_cpp,
    broadcast_cpp,
    dtype_promotion,
    getitem_cpp,
)
from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar
from .._imperative_rt.core2 import reshape_cpp, setitem_cpp, squeeze_cpp, transpose_cpp
from ..ops import builtin
from . import amp
from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph

_ElwMod = builtin.Elemwise.Mode


def _elwise_apply(args, mode):
    op = builtin.Elemwise(mode)
    (result,) = apply(op, *args)
    return result


def _elwise(*args, mode):
    return _elwise_apply(args, mode)


@lru_cache(maxsize=None)
def _get_extentedMatrixMulOp(
    device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
    @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2)
    def extentedMatrixMulOp(inputs, f, c):
        assert len(inputs) == 2
        inp1, inp2 = inputs
        _dim1, _dim2 = dim1, dim2

        def build_shape_head(shape, idx=-1):
            # shape[:idx]
            return f(
                builtin.Subtensor(items=[[0, False, True, False, False]]),
                shape,
                c(idx, "int32"),
            )

        def build_shape_tail(shape, idx=-1):
            # shape[idx:]
            return f(
                builtin.Subtensor(items=[[0, True, False, False, False]]),
                shape,
                c(idx, "int32"),
            )

        remove_row, remove_col = False, False
        if _dim1 == 1:
            _dim1 = 2
            remove_row = True
        if _dim2 == 1:
            _dim2 = 2
            remove_col = True

        if remove_row:
            inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
        if remove_col:
            inp2 = f(builtin.AddAxis(axis=[1,]), inp2)

        shape1 = f(builtin.GetVarShape(), inp1)
        shape2 = f(builtin.GetVarShape(), inp2)
        if _dim1 > 2:
            inp1 = f(
                builtin.Reshape(),
                inp1,
                f(
                    builtin.Concat(axis=0, comp_node=device),
                    f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)),
                    build_shape_tail(shape1),
                ),
            )
        if _dim2 > 2:
            inp2 = f(
                builtin.Reshape(),
                inp2,
                f(
                    builtin.Concat(axis=0, comp_node=device),
                    f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)),
                    build_shape_tail(shape2),
                ),
            )
        op = builtin.MatrixMul(
            transposeA=transpose_a,
            transposeB=transpose_b,
            compute_mode=compute_mode,
            format=format,
            strategy=strategy.value,
        )
        result = f(op, inp1, inp2)
        result_shape = f(builtin.GetVarShape(), result)
        if _dim1 > 2:
            result = f(
                builtin.Reshape(),
                result,
                f(
                    builtin.Concat(axis=0, comp_node=device),
                    build_shape_head(shape1),
                    build_shape_tail(result_shape),
                ),
            )
        if _dim2 > 2:
            result = f(
                builtin.Reshape(),
                result,
                f(
                    builtin.Concat(axis=0, comp_node=device),
                    build_shape_head(shape2),
                    build_shape_tail(result_shape),
                ),
            )
        maxdim = _dim1 if _dim1 > _dim2 else _dim2
        if remove_row:
            result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
        if remove_col:
            result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
        return (result,), (True,)

    return extentedMatrixMulOp


@lru_cache(maxsize=None)
def _get_extentedBatchedMatrixMulOp(
    device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
    @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2)
    def extentedBatchedMatrixMulOp(inputs, f, c):
        assert len(inputs) == 2
        inp1, inp2 = inputs
        _dim1, _dim2 = dim1, dim2

        def build_shape_head(shape, idx=-2):
            # shape[:idx]
            return f(
                builtin.Subtensor(items=[[0, False, True, False, False]]),
                shape,
                c(idx, "int32"),
            )

        def build_shape_tail(shape, idx=-2):
            # shape[idx:]
            return f(
                builtin.Subtensor(items=[[0, True, False, False, False]]),
                shape,
                c(idx, "int32"),
            )

        remove_row, remove_col = False, False
        if _dim1 == 1:
            _dim1 = 2
            remove_row = True
        if _dim2 == 1:
            _dim2 = 2
            remove_col = True

        if remove_row:
            inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
        if remove_col:
            inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
        shape1 = f(builtin.GetVarShape(), inp1)
        shape2 = f(builtin.GetVarShape(), inp2)
        maxdim = _dim1 if _dim1 > _dim2 else _dim2
        if _dim1 > _dim2:
            # broadcast
            shape2 = f(
                builtin.Concat(axis=0, comp_node=device),
                build_shape_head(shape1, idx=-_dim2),  # shape1[:-_dim2]
                shape2,
            )
            inp2 = f(builtin.Broadcast(), inp2, shape2)
            batch_shape = build_shape_head(shape1)
        if _dim2 > _dim1:
            # broadcast
            shape1 = f(
                builtin.Concat(axis=0, comp_node=device),
                build_shape_head(shape2, idx=-_dim1),  # shape2[:-_dim1]
                shape1,
            )
            inp1 = f(builtin.Broadcast(), inp1, shape1)
            batch_shape = build_shape_head(shape2)
        if _dim1 == _dim2:
            batch_shape = build_shape_head(shape1)

        # compress inputs to 3d
        if maxdim > 3:
            inp1 = f(
                builtin.Reshape(),
                inp1,
                f(
                    builtin.Concat(axis=0, comp_node=device),
                    f(builtin.Reduce(mode="product", axis=0), batch_shape),
                    build_shape_tail(shape1),
                ),
            )
            inp2 = f(
                builtin.Reshape(),
                inp2,
                f(
                    builtin.Concat(axis=0, comp_node=device),
                    f(builtin.Reduce(mode="product", axis=0), batch_shape),
                    build_shape_tail(shape2),
                ),
            )
        op = builtin.BatchedMatrixMul(
            transposeA=transpose_a,
            transposeB=transpose_b,
            compute_mode=compute_mode,
            format=format,
            strategy=strategy.value,
        )
        result = f(op, inp1, inp2)

        if maxdim > 3:
            result = f(
                builtin.Reshape(),
                result,
                f(
                    builtin.Concat(axis=0, comp_node=device),
                    batch_shape,
                    build_shape_tail(f(builtin.GetVarShape(), result)),
                ),
            )
        if remove_row:
            result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
        if remove_col:
            result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
        return (result,), (True,)

    return extentedBatchedMatrixMulOp


class _Hashable:
    def __init__(self, value) -> None:
        self.value = value

    def __hash__(self) -> int:
        return hash(str(self.value))

    def __eq__(self, o: object) -> bool:
        if not isinstance(o, _Hashable):
            return False
        return self.value == o.value


def _matmul(
    inp1,
    inp2,
    transpose_a=False,
    transpose_b=False,
    compute_mode="default",
    format="default",
):
    if amp._enabled:
        compute_mode = "float32"
        inp1, inp2 = cast_tensors(inp1, inp2)
    else:
        dtype = dtype_promotion(inp1, inp2)
        if inp1.dtype != dtype:
            inp1 = inp1.astype(dtype)
        if inp2.dtype != dtype:
            inp2 = inp2.astype(dtype)

    dim1, dim2 = inp1.ndim, inp2.ndim
    assert dim1 > 0 and dim2 > 0
    maxdim = dim1 if dim1 > dim2 else dim2
    compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)

    Strategy = builtin.ops.MatrixMul.Strategy
    strategy = Strategy(0)
    if _config._benchmark_kernel:
        strategy |= Strategy.PROFILE
    else:
        strategy |= Strategy.HEURISTIC
    if _config._deterministic_kernel:
        strategy |= Strategy.REPRODUCIBLE

    if dim1 == 1 and dim2 == 1:  # dispatch to Dot
        (result,) = apply(builtin.Dot(), inp1, inp2)
        return result
    elif maxdim <= 2 or dim2 <= 2:  # dispath to MatrixMul
        extentedMatrixMulOp = _get_extentedMatrixMulOp(
            inp1.device,
            inp1.dtype,
            dim1,
            dim2,
            transpose_a,
            transpose_b,
            compute_mode,
            format,
            strategy=_Hashable(strategy),
        )
        (result,) = apply(extentedMatrixMulOp(), inp1, inp2)
        return result
    else:  # dispath to BatchedMatrixMul
        extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp(
            inp1.device,
            inp1.dtype,
            dim1,
            dim2,
            transpose_a,
            transpose_b,
            compute_mode,
            format,
            strategy=_Hashable(strategy),
        )
        (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2)
        return result


def _unary_elwise(mode):
    def f(self):
        return _elwise(self, mode=mode)

    return f


def _binary_elwise(mode, rev=False):
    if not rev:

        def f(self, value):
            return _elwise(self, value, mode=mode)

    else:

        def f(self, value):
            return _elwise(value, self, mode=mode)

    return f


def _logical_unary_elwise(mode, rev=False):
    def f(self):
        if self.dtype != np.bool_:
            raise TypeError("{} requires a bool tensor".format(mode))
        return _elwise(self, mode=mode)

    return f


def _logical_binary_elwise(mode, rev=False):
    if not rev:

        def f(self, value):
            if self.dtype != np.bool_ or value.dtype != np.bool_:
                raise TypeError("{} requires 2 bool tensors".format(mode))
            return _elwise(self, value, mode=mode)

    else:

        def f(self, value):
            if self.dtype != np.bool_ or value.dtype != np.bool_:
                raise TypeError("{} requires 2 bool tensors".format(mode))
            return _elwise(value, self, mode=mode)

    return f


def _reduce(mode):
    def f(self, axis=None, keepdims: bool = False):
        data = self
        if axis is None:
            assert not keepdims, "can not set axis=None and keepdims=True"
            (result,) = apply(builtin.Reduce(mode=mode), data)
        elif isinstance(axis, collections.abc.Iterable):
            axis = _normalize_axis(self.ndim, axis, reverse=True)
            for ai in axis:
                op = builtin.Reduce(mode=mode, axis=ai, keepdim=keepdims)
                (data,) = apply(op, data)
            result = data
        else:
            # builtin.Reduce already accept negtive axis
            op = builtin.Reduce(mode=mode, axis=axis, keepdim=keepdims)
            (result,) = apply(op, data)

        return result

    return f


def _inplace(f):
    def g(self, value):
        result = f(self, value)
        if result is NotImplemented:
            raise NotImplementedError
        self._reset(result)
        return self

    return g


def _todo(*_):
    raise NotImplementedError


def _expand_args(args):
    if len(args) == 1:
        if isinstance(
            args[0], (collections.abc.Sequence, Tensor, SymbolVar, np.ndarray),
        ):
            args = args[0]
    return args


class ArrayMethodMixin(abc.ABC):

    # enable tensor to be converted to numpy array
    __array_priority__ = 1001

    def __array__(self, dtype=None):
        if dtype == None:
            return self.numpy()
        return self.numpy().astype(dtype)

    def __array_wrap__(self, array):
        Wrapper = type(self)
        return Wrapper(array, dtype=array.dtype, device=self.device)

    @abc.abstractmethod
    def _reset(self, other):
        pass

    @abc.abstractproperty
    def dtype(self) -> np.dtype:
        pass

    @abc.abstractproperty
    def shape(self) -> Union[tuple, Tensor]:
        pass

    @abc.abstractproperty
    def _tuple_shape(self) -> tuple:
        pass

    @abc.abstractmethod
    def numpy(self) -> np.ndarray:
        pass

    __hash__ = None  # due to __eq__ diviates from python convention

    __lt__ = lambda self, value: _elwise(self, value, mode=_ElwMod.LT).astype("bool")
    __le__ = lambda self, value: _elwise(self, value, mode=_ElwMod.LEQ).astype("bool")
    __gt__ = lambda self, value: _elwise(value, self, mode=_ElwMod.LT).astype("bool")
    __ge__ = lambda self, value: _elwise(value, self, mode=_ElwMod.LEQ).astype("bool")
    __eq__ = lambda self, value: _elwise(self, value, mode=_ElwMod.EQ).astype("bool")
    __ne__ = lambda self, value: _elwise(
        _elwise(self, value, mode=_ElwMod.EQ).astype("bool"), mode=_ElwMod.NOT,
    )

    __neg__ = _unary_elwise(_ElwMod.NEGATE)
    __pos__ = lambda self: self
    __abs__ = _unary_elwise(_ElwMod.ABS)
    __invert__ = _logical_unary_elwise(_ElwMod.NOT)
    __round__ = _unary_elwise(_ElwMod.ROUND)
    __trunc__ = _todo
    __floor__ = _unary_elwise(_ElwMod.FLOOR)
    __ceil__ = _unary_elwise(_ElwMod.CEIL)

    __add__ = _binary_elwise(_ElwMod.ADD)
    __sub__ = _binary_elwise(_ElwMod.SUB)
    __mul__ = _binary_elwise(_ElwMod.MUL)
    __matmul__ = lambda self, other: _matmul(self, other)
    __truediv__ = _binary_elwise(_ElwMod.TRUE_DIV)
    __floordiv__ = _binary_elwise(_ElwMod.FLOOR_DIV)
    __mod__ = _binary_elwise(_ElwMod.MOD)
    # __divmode__
    __pow__ = _binary_elwise(_ElwMod.POW)
    __lshift__ = _binary_elwise(_ElwMod.SHL)
    __rshift__ = _binary_elwise(_ElwMod.SHR)
    __and__ = _logical_binary_elwise(_ElwMod.AND)
    __or__ = _logical_binary_elwise(_ElwMod.OR)
    __xor__ = _logical_binary_elwise(_ElwMod.XOR)

    __radd__ = _binary_elwise(_ElwMod.ADD, rev=1)
    __rsub__ = _binary_elwise(_ElwMod.SUB, rev=1)
    __rmul__ = _binary_elwise(_ElwMod.MUL, rev=1)
    __rmatmul__ = lambda self, other: _matmul(other, self)
    __rtruediv__ = _binary_elwise(_ElwMod.TRUE_DIV, rev=1)
    __rfloordiv__ = _binary_elwise(_ElwMod.FLOOR_DIV, rev=1)
    __rmod__ = _binary_elwise(_ElwMod.MOD, rev=1)
    # __rdivmode__
    __rpow__ = _binary_elwise(_ElwMod.POW, rev=1)
    __rlshift__ = _binary_elwise(_ElwMod.SHL, rev=1)
    __rrshift__ = _binary_elwise(_ElwMod.SHR, rev=1)
    __rand__ = _logical_binary_elwise(_ElwMod.AND, rev=1)
    __ror__ = _logical_binary_elwise(_ElwMod.OR, rev=1)
    __rxor__ = _logical_binary_elwise(_ElwMod.XOR, rev=1)

    __iadd__ = _inplace(__add__)
    __isub__ = _inplace(__sub__)
    __imul__ = _inplace(__mul__)
    __imatmul__ = _inplace(__matmul__)
    __itruediv__ = _inplace(__truediv__)
    __ifloordiv__ = _inplace(__floordiv__)
    __imod__ = _inplace(__mod__)
    __ipow__ = _inplace(__pow__)
    __ilshift__ = _inplace(__lshift__)
    __irshift__ = _inplace(__rshift__)
    __iand__ = _inplace(__and__)
    __ior__ = _inplace(__or__)
    __ixor__ = _inplace(__xor__)

    __index__ = lambda self: self.item().__index__()
    __bool__ = lambda self: bool(self.item())
    __int__ = lambda self: int(self.item())
    __float__ = lambda self: float(self.item())
    __complex__ = lambda self: complex(self.item())

    def __len__(self):
        shape = self._tuple_shape
        if shape:
            return int(shape[0])
        raise TypeError("ndim is 0")

    def __iter__(self):
        for i in range(len(self)):
            yield self[i]

    def __getitem__(self, index):
        return getitem_cpp(self, index)

    def __setitem__(self, index, value):
        if index is not Ellipsis:
            value = setitem_cpp(self, index, value)
        self._reset(value)

    __contains__ = _todo

    @property
    def ndim(self):
        r"""Returns the number of dimensions of self :class:`~.Tensor`."""
        shape = self._tuple_shape
        if shape is None:
            raise ValueError("unkown ndim")
        return len(shape)

    @property
    def size(self):
        r"""Returns the size of the self :class:`~.Tensor`.
        The returned value is a subclass of :class:`tuple`.
        """
        shape = self.shape
        if shape.__class__ is tuple:
            return np.prod(self.shape).item()
        return shape.prod()

    @property
    def T(self):
        r"""alias of :attr:`~.Tensor.transpose`."""
        return self.transpose()

    def item(self, *args):
        r"""Returns the value of this :class:`~.Tensor` as a standard Python :class:`numbers.Number`.
        This only works for tensors with one element. For other cases, see :meth:`~.tolist`.
        """
        if not args:
            if isinstance(self.size, int):
                assert self.size == 1
            return self.numpy().item()
        return self[args].item()

    def tolist(self):
        r"""Returns the tensor as a (nested) list.
        For scalars, a standard Python number is returned, just like with :meth:`~.item`.
        Tensors are automatically moved to the CPU first if necessary.

        This operation is not differentiable.
        """
        return self.numpy().tolist()

    def astype(self, dtype):
        r"""Returns a :class:`Tensor` with the same data and number of elements
        with the specified :attr:`~.Tensor.dtype`.
        """
        return astype_cpp(self, dtype)

    def reshape(self, *args):
        r"""See :func:`~.reshape`."""
        return reshape_cpp(self, args)

    # FIXME: remove this method
    def _broadcast(self, *args):
        return broadcast_cpp(self, args)

    def transpose(self, *args):
        r"""See :func:`~.transpose`."""
        return transpose_cpp(self, args)

    def flatten(self):
        r"""See :func:`~.flatten`."""
        return reshape_cpp(self, (-1,))

    def sum(self, axis=None, keepdims: bool = False):
        r"""Returns the sum of each row of the input tensor in the given dimension ``axis``.

        If ``axis`` is a list of axises, reduce over all of them.
        If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
        except in the dimension(s) ``axis`` where it is of size 1.
        Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`).

        Args:
            axis: the dimension or dimensions to reduce.
            keepdims: whether the output tensor has ndim retained or not.

        Returns:
            output tensor.

        Examples:
            .. testcode::

               from megengine import tensor
               a = tensor([False, True, True, False])
               b = tensor([1.0, 2.0, 3.0, 4.0])
               print(a.sum().numpy())
               print(b.sum().numpy())

            Outputs:

            .. testoutput::

               2
               10.0
        """
        return _reduce("sum")(self, axis, keepdims)

    def prod(self, axis=None, keepdims: bool = False):
        r"""Returns the product of each row of the input tensor in the given dimension ``axis``.

        If ``axis`` is a list of axises, reduce over all of them.
        If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
        except in the dimension(s) ``axis`` where it is of size 1.
        Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`).

        Args:
            axis: the dimension or dimensions to reduce.
            keepdims: whether the output tensor has ndim retained or not.

        Returns:
            output tensor.

        Examples:
            .. testcode::

               from megengine import tensor
               a = tensor([False, True, True, False])
               b = tensor([1.0, 2.0, 3.0, 4.0])
               print(a.prod().numpy())
               print(b.prod().numpy())

            Outputs:

            .. testoutput::

               0
               24.0
        """
        return _reduce("product")(self, axis, keepdims)

    def min(self, axis=None, keepdims: bool = False):
        r"""Returns the min value of each row of the input tensor in the given dimension ``axis``.

        If ``axis`` is a list of axises, reduce over all of them.
        If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
        except in the dimension(s) ``axis`` where it is of size 1.
        Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`).

        Args:
            axis: the dimension or dimensions to reduce.
            keepdims: whether the output tensor has ndim retained or not.

        Returns:
            output tensor.

        Examples:
            .. testcode::

               from megengine import tensor
               a = tensor([False, True, True, False])
               b = tensor([1.0, 2.0, 3.0, 4.0])
               print(a.min().numpy())
               print(b.min().numpy())

            Outputs:

            .. testoutput::

               False
               1.0
        """
        return _reduce("min")(self, axis, keepdims)

    def max(self, axis=None, keepdims: bool = False):
        r"""Returns the max value of each row of the input tensor in the given dimension ``axis``.

        If ``axis`` is a list of axises, reduce over all of them.
        If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
        except in the dimension(s) ``axis`` where it is of size 1.
        Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`).

        Args:
            axis: the dimension or dimensions to reduce.
            keepdims: whether the output tensor has ndim retained or not.

        Returns:
            output tensor.

        Examples:
            .. testcode::

               from megengine import tensor
               a = tensor([False, True, True, False])
               b = tensor([1.0, 2.0, 3.0, 4.0])
               print(a.max().numpy())
               print(b.max().numpy())

            Outputs:

            .. testoutput::

               True
               4.0
        """
        return _reduce("max")(self, axis, keepdims)

    def mean(self, axis=None, keepdims: bool = False):
        r"""Returns the mean value of each row of the input tensor in the given dimension ``axis``.

        If ``axis`` is a list of axises, reduce over all of them.
        If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
        except in the dimension(s) ``axis`` where it is of size 1.
        Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`).

        Args:
            axis: the dimension or dimensions to reduce.
            keepdims: whether the output tensor has ndim retained or not.

        Returns:
            output tensor.

        Examples:
            .. testcode::

               from megengine import tensor
               a = tensor([False, True, True, False])
               b = tensor([1.0, 2.0, 3.0, 4.0])
               print(a.mean().numpy())
               print(b.mean().numpy())

            Outputs:

            .. testoutput::

               0.5
               2.5
        """
        return _reduce("mean")(self, axis, keepdims)