# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import builtins
import collections
import copy
import inspect
import re
from typing import Callable, Dict, List
from ..core._imperative_rt import OpDef
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ..core.ops.builtin import FakeQuant
from ..core.ops.special import Const
from ..module import Module
from ..tensor import Parameter, Tensor
from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
from .serialization import get_opdef_state, load_opdef_from_state
def rstrip(s: str, __chars: str):
    __chars = re.escape(__chars)
    s = re.sub(r"^(?P<left>.*?)(?:%s)+$" % __chars, "\g<left>", s)
    return s
class Expr:
    r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``, 
    ``GetAttr``, ``Input``, ``Constant``) on ``Node``.
    """
    inputs = None  # type: List[Node]
    r"""The input Nodes of this Expr."""
    outputs = None  # type: List[Node]
    r"""The output Nodes of this Expr."""
    const_val = None  # type: List[Any]
    r"""The non-tensor object in the input of the operation."""
    arg_def = None  # type: TreeDef
    r"""The :class:`TreeDef` used to reconstruct the input of the operation."""
    out_def = None  # type: TreeDef
    r"""The :class:`TreeDef` used to reconstruct the output of the operation."""
    _top_graph = None  # type: weakref.ReferenceType
    __total_id = 0
    def __init__(self) -> None:
        self._id = Expr.__total_id
        Expr.__total_id += 1
        self._disable_remove = False
    def enable_remove(self):
        self._disable_remove = False
    def disable_remove(self):
        self._disable_remove = True
    def add_inputs(self, vals):
        if not isinstance(vals, collections.abc.Sequence):
            vals = (vals,)
        for val in vals:
            node = NodeMixin.get(val, None)
            if isinstance(node, (TensorNode, ModuleNode)):
                self.inputs.append(node)
                node.users.append(self)
            else:
                assert node is None
                assert _is_leaf(val) and _is_const_leaf(val)
                idx = len(self.inputs) + len(self.const_val)
                self.const_val.append((idx, val))
    def add_outputs(self, outputs):
        self.outputs = []
        if outputs is not None:
            if not isinstance(outputs, collections.Sequence):
                outputs = (outputs,)
            name = None
            orig_name = None
            if isinstance(self, CallMethod):
                name = self.inputs[0]._name
                orig_name = self.inputs[0]._orig_name
                assert isinstance(name, str), "The name of ({}) must be a str".format(
                    self.inputs[0]
                )
                assert isinstance(
                    orig_name, str
                ), "The orig_name of ({}) must be a str".format(self.inputs[0])
                name = rstrip(name, "_out")
                if self.method == "__call__":
                    name += "_out"
                    orig_name += "_out"
                else:
                    strip_method = self.method.strip("_")
                    name = "%s_out" % strip_method
                    orig_name = name
            elif isinstance(self, CallFunction):
                name = self.func.__name__ + "_out"
            elif isinstance(self, Apply):
                name = str(self.opdef).lower() + "_out"
            for i in outputs:
                assert isinstance(i, RawTensor), "The output must be a Tensor"
                o_name = (
                    active_module_tracer().current_scope()._create_unique_name(name)
                )
                self.outputs.append(
                    NodeMixin.get_wrapped_type(i)(
                        expr=self,
                        name=o_name,
                        orig_name=orig_name if orig_name else o_name,
                    )
                )
            for i, node in zip(outputs, self.outputs,):
                NodeMixin.wrap_safe(i, node)
    def unflatten_args(self, inputs):
        if self.arg_def is not None:
            inputs = list(inputs)
            for idx, val in self.const_val:
                inputs.insert(idx, val)
            args, kwargs = self.arg_def.unflatten(inputs)
            return args, kwargs
        else:
            return inputs, {}
    def replace_inputs(self, repl_dict: Dict[Node, Node]):
        r"""Replace the input Nodes of this Expr.
        
        Args:
            repl_dict: the map {old_Node: new_Node} that specifies how to replace the input Nodes.
        """
        while repl_dict:
            node, repl_node = repl_dict.popitem()
            assert type(node) == type(repl_node)
            assert node in self.inputs, "({}) is not in the ({})".format(node, self)
            assert (
                repl_node.top_graph == node.top_graph
            ), "({}) and ({}) are not in the same graph".format(node, repl_node)
            graph = self.top_graph
            repl_expr_idx = graph._exprs.index(repl_node.expr)
            self_idx = graph._exprs.index(self)
            assert (
                repl_expr_idx < self_idx
            ), "({}) must be generated before ({})".format(repl_node, self)
            idx = self.inputs.index(node)
            self.inputs[idx] = repl_node
            user_idx = node.users.index(self)
            assert user_idx >= 0
            node.users.pop(user_idx)
            repl_node.users.append(self)
    @property
    def kwargs(self):
        r"""Get the the keyword arguments of the operation corresponding to this Expr."""
        _, kwargs = self.unflatten_args(self.inputs)
        return kwargs
    @property
    def args(self):
        r"""Get the the positional arguments of the operation corresponding to this Expr."""
        args, _ = self.unflatten_args(self.inputs)
        return args
    @property
    def top_graph(self):
        r"""Get the parent graph of this Expr."""
        if self._top_graph:
            return self._top_graph()
        return None
    def __getstate__(self):
        state = self.__dict__.copy()
        if "_top_graph" in state:
            state.pop("_top_graph")
        return state
    @classmethod
    def _get_next_id(cls):
        return cls.__total_id
    @classmethod
    def _set_next_id(cls, id: int = 0):
        assert isinstance(id, int)
        cls.__total_id = id
# expr: None (i.e. fake expression which is used to mark input)
# expr: outputs = getattr(inputs[0], self.name)
[文档]class GetAttr(Expr):
    r"""``Getattr`` represents the fetch of an attribute from the ``Module`` hierarchy."""
    name = None
    r"""name: the qualified name of the attribute to be retrieved."""
    def __init__(self, module, name, type=None, orig_name=None):
        super().__init__()
        assert isinstance(module, ModuleNode)
        self.inputs = [
            module,
        ]
        module.users.append(self)
        self.name = name
        node_cls = type if type else Node
        self.outputs = [
            node_cls(self, name=name, orig_name=orig_name),
        ]
[文档]    @classmethod
    def make(cls, *args, **kwargs):
        expr = cls(*args, **kwargs)
        module = expr.inputs[0]
        oup_name = expr.name
        while module._name != "self":
            oup_name = module._name + "_" + oup_name
            module = module.expr.inputs[0]
        oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name)
        expr.outputs[0]._name = oup_name
        active_module_tracer().current_scope()._insert(expr)
        return expr.outputs[0] 
[文档]    def interpret(self, *inputs):
        return (getattr(inputs[0], self.name),) 
    def __repr__(self):
        out_type = "Tensor"
        if isinstance(self.outputs[0], ModuleNode):
            out_type = self.outputs[0].module_type.__name__
        return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
            self._id, self.outputs[0], self.inputs[0], self.name, out_type
        ) 
# expr: outputs = inputs[0].__call__(*inputs[1:])
[文档]class CallMethod(Expr):
    r"""``CallMethod`` represents a call to the ``__call__`` method of ``Module`` or a method of ``Tensor``.
    Args:
        node: the Node to be called.
        method: the method name.
            Default: "__call__"
    """
    def __init__(self, node, method="__call__"):
        super().__init__()
        if isinstance(node, type):
            assert issubclass(node, Tensor)
            cls = Parameter if issubclass(node, Parameter) else Tensor
            self.inputs = []
            self.const_val = [(0, cls)]
        else:
            assert isinstance(node, (TensorNode, ModuleNode))
            node.users.append(self)
            self.inputs = [
                node,
            ]
            self.const_val = []
        self.method = method
[文档]    @classmethod
    def make(cls, *args, **kwargs):
        expr = cls(*args, **kwargs)
        active_module_tracer().current_scope()._insert(expr)
        return expr 
    @property
    def graph(self):
        if isinstance(self.inputs[0], ModuleNode):
            m_node = self.inputs[0]
            if (
                hasattr(m_node.owner, "argdef_graph_map")
                and m_node.owner.argdef_graph_map
            ):
                assert self.arg_def in m_node.owner.argdef_graph_map
                return m_node.owner.argdef_graph_map[self.arg_def]
        return None
[文档]    def interpret(self, *inputs):
        args, kwargs = self.unflatten_args(inputs)
        obj = args[0]
        meth = getattr(obj, self.method)
        if inspect.ismethod(meth):
            args = args[1:]
        outputs = getattr(obj, self.method)(*args, **kwargs)
        if self.method == "__setitem__":
            outputs = obj
        if outputs is None:
            return outputs
        outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
        return outputs 
    def __repr__(self):
        args = ", ".join(str(i) for i in self.args[1:])
        kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
        outputs = self.outputs
        if self.out_def:
            outputs = self.out_def.unflatten(outputs)
        method = ".%s" % self.method
        if method == ".__call__":
            method = ""
        return "%{}:\t{}{}{}({})".format(
            self._id,
            str(outputs) + " = " if outputs else "",
            self.args[0],
            method,
            ", ".join([args, kwargs]),
        ) 
# expr: outputs = apply(self.opdef, *inputs)
[文档]class Apply(Expr):
    r"""``Apply`` represents a call to :func:`apply`.
    Args:
        opdef: the applied :class:`OpDef`.
    """
    opdef = None
    def __init__(self, opdef):
        super().__init__()
        assert isinstance(opdef, OpDef)
        self.opdef = opdef
        self.inputs = []
[文档]    @classmethod
    def make(cls, *args, **kwargs):
        expr = cls(*args, **kwargs)
        active_module_tracer().current_scope()._insert(expr)
        return expr 
[文档]    def interpret(self, *inputs):
        return apply(self.opdef, *inputs) 
    def __repr__(self):
        return "%{}:\t{} = {}({})".format(
            self._id,
            ", ".join(str(i) for i in self.outputs),
            self.opdef,
            ", ".join(str(i) for i in self.inputs),
        )
    def __getstate__(self):
        state = super().__getstate__()
        state["opdef"] = get_opdef_state(state["opdef"])
        return state
    def __setstate__(self, state):
        state["opdef"] = load_opdef_from_state(state["opdef"])
        for k, v in state.items():
            setattr(self, k, v)
[文档]    @classmethod
    def apply_module_trace_hook(cls, opdef, *inputs):
        for i in inputs:
            node = NodeMixin.get(i, None)
            if node is None:  # capture as constant
                NodeMixin.wrap_safe(i, Constant.make(i))
        if isinstance(opdef, FakeQuant):
            inp_nodes = [NodeMixin.get(inputs[0])]
            for i in inputs[1:]:
                node = Constant.make(i)
                inp_nodes.append(node)
            apply_node = cls.make(opdef)
            for n in inp_nodes:
                n.users.append(apply_node)
            apply_node.inputs = inp_nodes
        else:
            apply_node = cls.make(opdef)
            apply_node.add_inputs(inputs)
        assert not apply_node.const_val
        unset_module_tracing()
        outputs = apply(opdef, *inputs)
        set_module_tracing()
        apply_node.add_outputs(outputs)
        for n, v in zip(apply_node.outputs, outputs):
            NodeMixin.wrap_safe(v, n)
        return list(outputs)  
[文档]class CallFunction(Expr):
    r"""``CallFunction`` represents a call to a built-in function.
    
    Args:
        func: a built-in function.
    """
    def __init__(self, func):
        super().__init__()
        assert isinstance(func, Callable)
        self.func = func
        self.const_val = []
        self.inputs = []
[文档]    @classmethod
    def make(cls, *args, **kwargs):
        expr = cls(*args, **kwargs)
        active_module_tracer().current_scope()._insert(expr)
        return expr 
[文档]    def interpret(self, *inputs):
        args, kwargs = self.unflatten_args(inputs)
        outputs = self.func(*args, **kwargs)
        if outputs is None:
            return outputs
        outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
        return outputs 
    def __repr__(self):
        args = ", ".join(str(i) for i in self.args)
        kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
        outputs = self.outputs
        if self.out_def:
            outputs = self.out_def.unflatten(outputs)
        return "%{}:\t{}{}({})".format(
            self._id,
            str(outputs) + " = " if outputs else "",
            self.func.__module__.rsplit(".")[-1] + "." + self.func.__name__,
            ", ".join([args, kwargs]),
        ) 
# expr outputs = self.value
[文档]class Constant(Expr):
    r"""``Constant`` represents a ``Tensor`` or "Module" which is not the attribute of a Module.
    Args:
        c: a const Tensor or Module.
        name: the name of output Node.
    """
    value = None
    r"""The const Tensor or Module"""
    # TODO: constant cache to reduce the size of dumped model
    _constant_cache = {}
    def __init__(self, c, name=None):
        super().__init__()
        assert isinstance(c, (RawTensor, Module))
        if isinstance(c, Module):
            assert module_tracer.is_builtin(c) or c.is_qat
        self.value = c
        self.name = name
        self.inputs = []
        node_cls = NodeMixin.get_wrapped_type(c)
        self.outputs = [
            node_cls(self, name=name, orig_name=name),
        ]
        self.outputs[0]._name = name if name else "const_" + str(self._id)
[文档]    @classmethod
    def make(cls, *args, **kwargs):
        expr = cls(*args, **kwargs)
        name = "const_module" if isinstance(expr.value, Module) else "const_tensor"
        full_name = name
        if (
            isinstance(expr.value, RawTensor)
            and id(expr.value) in active_module_tracer().id2name
        ):
            full_name = active_module_tracer().id2name[id(expr.value)]
            scope_name = active_module_tracer().current_scope()._module_name
            if full_name and scope_name:
                full_name = ("self." + full_name)[len(scope_name) + 1 :]
            else:
                full_name = name
        else:
            full_name = name
        name = active_module_tracer().current_scope()._create_unique_name(full_name)
        expr.outputs[0]._name = name
        expr.outputs[0]._orig_name = full_name
        active_module_tracer().current_scope()._insert(expr)
        return expr.outputs[0] 
[文档]    def interpret(self, *inputs):
        if isinstance(self.value, RawTensor):
            return Const(self.value.numpy())()
        return (self.value,) 
    def __repr__(self):
        name = self.name
        if name is None:
            name = type(self.value)
        node_type = "Module"
        if isinstance(self.outputs[0], TensorNode):
            node_type = "Tensor"
        return "%{}:\t{} = Constant({}) -> ({})".format(
            self._id, self.outputs[0], name, node_type
        )
    def __getstate__(self):
        state = self.__dict__.copy()
        if "_top_graph" in state:
            state.pop("_top_graph")
        if isinstance(self.value, RawTensor):
            state["value"] = Tensor(self.value)
        return state