sync_batch_norm

dragon.nn.sync_batch_norm(
  inputs,
  axis=- 1,
  momentum=0.9,
  epsilon=1e-05,
  use_stats=- 1,
  process_group=None,
  **kwargs
)[source]

Apply the batch normalization with synced statistics. [Ioffe & Szegedy, 2015].

The normalization is defined as:

\[y = \frac{x - \mathrm{E}[x]} {\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta \]

The running average of statistics are calculated as:

\[x_{\text{running}} = \text{momentum} * x_{\text{running}} + (1 - \text{momentum}) * x_{\text{batch}} \]
Parameters:
  • inputs (Sequence[dragon.Tensor]) The tensor x, gamma, beta, mean and var.
  • axis (int, optional, default=-1) The channel axis.
  • momentum (Union[float, dragon.Tensor], optional) The value to \(\text{momentum}\).
  • epsilon (float, optional, default=1e-5) The value to \(\epsilon\).
  • use_stats (int, optional, default=-1) Whether to use estimated statistics or not.
  • process_group (ProcessGroup, optional) The group for communication.
Returns:

dragon.Tensor The output tensor.