autocast

class autocast(enabled=True, low_prec_dtype='float16', high_prec_dtype='float32')[源代码]

作为上下文管理器或装饰器来控制amp的自动转换模式的类。

参数
  • enabled (bool) – 是否启用自动转换模式。

  • low_prec_dtype (str) – 设置 amp 自动转换模式的低精度 dtype 。它将改变张量转换中的目标 dtype 以获得更好的速度和内存。默认值: float16。

  • high_prec_dtype (str) – 设置 amp 自动转换模式的高精度 dtype 。它将改变张量转换中的目标 dtype 以获得更好的速度和内存。默认值: float16。

实际案例

# used as decorator
@autocast()
def train_step(image, label):
    with gm:
        logits = model(image)
        loss = F.nn.cross_entropy(logits, label)
        gm.backward(loss)
    opt.step().clear_grad()
    return loss

# used as context manager
def train_step(image, label):
    with autocast():
        with gm:
            logits = model(image)
            loss = F.nn.cross_entropy(logits, label)
            gm.backward(loss)
    opt.step().clear_grad()
    return loss