# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from ..functional import matmul, relu
from ..tensor import Parameter
from . import init
from .module import Module
[文档]class BatchMatMulActivation(Module):
    r"""Batched :func:`~.matmul` with activation(only :func:`~.relu` supported), no transpose anywhere."""
    def __init__(
        self,
        batch: int,
        in_features: int,
        out_features: int,
        bias: bool = True,
        nonlinear_mode="identity",
        **kwargs
    ):
        super().__init__(**kwargs)
        self.batch = batch
        self.out_features = out_features
        self.in_features = in_features
        w_shape = (batch, out_features, in_features)
        self.weight = Parameter(np.zeros(w_shape, dtype=np.float32))
        self.bias = None
        if bias:
            b_shape = (out_features,)
            self.bias = Parameter(np.zeros(b_shape, dtype=np.float32))
        self.nonlinear_mode = nonlinear_mode.lower()
        self.reset_parameters()
    def _get_fanin(self):
        return self.in_features
[文档]    def reset_parameters(self) -> None:
        fanin = self._get_fanin()
        std = np.sqrt(1 / fanin)
        init.normal_(self.weight, 0.0, std)
        if self.bias is not None:
            init.zeros_(self.bias) 
    def _calc_linear(self, x, weight, bias):
        res = matmul(weight, x)
        if self.bias is not None:
            res += bias
        if self.nonlinear_mode == "relu":
            res = relu(res)
        return res
[文档]    def forward(self, x):
        return self._calc_linear(x, self.weight, self.bias) 
    def _module_info_string(self) -> str:
        return "batch={}, in_features={}, out_features={}, bias={}".format(
            self.batch, self.in_features, self.out_features, self.bias is not None
        )