Deep Neural Network Library (DNNL)
1.2.0

Performance library for Deep Learning

Batch Normalization

The batch normalization primitive performs a forward or backward batch normalization operation on 0D, 2D, or 3D spatial data.

The batch normalization operation is defined by the following formulas. We show formulas only for 2D spatial data which are straightforward to generalize to cases of higher and lower dimensions. Variable names follow the standard Naming Conventions.

\[ dst(n, c, h, w) = \gamma(c) \cdot \frac{src(n, c, h, w) - \mu(c)} {\sqrt{\sigma^2(c) + \varepsilon}} + \beta(c), \]

where

- \(\gamma(c), \beta(c)\) are optional scale and shift for a channel (see dnnl_use_scaleshift flag),
- \(\mu(c), \sigma^2(c)\) are computed at run-time or provided by a user mean and variance for channel (see dnnl_use_global_stats flag), and
- \(\varepsilon\) is a constant to improve numerical stability.

When mean and variance are computed at a run-time the following formulas are used:

- \(\mu(c) = \frac{1}{NHW} \sum\limits_{nhw} src(n, c, h, w)_{}\),
- \(\sigma^2(c) = \frac{1}{NHW} \sum\limits_{nhw} {}_{} (src(n, c, h, w) - \mu(c))^2\).

The \(\gamma(c)\) and \(\beta(c)\) tensors are considered learnable.

In training mode the primitive also optionally supports fusion with ReLU activation with zero negative slope applied to the result (see dnnl_fuse_norm_relu flag).

- Note
- The batch normalization primitive computes population mean and variance and not their sample or unbiased versions that are typically used to compute running mean and variance.
- Using the mean and variance computed by the batch normalization primitive, running mean and variance \(\hat\mu\) and \(\hat\sigma^2\) can be computed as
\[ \hat\mu := \alpha \cdot \hat\mu + (1 - \alpha) \cdot \mu, \\ \hat\sigma^2 := \alpha \cdot \hat\sigma^2 + (1 - \alpha) \cdot \sigma^2. \]

- If mean and variance are computed at run-time (i.e., dnnl_use_global_stats is not set), they become outputs for the propagation kind dnnl_forward_training (since they would be required during the backward propagation) and are not exposed for the propagation kind dnnl_forward_inference.
- If batch normalization is created with ReLU fusion (i.e., dnnl_fuse_norm_relu is set), for the propagation kind dnnl_forward_training the primitive would produce a
`workspace`

memory as one extra output. This memory is required to compute the backward propagation. When the primitive is executed with propagation kind dnnl_forward_inference, the workspace is not produced. Behavior would be the same as creating a batch normalization primitive with ReLU as a post-op (see section below).

The backward propagation computes \(diff\_src(n, c, h, w)\), \(diff\_\gamma(c)^*\), and \(diff\_\beta(c)^*\) based on \(diff\_dst(n, c, h, w)\), \(src(n, c, h, w)\), \(\mu(c)\), \(\sigma^2(c)\), \(\gamma(c) ^*\), and \(\beta(c) ^*\).

The tensors marked with an asterisk are used only when the primitive is configured to use \(\gamma(c)\), and \(\beta(c)\) (i.e., dnnl_use_scaleshift is set).

Depending on the flags and propagation kind, the batch normalization primitive requires different inputs and outputs. For clarity, the summary table is shown below.

TODO: add?

- The different flavors of the primitive are partially controlled by the
`flags`

parameter that is passed to the operation descriptor initialization function (e.g., dnnl::batch_normalization_forward::desc::desc()). Multiple flags can be set using the bitwise OR operator (`|`

). - For forward propagation, the mean and variance might be either computed at run-time (in which case they are outputs of the primitive) or provided by a user (in which case they are inputs). In the latter case, a user must set the dnnl_use_global_stats flag. For the backward propagation, the mean and variance are always input parameters.
- The memory format and data type for
`src`

and`dst`

are assumed to be the same, and in the API are typically referred as`data`

(e.g., see`data_desc`

in dnnl::batch_normalization_forward::desc::desc()). The same holds for`diff_src`

and`diff_dst`

. The corresponding memory descriptors are referred to as`diff_data_desc`

. - Both forward and backward propagation support in-place operations, meaning that
`src`

can be used as input and output for forward propagation, and`diff_dst`

can be used as input and output for backward propagation. In case of in-place operation, the original data will be overwritten. - As mentioned above, the batch normalization primitive can be fused with ReLU activation even in the training mode. In this case, on the forward propagation the primitive has one additional output,
`workspace`

, that should be passed during the backward propagation.

The operation supports the following combinations of data types:

Propagation | Source / Destination | Mean / Variance / ScaleShift |
---|---|---|

forward / backward | f32, bf16 | f32 |

forward | f16 | f32 |

forward | s8 | f32 |

- Warning
- There might be hardware and/or implementation specific restrictions. Check Implementation Limitations section below.

The mean ( \(\mu\)) and variance ( \(\sigma^2\)) are separate 1D tensors of size \(C\).

The format of the corresponding memory object must be dnnl_x (dnnl_a).

If used, the scale ( \(\gamma\)) and shift ( \(\beta\)) are combined in a single 2D tensor of shape \(2 \times C\).

The format of the corresponding memory object must be dnnl_nc (dnnl_ab).

Like other CNN primitives, the batch normalization primitive expects data to be \(N \times C \times SP_n \times \cdots \times SP_0\) tensor.

The batch normalization primitive is optimized for the following memory formats:

Spatial | Logical tensor | Implementations optimized for memory formats |
---|---|---|

0D | NC | dnnl_nc (dnnl_ab) |

2D | NCHW | dnnl_nchw (dnnl_abcd), dnnl_nhwc (dnnl_acdb), optimized^ |

3D | NCDHW | dnnl_ncdhw (dnnl_abcde), dnnl_ndhwc (dnnl_acdeb), optimized^ |

Here *optimized^* means the format that comes out of any preceding compute-intensive primitive.

Post-ops and attributes enable you to modify the behavior of the batch normalization primitive by chaining certain operations after the batch normalization operation. The following post-ops are supported by batch normalization primitives:

Propagation | Type | Operation | Description |
---|---|---|---|

forward | post-op | eltwise | Applies an Eltwise operation to the result (currently only dnnl_eltwise_relu algorithm is supported) |

- Note
- As mentioned in Primitive Attributes, the post-ops should be used for inference only. For instance, using ReLU as a post-op would not produce an additional output
`workspace`

that is required to compute backward propagation correctly. Hence, in case of training one should use the dnnl_fuse_norm_relu directly.

- Refer to Data Types for limitations related to data types support.
- For the data types that have forward propagation support only, mean and variance must be provided by a user (i.e., dnnl_use_global_stats is set).

- For backward propagation, use the same memory format for
`src`

,`diff_dst`

, and`diff_src`

(the format of the`diff_dst`

and`diff_src`

are always the same because of the API). Different formats are functionally supported but lead to highly suboptimal performance. - Use in-place operations whenever possible.