# -*- coding: utf-8 -*-# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")## Copyright (c) 2014-2021 Megvii Inc. All rights reserved.## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.fromtypingimportOptionalimportnumpyasnpfrom..distributed.groupimportWORLD,Groupfrom..functional.nnimportbatch_norm,sync_batch_normfrom..tensorimportParameter,Tensorfrom.importinitfrom.moduleimportModuleclass_BatchNorm(Module):def__init__(self,num_features,eps=1e-5,momentum=0.9,affine=True,track_running_stats=True,freeze=False,compute_mode="default",param_dim="dim_1c11",**kwargs):super(_BatchNorm,self).__init__(**kwargs)self.num_features=num_featuresself.eps=epsself.momentum=momentumself.affine=affineself.track_running_stats=track_running_statsself._track_running_stats_saved=track_running_statsself.freeze=freezeself.compute_mode=compute_modeself.param_dim=param_dimifself.freeze:assert(self._track_running_stats_saved),"track_running_stats must be initilized to True if freeze is True"tshape=(1,self.num_features,1,1)ifself.affine:self.weight=Parameter(np.ones(tshape,dtype=np.float32))self.bias=Parameter(np.zeros(tshape,dtype=np.float32))else:self.weight=Noneself.bias=Noneifself.track_running_stats:self.running_mean=Tensor(np.zeros(tshape,dtype=np.float32))self.running_var=Tensor(np.ones(tshape,dtype=np.float32))else:self.running_mean=Noneself.running_var=Nonedefreset_running_stats(self)->None:ifself.track_running_stats:init.zeros_(self.running_mean)init.ones_(self.running_var)defreset_parameters(self)->None:self.reset_running_stats()ifself.affine:init.ones_(self.weight)init.zeros_(self.bias)def_check_input_ndim(self,inp):raiseNotImplementedErrordefforward(self,inp):self._check_input_ndim(inp)ifself._track_running_stats_saved==False:assert(self.track_running_stats==False),"track_running_stats can not be initilized to False and changed to True later"_weight=self.weight_bias=self.biasifself.freeze:if_weightisnotNone:_weight=_weight.detach()if_biasisnotNone:_bias=_bias.detach()# fastpath excution for freezescale=(self.running_var+self.eps)**(-0.5)if_weightisnotNone:scale*=_weightbias=-self.running_mean*scaleif_biasisnotNone:bias+=_biasreturninp*scale+biasifself.trainingandself.track_running_stats:exponential_average_factor=self.momentumelse:exponential_average_factor=0.0# uselessoutput=batch_norm(inp,self.running_meanifself.track_running_statselseNone,self.running_varifself.track_running_statselseNone,_weight,_bias,training=self.trainingor((self.running_meanisNone)and(self.running_varisNone)),momentum=exponential_average_factor,eps=self.eps,compute_mode=self.compute_mode,param_dim=self.param_dim,)returnoutputdef_module_info_string(self)->str:s=("{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ""track_running_stats={track_running_stats}")returns.format(**self.__dict__)
[文档]classSyncBatchNorm(_BatchNorm):r"""Applies Synchronized Batch Normalization for distributed training. Args: num_features: usually :math:`C` from an input of shape :math:`(N, C, H, W)` or the highest ranked dimension of an input less than 4D. eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the ``running_mean`` and ``running_var`` computation. Default: 0.9 affine: a boolean value that when set to True, this module has learnable affine parameters. Default: True track_running_stats: when set to True, this module tracks the running mean and variance. When set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: True freeze: when set to True, this module does not update the running mean and variance, and uses the running mean and variance instead of the batch mean and batch variance to normalize the input. The parameter takes effect only when the module is initilized with track_running_stats as True. Default: False group: communication group, caculate mean and variance between this group. Default: :obj:`~.distributed.WORLD` """def__init__(self,num_features,eps=1e-5,momentum=0.9,affine=True,track_running_stats=True,freeze=False,group:Optional[Group]=WORLD,**kwargs)->None:super().__init__(num_features,eps,momentum,affine,track_running_stats,freeze,**kwargs)self.group=groupdef_check_input_ndim(self,inp):iflen(inp.shape)notin{2,3,4}:raiseValueError("expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape)))defforward(self,inp):self._check_input_ndim(inp)inp_shape=inp.shape_ndims=len(inp_shape)if_ndims!=4:new_shape=Tensor([1,1,1,1],device=inp.device)origin_shape=inp_shapeif_ndims==2:new_shape[:2]=origin_shape[:2]elif_ndims==3:new_shape[:3]=origin_shape[:3]else:raiseValueError("expected 2D, 3D or 4D input (got {}D input)".format(len(inp_shape)))inp=inp.reshape(new_shape)ifself.trainingandself.track_running_stats:exponential_average_factor=self.momentumelse:exponential_average_factor=0.0# useless_weight=self.weight_bias=self.biasifself.freeze:if_weightisnotNone:_weight=_weight.detach()if_biasisnotNone:_bias=_bias.detach()output=sync_batch_norm(inp,self.running_mean,self.running_var,_weight,_bias,training=(self.trainingandnotself.freeze)or((self.running_meanisNone)and(self.running_varisNone)),momentum=exponential_average_factor,eps=self.eps,group=self.group,)if_ndims!=4:output=output.reshape(origin_shape)returnoutput
[文档]classBatchNorm1d(_BatchNorm):r"""Applies Batch Normalization over a 2D/3D tensor. Refer to :class:`~.BatchNorm2d` for more information. """def_check_input_ndim(self,inp):iflen(inp.shape)notin{2,3}:raiseValueError("expected 2D or 3D input (got {}D input)".format(len(inp.shape)))
[文档]classBatchNorm2d(_BatchNorm):r"""Applies Batch Normalization over a 4D tensor. .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta The mean and standard-deviation are calculated per-dimension over the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors. By default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default :attr:`momentum` of 0.9. If :attr:`track_running_stats` is set to ``False``, this layer will not keep running estimates, batch statistics is used during evaluation time instead. Because the Batch Normalization is done over the `C` dimension, computing statistics on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. .. note:: The update formula for ``running_mean`` and ``running_var`` (taking ``running_mean`` as an example) is .. math:: \textrm{running_mean} = \textrm{momentum} \times \textrm{running_mean} + (1 - \textrm{momentum}) \times \textrm{batch_mean} which could be defined differently in other frameworks. Most notably, ``momentum`` of 0.1 in PyTorch is equivalent to ``mementum`` of 0.9 here. Args: num_features: usually :math:`C` from an input of shape :math:`(N, C, H, W)` or the highest ranked dimension of an input less than 4D. eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the ``running_mean`` and ``running_var`` computation. Default: 0.9 affine: a boolean value that when set to True, this module has learnable affine parameters. Default: True track_running_stats: when set to True, this module tracks the running mean and variance. When set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: True freeze: when set to True, this module does not update the running mean and variance, and uses the running mean and variance instead of the batch mean and batch variance to normalize the input. The parameter takes effect only when the module is initilized with track_running_stats as True. Default: False Examples: .. testcode:: import numpy as np import megengine as mge import megengine.module as M # With Learnable Parameters m = M.BatchNorm2d(4) inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32")) oup = m(inp) print(m.weight.numpy().flatten(), m.bias.numpy().flatten()) # Without L`e`arnable Parameters m = M.BatchNorm2d(4, affine=False) oup = m(inp) print(m.weight, m.bias) Outputs: .. testoutput:: [1. 1. 1. 1.] [0. 0. 0. 0.] None None """def_check_input_ndim(self,inp):iflen(inp.shape)!=4:raiseValueError("expected 4D input (got {}D input)".format(len(inp.shape)))