SyncBatchNorm¶
- class
dragon.vm.torch.nn.
SyncBatchNorm
(
num_features,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True,
process_group=None
)[source]¶ Apply the sync batch normalization over input. [Ioffe & Szegedy, 2015].
The normalization is defined as:
\[y = \frac{x - \mathrm{E}[x]} {\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta \]The running average of statistics are calculated as:
\[x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{batch}} \]If
process_group
isNone
, use the value ofdragon.distributed.get_group(...)
.
__init__¶
SyncBatchNorm.
__init__
(
num_features,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True,
process_group=None
)[source]¶Create a
SyncBatchNorm
module.- Parameters:
- num_features (int) – The number of channels.
- eps (float, optional, default=1e-5) – The value to \(\epsilon\).
- momentum (float, optional, default=0.1) – The value to \(\text{momentum}\).
- affine (bool, optional, default=True) –
True
to apply an affine transformation. - track_running_stats (bool, optional, default=True) –
True
to using stats when switching toeval
. - process_group (ProcessGroup, optional) – The group for communication.
Methods¶
convert_sync_batchnorm¶
- classmethod
SyncBatchNorm.
convert_sync_batchnorm
(
module,
process_group=None
)[source]¶ Convert to sync batch normalization recursively.
- Parameters:
- module (dragon.vm.torch.nn.Module) – The module containing batch normalization.
- process_group (ProcessGroup, optional) – The group for communication.
- Returns:
dragon.vm.torch.nn.Module – The output module.