[文档]classQATModule(Module):r"""Base class of quantized-float related :class:`~.Module`, basically for QAT and Calibration. Use :meth:`from_float_module` to generate a instance from float :class:`~.Module`. Or use :func:`~.quantize.quantize_qat` to do it recursively and automatically. Can also be converted to :class:`~.QuantizedModule` for deployment using :func:`~.quantize.quantize` further. """with_weight=Truewith_act=Truedef__init__(self,**kwargs):super().__init__(**kwargs)self.weight_observer=None# type: Observerself.act_observer=None# type: Observerself.weight_fake_quant=None# type: FakeQuantizeself.act_fake_quant=None# type: FakeQuantizedef__repr__(self):return"QAT."+super().__repr__()
[文档]defset_qconfig(self,qconfig:QConfig):r"""Set quantization related configs with ``qconfig``, including observer and fake_quant for weight and activation. """defsafe_call(func):returnfunc()iffuncisnotNoneelseNoneifself.with_act:self.act_observer=safe_call(qconfig.act_observer)self.act_fake_quant=safe_call(qconfig.act_fake_quant)ifself.with_weight:self.weight_observer=safe_call(qconfig.weight_observer)self.weight_fake_quant=safe_call(qconfig.weight_fake_quant)
def_enable_exec(self,with_module,func,enable):ifnotwith_moduleornotfunc:returnifenable:func.enable()else:func.disable()defset_fake_quant(self,enable):self._enable_exec(self.with_act,self.act_fake_quant,enable)self._enable_exec(self.with_weight,self.weight_fake_quant,enable)defset_observer(self,enable):self._enable_exec(self.with_act,self.act_observer,enable)self._enable_exec(self.with_weight,self.weight_observer,enable)def_apply_fakequant_with_observer(self,target:Tensor,fake_quant:FakeQuantize,observer:Observer):# do observerifobserverisNone:oup=targetqparams=Noneelse:oup=observer(target)qparams=observer.get_qparams()# do fake quantiffake_quantisnotNone:oup=fake_quant(oup,qparams)# use qparams of fake_quant if have.ifhasattr(fake_quant,"get_qparams"):qparams=fake_quant.get_qparams()# set to tensor qparams.ifqparamsisnotNone:oup.qparams.update(qparams)returnoup
[文档]defapply_quant_weight(self,target:Tensor):r"""Apply weight's observer and fake_quant from ``qconfig`` on ``target``."""returnself._apply_fakequant_with_observer(target,self.weight_fake_quant,self.weight_observer)
[文档]defapply_quant_activation(self,target:Tensor):r"""Apply weight's observer and fake_quant from ``qconfig`` on ``target``."""returnself._apply_fakequant_with_observer(target,self.act_fake_quant,self.act_observer)
[文档]defapply_quant_bias(self,target:Tensor,inp:Tensor,w_qat:Tensor):r"""Use :func:`~.fake_quant_bias` to process ``target``. Only valid when ``act_fake_quant`` and ``weight_fake_quant`` are both enabled. """# bias should have the same dtype as activation, so act_fake_quant can also# decide whether to do bias fakequantif(self.act_fake_quantandself.act_fake_quant.enabledandself.weight_fake_quantandself.weight_fake_quant.enabled):b_qat=fake_quant_bias(target,inp,w_qat)else:b_qat=targetreturnb_qat
[文档]defget_weight_dtype(self):r"""Get weight's quantization dtype as the method from ``qconfig``."""returnself._get_method_result("get_quantized_dtype",self.weight_fake_quant,self.weight_observer)
[文档]defget_activation_dtype(self):r"""Get activation's quantization dtype as the method from ``qconfig``."""returnself._get_method_result("get_quantized_dtype",self.act_fake_quant,self.act_observer)
[文档]@classmethod@abstractmethoddeffrom_float_module(cls,float_module:Module):r"""Return a :class:`~.QATModule` instance converted from a float :class:`~.Module` instance. """