autocast

class autocast(enabled=True, low_prec_dtype='float16', high_prec_dtype='float32')[源代码]

作为上下文管理器或装饰器来控制amp的自动转换模式的类。

参数
  • enabled (bool) – whether autocast mode is enabled.

  • low_prec_dtype (str) – set amp autocast mode’s lower precision dtype. It will change the target dtype in tensor casting for better speed and memory. Default: float16.

  • high_prec_dtype (str) – set amp autocast mode’s higher precision dtype. It will change the target dtype in tensor casting for better precision. Default: float32.

返回

None

实际案例

# used as decorator
@autocast()
def train_step(image, label):
    with gm:
        logits = model(image)
        loss = F.nn.cross_entropy(logits, label)
        gm.backward(loss)
    opt.step().clear_grad()
    return loss

# used as context manager
def train_step(image, label):
    with autocast():
        with gm:
            logits = model(image)
            loss = F.nn.cross_entropy(logits, label)
            gm.backward(loss)
    opt.step().clear_grad()
    return loss