BatchNorm1d¶
The BatchNorm1d layer implements batch normalization, a technique designed to stabilize and accelerate the training of deep neural networks by reducing internal covariate shift. This is achieved by normalizing intermediate activations across the batch dimension and subsequently applying a learnable affine transformation.
Mathematical definition¶
Let $ \mathbf{X} \in \mathbb{R}^{N \times d} $ be an input tensor representing a batch of $N$ samples, where each sample has $d$ features. Batch normalization operates feature-wise, normalizing each feature independently across the batch.
During training, the batch-wise mean and variance are computed as
$$ \boldsymbol{\mu}_B = \frac{1}{N} \sum_{i=1}^{N} \mathbf{x}_i \;\in\; \mathbb{R}^{1 \times d}, $$
$$ \boldsymbol{\sigma}_B^2 = \frac{1}{N} \sum_{i=1}^{N} (\mathbf{x}_i - \boldsymbol{\mu}_B)^2 \;\in\; \mathbb{R}^{1 \times d}, $$
where $\mathbf{x}_i \in \mathbb{R}^{1 \times d}$ denotes the $i$-th sample in the batch.
Normalization step¶
Each input sample is normalized using the batch statistics:
$$ \widehat{\mathbf{X}} = \frac{\mathbf{X} - \boldsymbol{\mu}_B} {\sqrt{\boldsymbol{\sigma}_B^2 + \varepsilon}}, \quad \widehat{\mathbf{X}} \in \mathbb{R}^{N \times d}, $$
where $\varepsilon > 0$ is a small constant introduced for numerical stability.
Learnable affine transformation¶
To preserve the representational capacity of the network, batch normalization introduces two learnable parameters:
- Scale parameter: $ \boldsymbol{\gamma} \in \mathbb{R}^{1 \times d} $
- Shift parameter: $ \boldsymbol{\beta} \in \mathbb{R}^{1 \times d} $
The final output of the layer is given by
$$ \mathbf{Y} = \boldsymbol{\gamma} \odot \widehat{\mathbf{X}} + \boldsymbol{\beta}, \quad \mathbf{Y} \in \mathbb{R}^{N \times d}, $$
where $\odot$ denotes element-wise multiplication applied column-wise, i.e., independently to each feature.
Running statistics and inference mode¶
In addition to batch statistics, BatchNorm1d maintains running estimates of the mean and variance:
$$ \boldsymbol{\mu}_{\text{run}} \in \mathbb{R}^{1 \times d}, \quad \boldsymbol{\sigma}^2_{\text{run}} \in \mathbb{R}^{1 \times d}. $$
These statistics are updated during training using an exponential moving average:
$$ \boldsymbol{\mu}_{\text{run}} \leftarrow (1 - \alpha) \boldsymbol{\mu}_{\text{run}} + \alpha \boldsymbol{\mu}_B, $$
$$ \boldsymbol{\sigma}^2_{\text{run}} \leftarrow (1 - \alpha) \boldsymbol{\sigma}^2_{\text{run}} + \alpha \boldsymbol{\sigma}^2_B, $$
where $\alpha$ is the momentum parameter. During inference (evaluation mode), normalization is performed using these constant running statistics:
$$ \widehat{\mathbf{X}} = \frac{\mathbf{X} - \boldsymbol{\mu}_{\text{run}}} {\sqrt{\boldsymbol{\sigma}^2_{\text{run}} + \varepsilon}}. $$
Functional view¶
The BatchNorm1d layer realizes the mapping
$$ \text{BatchNorm1d}:\; \mathbb{R}^{N \times d} \;\longrightarrow\; \mathbb{R}^{N \times d}, $$
where normalization and affine reparameterization are applied independently to each feature across the batch.
Backward Pass and Gradient Computation¶
Computing the gradient $\frac{\partial \mathcal{L}}{\partial \mathbf{X}}$ in BatchNorm is more involved than in direct activations because every $\mathbf{x}_i$ contributes to the batch statistics ($\boldsymbol{\mu}_B$ and $\boldsymbol{\sigma}_B^2$), which in turn affect all normalized outputs $\widehat{\mathbf{X}}$ in the batch.
Applying the Complete Chain Rule, the gradient with respect to a single sample $\mathbf{x}_i$ is the sum of three computational paths:
$$\frac{\partial \mathcal{L}}{\partial \mathbf{x}_i} = \underbrace{\frac{\partial \mathcal{L}}{\partial \widehat{\mathbf{x}}_i} \frac{\partial \widehat{\mathbf{x}}_i}{\partial \mathbf{x}_i}}_{\text{Direct path}} + \underbrace{\left( \sum_{j=1}^N \frac{\partial \mathcal{L}}{\partial \widehat{\mathbf{x}}_j} \frac{\partial \widehat{\mathbf{x}}_j}{\partial \boldsymbol{\mu}_B} \right) \frac{\partial \boldsymbol{\mu}_B}{\partial \mathbf{x}_i}}_{\text{Path via Mean}} + \underbrace{\left( \sum_{j=1}^N \frac{\partial \mathcal{L}}{\partial \widehat{\mathbf{x}}_j} \frac{\partial \widehat{\mathbf{x}}_j}{\partial \boldsymbol{\sigma}_B^2} \right) \frac{\partial \boldsymbol{\sigma}_B^2}{\partial \mathbf{x}_i}}_{\text{Path via Variance}}$$
Step-by-Step Components:¶
- Derivative w.r.t Normalized Activation: From $\mathbf{y}_i = \boldsymbol{\gamma} \widehat{\mathbf{x}}_i + \boldsymbol{\beta}$, we have $\nabla_{\widehat{\mathbf{x}}_i} \mathcal{L} = \nabla_{\mathbf{y}_i} \mathcal{L} \odot \boldsymbol{\gamma}$.
- The Variance path: Since $\boldsymbol{\sigma}_B^2 = \frac{1}{N} \sum (\mathbf{x}_j - \boldsymbol{\mu}_B)^2$, its contribution accounts for how the spread of the batch affects the scaling of every sample.
- The Mean path: Since $\boldsymbol{\mu}_B = \frac{1}{N} \sum \mathbf{x}_j$, its contribution accounts for how shifting the batch center affects every centered sample $(\mathbf{x}_j - \boldsymbol{\mu}_B)$.
By substituting the partial derivatives ($\frac{\partial \boldsymbol{\mu}_B}{\partial \mathbf{x}_i} = \frac{1}{N}$ and $\frac{\partial \boldsymbol{\sigma}_B^2}{\partial \mathbf{x}_i} = \frac{2(\mathbf{x}_i - \boldsymbol{\mu}_B)}{N}$), we obtain the consolidated analytical result used in the backpropagation engine:
$$\frac{\partial \mathcal{L}}{\partial \mathbf{x}_i} = \frac{\boldsymbol{\gamma}}{N \sqrt{\boldsymbol{\sigma}_B^2 + \varepsilon}} \left[ N \nabla_{\mathbf{y}_i} \mathcal{L} - \sum_{j=1}^N \nabla_{\mathbf{y}_j} \mathcal{L} - \widehat{\mathbf{x}}_i \sum_{j=1}^N \left( \nabla_{\mathbf{y}_j} \mathcal{L} \odot \widehat{\mathbf{x}}_j \right) \right]$$
This formulation captures the "batch-talk" inherently required by the normalization process.
Multi-device support¶
BatchNorm1d is device-aware and supports execution on CPU and GPU backends. Learnable parameters are stored as tensors on the selected device, while running statistics are maintained as NumPy or CuPy arrays.
Parameter interface¶
The trainable parameters of the layer are exposed through the parameters() method, which returns
$$ \{\boldsymbol{\gamma}, \boldsymbol{\beta}\}. $$
# Uncomment to install sorix
#!pip install 'sorix @ git+https://github.com/Mitchell-Mirano/sorix.git@main'
from sorix import tensor
from sorix.nn import BatchNorm1d
import numpy as np
samples, features = 8, 3
X = tensor(np.random.randn(samples, features))
X
tensor([[-0.24933387, -0.28525337, -0.60415811],
[ 0.38161554, 0.59052643, 0.31669436],
[ 0.13659471, 1.19939234, -1.1958867 ],
[ 1.43433487, -1.04299107, -0.53035623],
[ 0.97062172, 0.67330625, 1.39688185],
[-1.0496965 , -0.85192622, -1.94755154],
[ 1.61150794, -0.43341638, -0.49170012],
[-0.1007724 , -0.03571685, -0.4692231 ]], dtype=sorix.float64)
bn = BatchNorm1d(features)
print(bn.gamma)
print(bn.beta)
tensor([[1., 1., 1.]], requires_grad=True) tensor([[0., 0., 0.]], requires_grad=True)
Y = bn(X)
Y
tensor([[-0.75918656, -0.35650902, -0.17695236],
[-0.01212848, 0.83521286, 0.81969192],
[-0.30223857, 1.66373029, -0.81738385],
[ 1.23431451, -1.38760451, -0.09707613],
[ 0.68526788, 0.94785592, 1.98878547],
[-1.70683363, -1.12761203, -1.63091531],
[ 1.44409135, -0.5581226 , -0.05523838],
[-0.58328649, -0.01695091, -0.03091136]], dtype=sorix.float64, requires_grad=True)
print(bn.running_mean)
print(bn.running_var)
tensor([[ 0.0391859 , -0.00232599, -0.04406624]], dtype=sorix.float64) tensor([[0.98152036, 0.96171969, 0.99756331]], dtype=sorix.float64)
bn.training = False
Y_eval = bn(X)
Y_eval
tensor([[-0.29122168, -0.28850176, -0.56077269],
[ 0.34563641, 0.6045331 , 0.36119913],
[ 0.09832102, 1.22539519, -1.15322056],
[ 1.40821418, -1.0611688 , -0.4868811 ],
[ 0.94015848, 0.68894388, 1.44269965],
[-1.09907951, -0.86633949, -1.90579909],
[ 1.58704643, -0.43958395, -0.448178 ],
[-0.14126898, -0.03404874, -0.42567365]], dtype=sorix.float64, requires_grad=True)