Normalization in Transformers
Explore the differences between batch normalization and layer normalization in transformer architectures. Understand how normalization stabilizes training, why layer norm suits variable batch sizes and autoregressive generation, and the impact of PreNorm versus PostNorm placements. Gain depth in normalization's role in large-scale transformer training.
We'll cover the following...
- What is normalization, and why do neural networks need it?
- How does batch normalization work, and why does it struggle with transformers?
- How does layer normalization work?
- Why do transformers use layer normalization instead of batch normalization?
- What is the difference between PreNorm and PostNorm placement?
- Conclusion
In roles that involve tinkering with LLMs and their architectures, interviewers often ask why transformers use layer normalization instead of batch normalization. The question assesses your understanding of training dynamics and architectural design, and it demonstrates whether you can reason about model behavior under real-world constraints, such as distributed training or autoregressive decoding. Beyond naming models that use layer norm, strong candidates explain how batch and layer norm differ, why batch norm’s assumptions break in Transformer settings, and how normalization choices interact with sequence modeling and parallel computation.
What is normalization, and why do neural networks need it?
In machine learning, normalization broadly refers to rescaling data so that its values fall within a common range or distribution. For inputs, normalization may map features into [0, 1] or standardize them to have a mean of 0 and a variance of 1, helping gradient descent converge faster. Inside neural networks, normalization layers (like BN or LN) perform a similar standardization on the layer’s activations. These layers typically subtract the mean and divide by the standard deviation of certain activations. The goal is to stabilize training by preventing internal covariate shift—i.e., large shifts in a layer’s input distribution from one batch to the next, which can slow training and cause gradients to explode or vanish.
Normalization comes in many flavors (batch, layer, instance, group norm, etc.), but we focus on batch vs. layer for Transformers. Batch normalization (BN) computes the mean and variance of each feature across the batch. Layer normalization (LN) computes the mean and variance across the features of a single example. A key difference is that BN’s behavior depends on having a meaningful batch of data, whereas LN works on each sample independently. This difference will be crucial when we look at sequence models.
Educative byte: The term “internal covariate shift” was introduced in the original BatchNorm paper (Ioffe & Szegedy, 2015) to describe how the distribution of inputs to each layer changes during training as parameters update. While later research questioned whether this is the primary mechanism behind BN’s success, the intuition remains useful: normalization layers keep activations in a well-behaved range, preventing the “activation explosion” that can destabilize deep networks.
To clarify this distinction, consider a simple example. Imagine two input vectors (e.g., token embeddings):
x1 = [2.0, 4.0, 6.0]x2 = [1.0, 3.0, 5.0]
With BN, the mean and variance are computed per feature across the batch:
Feature-wise mean = [(2+1)/2, (4+3)/2, (6+5)/2] = [1.5, 3.5, 5.5]
With LN, each vector (row) is normalized independently:
x1 → mean = 4.0, std = 1.63 → x1' = [−1.225, 0, 1.225]
x2 → mean = 3.0, std = 1.63 → x2' = [−1.225, 0, 1.225]
So BN normalizes each feature by comparing across tokens, while LN normalizes across features within a single token.
Educative byte: Think of BN as adjusting scores so that the whole class average is zero. In contrast, LN allows each student (token) to normalize their scores, thereby achieving a zero mean and unit variance independently.
This distinction becomes crucial for Transformers, where we often process one token at a time or use variable-length sequences. In the next few sections, we’ll explore why LN fits these constraints better.
Quick answer for interview: Normalization in neural networks standardizes activations by subtracting the mean and dividing by the standard deviation. This stabilizes training by preventing internal covariate shift—large swings in layer input distributions that can cause exploding or vanishing gradients. The key difference between batch and layer normalization is what they normalize over: BN computes statistics across the batch dimension for each feature, while LN computes statistics across features for each sample independently. This distinction is critical for Transformers because LN works with any batch size, including batch size 1 during autoregressive generation.
How does batch normalization work, and why does it struggle with transformers?
Batch normalization (BN) ...