sync_batch_norm

dragon.nn.sync_batch_norm(
  inputs,
  axis=-1,
  momentum=0.9,
  eps=1e-05,
  use_stats=-1,
  process_group=None,
  **kwargs
)[source]

Apply the batch normalization with synced statistics. [Ioffe & Szegedy, 2015].

The normalization is defined as:

\[\text{out} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta \]

The running average of statistics are calculated as:

\[x_{\text{running}} = \text{momentum} * x_{\text{running}} + (1 - \text{momentum}) * x_{\text{stat}} \]

Note that the number of inputs should be 5, i.e., this operators is implemented into the fused version.

However, you can still fix the gamma and beta, by disabling the their gradients directly.

Parameters:
  • inputs (Sequence[dragon.Tensor]) – The tensor x, gamma, beta, mean and var.
  • axis (int, optional, default=-1) – The channel axis.
  • momentum (float, optional, default=0.9) – The momentum for average.
  • eps (float, optional, default=1e-5) – The value of \(\epsilon\).
  • use_stats (int, optional, default=-1) – Whether to use estimated statistics or not.
  • process_group (ProcessGroup, optional) – The group for communication.
Returns:

dragon.Tensor – The output tensor.