# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
from typing import Union
from .. import functional as F
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
from ..logger import get_logger
from ..module import Module
from ..tensor import Parameter
from .utils import (
QParams,
QParamsModuleMixin,
QuantMode,
create_qparams,
fake_quant_tensor,
tqt_forward,
)
logger = get_logger(__name__)
class _FakeQuantize(Module):
def __init__(
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
):
super().__init__()
if isinstance(dtype, str):
if not dtype in _builtin_quant_dtypes:
raise ValueError(
"unknown dtype: {}, only support {}".format(
dtype, _builtin_quant_dtypes.keys()
)
)
dtype = _builtin_quant_dtypes[dtype]
if "narrow_range" in kwargs:
del kwargs["narrow_range"]
logger.warning(
"FakeQuantize currently has no narrow_range param "
"so it is ignored here",
exc_info=DeprecationWarning,
)
self.dtype = dtype
self.qmin = dtype.qmin
self.qmax = dtype.qmax
self.enabled = enable
def enable(self):
self.enabled = True
def disable(self):
self.enabled = False
def fake_quant_forward(self, inp, qparams: QParams = None):
raise NotImplementedError
def normal_foward(self, inp, qparams: QParams = None):
return inp
def forward(self, inp, qparams: QParams = None):
if self.enabled:
return self.fake_quant_forward(inp, qparams=qparams)
else:
return self.normal_foward(inp, qparams=qparams)
[文档]class TQT(_FakeQuantize, QParamsModuleMixin):
r"""
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.
:param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
quantization dtype of input.
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
"""
def __init__(
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
):
super().__init__(dtype, enable, **kwargs)
self.scale = Parameter(0.0, dtype="float32")
def fake_quant_forward(self, inp, qparams: QParams = None):
# when enable, TQT will do fakequant forward, finetune the scale
return tqt_forward(self.qmin, self.qmax, inp, self.scale)
def set_qparams(self, qparams: QParams):
assert (
qparams.mode == QuantMode.SYMMERTIC
), "only symmetric quantization is supported by TQT"
if qparams.scale is None:
raise AssertionError("Can not get an initialized scale")
self.scale[...] = F.log(qparams.scale) / math.log(2)
def get_qparams(self):
return create_qparams(QuantMode.SYMMERTIC, self.dtype, scale=2 ** self.scale)
[文档]class FakeQuantize(_FakeQuantize):
r"""
A module to do quant and dequant according to observer's scale and zero_point.
:param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
quantization dtype of input.
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
"""
def fake_quant_forward(self, inp, qparams: QParams = None):
assert (
qparams.dtype_meta is self.dtype
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
qparams.dtype_meta, self.dtype
)
return fake_quant_tensor(inp, qparams)