autocast#
- class autocast(enabled=True, low_prec_dtype='float16', high_prec_dtype='float32')[源代码]#
作为上下文管理器或装饰器来控制amp的自动转换模式的类。
- 参数:
enabled (
bool
) – Whether autocast mode is enabled.low_prec_dtype (
str
) – Set amp autocast mode’s lower precision dtype. It will change the target dtype in tensor casting for better speed and memory. Default: float16.high_prec_dtype (
str
) – Set amp autocast mode’s higher precision dtype. It will change the target dtype in tensor casting for better precision. Default: float32.
实际案例
# used as decorator @autocast() def train_step(image, label): with gm: logits = model(image) loss = F.nn.cross_entropy(logits, label) gm.backward(loss) opt.step().clear_grad() return loss # used as context manager def train_step(image, label): with autocast(): with gm: logits = model(image) loss = F.nn.cross_entropy(logits, label) gm.backward(loss) opt.step().clear_grad() return loss