[docs]classautocast:r"""A class to control autocast mode for amp as a context manager or a decorator. Args: enabled: whether autocast mode is enabled. low_prec_dtype: set amp autocast mode's lower precision dtype. It will change the target dtype in tensor casting for better speed and memory. Default: float16. high_prec_dtype: set amp autocast mode's higher precision dtype. It will change the target dtype in tensor casting for better precision. Default: float32. Return: None Examples: .. code-block:: # used as decorator @autocast() def train_step(image, label): with gm: logits = model(image) loss = F.nn.cross_entropy(logits, label) gm.backward(loss) opt.step().clear_grad() return loss # used as context manager def train_step(image, label): with autocast(): with gm: logits = model(image) loss = F.nn.cross_entropy(logits, label) gm.backward(loss) opt.step().clear_grad() return loss """def__init__(self,enabled:bool=True,low_prec_dtype:str="float16",high_prec_dtype:str="float32",):self.enabled=enabledself.high_prec_dtype=high_prec_dtypeself.low_prec_dtype=low_prec_dtypeself._origin_enabled=Noneself._origin_high=Noneself._origin_low=Noneself._origin_configs=Nonedef__enter__(self):ifself.enabled:self._origin_enabled=amp._enabledself._origin_high=amp._get_amp_high_prec_dtype()self._origin_low=amp._get_amp_low_prec_dtype()amp._enabled=self.enabledamp._set_amp_dtype_autocast(self.enabled)amp._set_amp_high_prec_dtype(self.high_prec_dtype)amp._set_amp_low_prec_dtype(self.low_prec_dtype)self._origin_configs=_config._reset_execution_config(compute_mode="float32")def__exit__(self,*args):ifself.enabled:amp._enabled=self._origin_enabledamp._set_amp_dtype_autocast(self._origin_enabled)amp._set_amp_high_prec_dtype(self._origin_high)amp._set_amp_low_prec_dtype(self._origin_low)_config._reset_execution_config(*self._origin_configs)def__call__(self,func):@functools.wraps(func)defwrapper(*args,**kwargs):withself:returnfunc(*args,**kwargs)returnwrapper