megengine.functional.nn.batch_norm¶
- batch_norm(inp, running_mean=None, running_var=None, weight=None, bias=None, *, training=False, momentum=0.9, eps=1e-05, inplace=True)[源代码]¶
对输入进行批标准化。
更多信息参见
BatchNorm2d
和BatchNorm1d
。- 参数
inp (
Tensor
) – 输入张量。weight (
Optional
[Tensor
]) – 可学习仿射参数中的放缩张量。可参阅BatchNorm2d
中的 \(\gamma\)bias (
Optional
[Tensor
]) – 可学习仿射参数中的偏置张量。可参阅BatchNorm2d
中的 \(eta\)training (
bool
) – 一个布尔值,它表示是否执行训练模式下的批归一化,即对当前批数据进行统计并更新统计量。 默认:False
momentum (
float
) – 用于计算running_mean
和running_var
的值。 默认: 0.9eps (
float
) – 添加到分母的单个值,增加数值稳定性。默认:1e-5inplace (
bool
) – 是否更新原始 tensors``running_mean`` 和running_var
tensor, 默认:True
, 如果设置成Flase
, 则不更新原始 tensors, 而是返回新的 tensors.compute_mode – When set to ‘default’, no special requirements will be placed on the precision of intermediate results. When set to ‘float32’, float32 would be used for accumulator and intermediate result, but only effective when input and output are of float16 dtype.
param_dim – a value indicating in which format the parameters are. Default: ‘dim_1c11’, which means NCHW format. And ‘dim_111c’ means NHWC format.