from...functionalimportones,relu,sqrt,sum,zerosfrom..importconv_bnasFloatfrom.moduleimportQATModuleclass_ConvBnActivation2d(Float._ConvBnActivation2d,QATModule):defget_batch_mean_var(self,inp):def_sum_channel(inp,axis=0,keepdims=True):ifisinstance(axis,int):out=sum(inp,axis=axis,keepdims=keepdims)elifisinstance(axis,tuple):foridx,eleminenumerate(axis):out=sum(inpifidx==0elseout,axis=elem,keepdims=keepdims)returnoutsum1=_sum_channel(inp,(0,2,3))sum2=_sum_channel(inp**2,(0,2,3))reduce_size=inp.size/inp.shape[1]batch_mean=sum1/reduce_sizebatch_var=(sum2-sum1**2/reduce_size)/reduce_sizereturnbatch_mean,batch_vardeffold_weight_bias(self,bn_mean,bn_var):# get fold bn conv param# bn_istd = 1 / bn_std# w_fold = gamma / bn_std * W# b_fold = gamma * (b - bn_mean) / bn_std + betagamma=self.bn.weightifgammaisNone:gamma=ones((self.bn.num_features),dtype="float32")gamma=gamma.reshape(1,-1,1,1)beta=self.bn.biasifbetaisNone:beta=zeros((self.bn.num_features),dtype="float32")beta=beta.reshape(1,-1,1,1)ifbn_meanisNone:bn_mean=zeros((1,self.bn.num_features,1,1),dtype="float32")ifbn_varisNone:bn_var=ones((1,self.bn.num_features,1,1),dtype="float32")conv_bias=self.conv.biasifconv_biasisNone:conv_bias=zeros(self.conv._infer_bias_shape(),dtype="float32")bn_istd=1.0/sqrt(bn_var+self.bn.eps)# bn_istd = 1 / bn_std# w_fold = gamma / bn_std * Wscale_factor=gamma*bn_istdifself.conv.groups==1:w_fold=self.conv.weight*scale_factor.reshape(-1,1,1,1)else:w_fold=self.conv.weight*scale_factor.reshape(self.conv.groups,-1,1,1,1)w_fold=self.apply_quant_weight(w_fold)# b_fold = gamma * (b - bn_mean) / bn_std + betab_fold=beta+gamma*(conv_bias-bn_mean)*bn_istdreturnw_fold,b_folddefupdate_running_mean_and_running_var(self,bn_mean,bn_var,num_elements_per_channel):# update running mean and running var. no grad, use unbiased bn varbn_mean=bn_mean.detach()bn_var=(bn_var.detach()*num_elements_per_channel/(num_elements_per_channel-1))exponential_average_factor=1-self.bn.momentumself.bn.running_mean*=self.bn.momentumself.bn.running_mean+=exponential_average_factor*bn_meanself.bn.running_var*=self.bn.momentumself.bn.running_var+=exponential_average_factor*bn_vardefcalc_conv_bn_qat(self,inp,approx=True):ifself.trainingandnotapprox:conv=self.conv(inp)bn_mean,bn_var=self.get_batch_mean_var(conv)num_elements_per_channel=conv.size/conv.shape[1]self.update_running_mean_and_running_var(bn_mean,bn_var,num_elements_per_channel)else:bn_mean,bn_var=self.bn.running_mean,self.bn.running_var# get gamma and beta in BatchNormgamma=self.bn.weightifgammaisNone:gamma=ones((self.bn.num_features),dtype="float32")gamma=gamma.reshape(1,-1,1,1)beta=self.bn.biasifbetaisNone:beta=zeros((self.bn.num_features),dtype="float32")beta=beta.reshape(1,-1,1,1)# conv_biasconv_bias=self.conv.biasifconv_biasisNone:conv_bias=zeros(self.conv._infer_bias_shape(),dtype="float32")bn_istd=1.0/sqrt(bn_var+self.bn.eps)# bn_istd = 1 / bn_std# w_fold = gamma / bn_std * Wscale_factor=gamma*bn_istdifself.conv.groups==1:w_fold=self.conv.weight*scale_factor.reshape(-1,1,1,1)else:w_fold=self.conv.weight*scale_factor.reshape(self.conv.groups,-1,1,1,1)b_fold=Noneifnot(self.trainingandapprox):# b_fold = gamma * (conv_bias - bn_mean) / bn_std + betab_fold=beta+gamma*(conv_bias-bn_mean)*bn_istdw_qat=self.apply_quant_weight(w_fold)b_qat=self.apply_quant_bias(b_fold,inp,w_qat)conv=self.conv.calc_conv(inp,w_qat,b_qat)ifnot(self.trainingandapprox):returnconv# rescale conv to get original conv outputorig_conv=conv/scale_factor.reshape(1,-1,1,1)ifself.conv.biasisnotNone:orig_conv=orig_conv+self.conv.bias# calculate batch normconv=self.bn(orig_conv)returnconv@classmethoddeffrom_float_module(cls,float_module:Float._ConvBnActivation2d):qat_module=cls(float_module.conv.in_channels,float_module.conv.out_channels,float_module.conv.kernel_size,float_module.conv.stride,float_module.conv.padding,float_module.conv.dilation,float_module.conv.groups,float_module.conv.biasisnotNone,float_module.conv.conv_mode,float_module.conv.compute_mode,padding_mode=float_module.conv.padding_mode,name=float_module.name,)qat_module.conv.weight=float_module.conv.weightqat_module.conv.bias=float_module.conv.biasqat_module.bn=float_module.bnreturnqat_module
[文档]classConvBn2d(_ConvBnActivation2d):r"""A fused :class:`~.QATModule` including :class:`~.module.Conv2d` and :class:`~.module.BatchNorm2d` with QAT support. Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. """defforward(self,inp):returnself.apply_quant_activation(self.calc_conv_bn_qat(inp))
[文档]classConvBnRelu2d(_ConvBnActivation2d):r"""A fused :class:`~.QATModule` including :class:`~.module.Conv2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu` with QAT support. Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. """defforward(self,inp):returnself.apply_quant_activation(relu(self.calc_conv_bn_qat(inp)))