Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active March 26, 2025 03:08
Show Gist options
  • Save pszemraj/6f596c5266dd7dfa7cb1c5c179f35644 to your computer and use it in GitHub Desktop.
Save pszemraj/6f596c5266dd7dfa7cb1c5c179f35644 to your computer and use it in GitHub Desktop.
LayerNorm Scaling implementation to mitigate the Curse of Depth in LLMs.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class LayerNormScaling(nn.Module):
"""
LayerNorm Scaling implementation to mitigate the Curse of Depth in LLMs.
This module applies Layer Normalization and then scales the output by 1/sqrt(layer_depth)
to prevent the variance explosion issue in deeper transformer layers.
Args:
hidden_size (int): The size of the input and output features.
layer_idx (int): The index of the current layer, starting from 1.
eps (float, optional): A small value added to the denominator for numerical stability. Default: 1e-6.
"""
def __init__(self, hidden_size, layer_idx, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
self.layer_idx = layer_idx
self.scale_factor = 1.0 / math.sqrt(max(1, layer_idx)) # Scale factor based on layer depth
def forward(self, hidden_states):
"""
Apply layer normalization with depth-based scaling.
Args:
hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_length, hidden_size]
Returns:
torch.Tensor: Normalized and scaled tensor of the same shape
"""
# Standard Layer Normalization
mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
normalized = (hidden_states - mean) / torch.sqrt(variance + self.variance_epsilon)
# Apply weight and bias
output = self.weight * normalized + self.bias
# Apply scaling based on layer depth to mitigate CoD
output = output * self.scale_factor
return output
class RMSNormScaling(nn.Module):
"""
RMSNorm Scaling implementation (as used in LLaMA) to mitigate the Curse of Depth in LLMs.
This module applies RMS Normalization and then scales the output by 1/sqrt(layer_depth)
to prevent the variance explosion issue in deeper transformer layers.
Args:
hidden_size (int): The size of the input and output features.
layer_idx (int): The index of the current layer, starting from 1.
eps (float, optional): A small value added to the denominator for numerical stability. Default: 1e-6.
"""
def __init__(self, hidden_size, layer_idx, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.layer_idx = layer_idx
self.scale_factor = 1.0 / math.sqrt(max(1, layer_idx)) # Scale factor based on layer depth
def forward(self, hidden_states):
"""
Apply RMS normalization with depth-based scaling.
Args:
hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_length, hidden_size]
Returns:
torch.Tensor: Normalized and scaled tensor of the same shape
"""
# RMS Normalization (as used in LLaMA)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# Apply weight
output = self.weight * hidden_states
# Apply scaling based on layer depth to mitigate CoD
output = output * self.scale_factor
return output
# Example of how to use these in a transformer layer
class TransformerLayerWithScaling(nn.Module):
"""
Example transformer layer using LayerNormScaling to mitigate the Curse of Depth.
This is a simple implementation showing how to use the scaling in
a Pre-LN transformer architecture.
"""
def __init__(self, hidden_size, num_heads, ff_dim, layer_idx, dropout=0.1):
super().__init__()
self.layer_idx = layer_idx
# Pre-LN with Scaling for Attention
self.ln1 = RMSNormScaling(hidden_size, layer_idx)
# Multi-head attention
self.attention = nn.MultiheadAttention(hidden_size, num_heads, dropout=dropout, batch_first=True)
# Pre-LN with Scaling for Feed-forward
self.ln2 = RMSNormScaling(hidden_size, layer_idx)
# Feed-forward network
self.ff = nn.Sequential(
nn.Linear(hidden_size, ff_dim),
nn.SiLU(), # SiLU/Swish activation as used in LLaMA
nn.Linear(ff_dim, hidden_size),
nn.Dropout(dropout)
)
def forward(self, x, attention_mask=None):
# Apply Pre-LN with scaling for attention
ln_out = self.ln1(x)
attn_out, _ = self.attention(ln_out, ln_out, ln_out, key_padding_mask=attention_mask)
x = x + attn_out
# Apply Pre-LN with scaling for feed-forward
ln_out = self.ln2(x)
ff_out = self.ff(ln_out)
x = x + ff_out
return x
# Example of how to apply to an existing architecture like LLaMA
def apply_cod_mitigation_to_llama(model):
"""
Apply the Curse of Depth mitigation to an existing LLaMA model.
This function demonstrates how to retrofit existing models.
Args:
model: A LLaMA model instance
Returns:
The modified model with LayerNorm Scaling
"""
# Get the transformer layers
layers = model.model.layers
# Replace the layer norms with scaled versions
for i, layer in enumerate(layers):
# Layer index starts from 1 in the paper
layer_idx = i + 1
# Convert input LayerNorm
if hasattr(layer, 'input_layernorm'):
hidden_size = layer.input_layernorm.weight.shape[0]
eps = layer.input_layernorm.variance_epsilon
layer.input_layernorm = RMSNormScaling(hidden_size, layer_idx, eps)
# Convert post-attention LayerNorm
if hasattr(layer, 'post_attention_layernorm'):
hidden_size = layer.post_attention_layernorm.weight.shape[0]
eps = layer.post_attention_layernorm.variance_epsilon
layer.post_attention_layernorm = RMSNormScaling(hidden_size, layer_idx, eps)
return model

The Curse of Depth in Large Language Models: Overview and Implementation

Paper Overview

The paper "The Curse of Depth in Large Language Models" introduces an important concept that affects the efficiency and performance of modern language models. The authors identify a phenomenon where nearly half of the layers in LLMs are less effective than expected, calling this the "Curse of Depth" (CoD).

Key Findings

  1. The Problem: Deeper layers in LLMs like Llama, Mistral, DeepSeek, and Qwen contribute significantly less to the final output compared to earlier layers. This creates inefficiency, as training these models requires substantial computational resources.

  2. Root Cause: The authors identify Pre-Layer Normalization (Pre-LN) as the culprit. While Pre-LN stabilizes training, it causes the output variance to grow exponentially with model depth. This causes deeper transformer blocks to act almost like identity matrices, barely transforming the data in meaningful ways.

  3. The Solution: LayerNorm Scaling, which scales the output of the layer normalization inversely by the square root of its depth (1/√layer_depth). This simple modification controls the variance explosion issue.

  4. Results: Experiments across models from 130M to 1B parameters show that LayerNorm Scaling significantly improves pre-training performance compared to Pre-LN and carries these benefits through to supervised fine-tuning.

Implementation

Below is a clean, standalone PyTorch implementation of LayerNorm Scaling that can be used as a drop-in replacement for standard LayerNorm.

How to Use the Implementation

The code provides two main classes:

  1. LayerNormScaling: A drop-in replacement for standard LayerNorm with depth scaling
  2. RMSNormScaling: A version for LLaMA-like models that use RMSNorm instead of LayerNorm

The implementation can be used in three ways:

1. For a New Model

When designing a new model, you can directly use the LayerNormScaling or RMSNormScaling classes:

from layernorm_scaling import RMSNormScaling

# In your transformer layer
self.input_layernorm = RMSNormScaling(hidden_size, layer_idx=current_layer_index)

2. To Retrofit an Existing Model

The apply_cod_mitigation_to_llama function shows how to convert an existing LLaMA model to use the scaling technique:

from transformers import LlamaForCausalLM
from layernorm_scaling import apply_cod_mitigation_to_llama

# Load your model
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

# Apply the CoD mitigation
model = apply_cod_mitigation_to_llama(model)

3. As a Reference for Custom Implementations

The TransformerLayerWithScaling class demonstrates how to integrate the scaling into a transformer layer architecture.

Benefits of LayerNorm Scaling

  1. Improved Performance: The paper shows consistent improvements in perplexity across model sizes
  2. Resource Efficiency: Makes better use of all layers, improving training efficiency
  3. Simple Implementation: Requires minimal code changes and no additional parameters
  4. Compatible with Existing Models: Can be retrofitted to already trained models

This implementation follows the paper's approach while providing flexibility for different use cases and model architectures.

citation

based on original work:

@article{sun2025curse,
  title={The Curse of Depth in Large Language Models},
  author={Sun, Wenfang and Song, Xinyuan and Li, Pengxiang and Yin, Lu and Zheng, Yefeng and Liu, Shiwei},
  journal={arXiv preprint arXiv:2502.05795},
  year={2025}
}
@pszemraj
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment