megengine.functional.nn.batch_norm

batch_norm(inp, running_mean=None, running_var=None, weight=None, bias=None, *, training=False, momentum=0.9, eps=1e-05, inplace=True)[源代码]

对输入进行批标准化。

更多信息参见 BatchNorm2dBatchNorm1d

参数
  • inp (Tensor) – 输入张量。

  • running_mean (Optional[Tensor]) – 存储运行中的均值的张量。

  • running_var (Optional[Tensor]) – 存储运行中的方差的张量。

  • weight (Optional[Tensor]) – 可学习仿射参数中的放缩张量。可参阅 BatchNorm2d 中的 \(\gamma\)

  • bias (Optional[Tensor]) – 可学习仿射参数中的偏置张量。可参阅 BatchNorm2d 中的 \(eta\)

  • training (bool) – 一个布尔值,它表示是否执行训练模式下的批归一化,即对当前批数据进行统计并更新统计量。 默认: False

  • momentum (float) – 用于计算 running_meanrunning_var 的值。 默认: 0.9

  • eps (float) – 添加到分母的单个值,增加数值稳定性。默认:1e-5

  • inplace (bool) – 是否更新原始 tensors``running_mean`` 和 running_var tensor, 默认: True, 如果设置成 Flase, 则不更新原始 tensors, 而是返回新的 tensors.

  • compute_mode – When set to ‘default’, no special requirements will be placed on the precision of intermediate results. When set to ‘float32’, float32 would be used for accumulator and intermediate result, but only effective when input and output are of float16 dtype.

  • param_dim – a value indicating in which format the parameters are. Default: ‘dim_1c11’, which means NCHW format. And ‘dim_111c’ means NHWC format.