Skip to content

Instantly share code, notes, and snippets.

@garrett361
Last active May 12, 2026 13:25
Show Gist options
  • Select an option

  • Save garrett361/e2454c658ec3c78a0020be4a7b132727 to your computer and use it in GitHub Desktop.

Select an option

Save garrett361/e2454c658ec3c78a0020be4a7b132727 to your computer and use it in GitHub Desktop.
Model init

Summarizing what I've found about standard model initialization practices in public code bases and papers.

Code Bases

torchtitan (commit 8599818f)

Shared depth-scaling utility (param_init.py): depth_scaled_std(base_std, layer_id) = base_std / sqrt(2 * (layer_id + 1)) (0-indexed layers).

Two distinct init philosophies coexist:

GPT-OSS (gpt_oss/__init__.py) — all weights depth-scaled:

  • Embeddings: normal_(std=0.02)
  • Q, K, V, and output (wo) projections: depth-scaled trunc_normal_(std=depth_scaled_std(0.02, layer_id))
  • All FFN/MoE weights (mlp1, mlp2) and router gate: depth-scaled
  • LM head: trunc_normal_(std=dim^{-0.5}, a=−3s, b=3s)

Llama3 / Llama4 / DeepSeekV3 / Qwen3 — output projections only depth-scaled:

  • Embeddings: normal_(std=1.0), or skip_param_init for weight-tied variants (e.g. Llama3 1B/3B)
  • Q/K/V projections and FFN gate (w1): flat trunc_normal_(std=0.02)
  • Attention output (wo), FFN down + up projections (w2, w3): depth-scaled
  • MoE experts: same split — w1 flat, w2/w3 depth-scaled; router gate depth-scaled
  • LM head: trunc_normal_(std=dim^{-0.5}, a=−3s, b=3s)

RMSNorm weights: ones_ across all models.

Previously (pre-adbaa061): Llama3TransformerBlock.Config carried a depth_init: bool = True flag that selected between two scaling modes:

  • True (default): per-layer normal_(std=0.02 / sqrt(2 * (layer_id + 1)))
  • False: total-depth normal_(std=0.02 / sqrt(2 * n_layers)) — same std for every layer, scaled by total depth

The param_init refactor (adbaa061) ported only the True branch into depth_scaled_std and dropped the total-depth option with no replacement.

OLMo (commit 090253da)

Three implemented strategies via InitFnType (olmo/config.py:203). All schemes use init_normal() which dispatches to normal_ or trunc_normal_ based on the optional init_cutoff_factor config field. Biases always zeroed.

normal (default):

  • All weights: normal_(std=init_std), default init_std=0.02; embeddings override with emb_init_std if set

mitchell (attributed to Mitchell Wortsman):

  • Embeddings + LM head: trunc_normal_(std=d_model^{-0.5})
  • Input projections (Q/K/V, FFN gate): trunc_normal_(std=d_model^{-0.5}) — width-scaled, no depth
  • Output projections (attn_out, ff_out): trunc_normal_(std=(2 · fan_in · (layer_id+1))^{-0.5}) — width- and depth-scaled

full_megatron (metaseq "full megatron init", used for Llama 2):

  • Embeddings: trunc_normal_(std=init_std), optionally × sqrt(d_model) if scale_emb_init=True; override with emb_init_std
  • Input projections (Q/K/V, FFN gate): flat trunc_normal_(std=init_std)
  • Output projections (attn_out, ff_out): trunc_normal_(std=init_std / sqrt(2 · n_layers)) — total-depth scaling
  • LM head: trunc_normal_(std=d_model^{-0.5})

LayerNorm weights: ones_ across all schemes.

Megatron-LM (commit 5fe3f0665)

Architecture: two callable config slots (TransformerConfig, megatron/core/transformer/transformer_config.py) select init for all layers, with a third for embeddings. Core functions in megatron/core/utils.py.

Default (no flags):

  • Embeddings: normal_(std=0.02)
  • Q/K/V, FFN gate/up (fc1), MoE router: normal_(std=0.02)
  • Attention output proj, FFN down (fc2): normal_(std=0.02 / sqrt(2 · n_layers)) — same std for all layers, scaled by total depth; multiplier is 1.0 instead of 2.0 for hybrid (SSM+transformer) models
  • LM head: normal_(std=0.02) (uses embedding_init_method; or init_method when weights are tied)
  • Norms: PyTorch defaults (ones_/zeros_)

--init-method-xavier-uniform (argparse flag, also YAML init_method: "xavier_uniform"): overrides both init_method and output_layer_init_method to xavier_uniform_, discarding depth scaling.

MuP (use_mup=True): implements Maximal Update Parametrization (Yang & Hu 2021), keeping hidden-layer feature updates O(1) as width scales via coupled init, LR, and forward-pass adjustments (width_mult = hidden_size / mup_base_hidden_size):

  • Init: embeddings and LM head stay at normal_(std=0.02); hidden layers use normal_(std=0.02 / sqrt(width_mult)); output projections use normal_(std=0.02 / (sqrt(2 · n_layers) · sqrt(width_mult))). Since fan-in scales proportionally with width_mult, the reduced std exactly cancels that growth: effectively equivalent to normal_(std=0.02) at the base fan-in in terms of pre-activation output variance, keeping it O(1) regardless of width.
  • Forward pass: LM head logits multiplied by 1 / width_mult (auto-set via mup_output_mult). The LM head init is unscaled so its pre-activation std grows as sqrt(width_mult); the 1/width_mult multiplier then brings the effective output std to 0.02 / sqrt(width_mult) at the base fan-in — effectively equivalent to normal_(std=0.02 / sqrt(width_mult)) with no multiplier, deliberately weaker than hidden activations. Attention softmax uses 1 / d_head instead of 1 / sqrt(d_head) (mup_attn_scale_power=1.0); embedding outputs optionally multiplied by mup_embedding_mult (defaults to 1.0, not auto-scaled)
  • LR (Adam/AdamW): hidden 2D weight matrices get lr / width_mult and eps / width_mult; 1D params (biases, norms) and embeddings/output layer keep base lr. Implemented via per-param-group overrides in get_mup_config_overrides (megatron/core/optimizer/__init__.py:130). SGD inverts: vector-like params get lr × width_mult.

transformers (commit 95933eb6f4)

Default (PreTrainedModel._init_weights, modeling_utils.py:2362) — used by most models (Llama, Mistral, GPT-NeoX, and the majority of the library):

  • All linears and embeddings: normal_(std=initializer_range), default initializer_range=0.02
  • Norms (LayerNorm, RMSNorm): ones_ (weight), zeros_ (bias)
  • Biases: zeros_

The following are rare exceptions; the large majority of models in the library use the default above.

Exceptions

GPT-2 (modeling_gpt2.py:433):

  • All linears and embeddings: normal_(std=0.02)
  • Residual output projections (c_proj.weight) only: normal_(std=0.02 / sqrt(2 · n_layer)) — total-depth scaled; comment credits OpenAI GPT-2 paper and Megatron-LM

CLIP (and ~12 derivatives: AltCLIP, CLIPSEG, CLAP, OWLv2, OWLViT, X-CLIP, etc.; modeling_clip.py:410): All stds multiplied by initializer_factor (default 1.0, exists for testing). Asymmetric depth scaling:

  • Text/position embeddings: normal_(std=0.02 · factor)
  • Q/K/V projections and FFN down (fc2): normal_(std=d_model^{-0.5} · (2 · n_layers)^{-0.5} · factor) — width- and total-depth-scaled
  • Attention output proj: normal_(std=d_model^{-0.5} · factor) — width-scaled only
  • FFN up (fc1): normal_(std=(2 · d_model)^{-0.5} · factor) — width-scaled only

ModernBERT (modeling_modernbert.py:360): Uses trunc_normal_ throughout with configurable cutoff (default 3σ):

  • Embeddings and input projections (Wqkv, FFN Wi): trunc_normal_(std=initializer_range)
  • Output projections (attn Wo, FFN Wo), prediction head: trunc_normal_(std=initializer_range / sqrt(2 · n_layers)) — total-depth scaled
  • Final classifier head: trunc_normal_(std=hidden_size^{-0.5})
  • LayerNorm: ones_

T5 (modeling_t5.py:528): No depth scaling; fan-in/out width scaling throughout, multiplied by initializer_factor (default 1.0):

  • Q: normal_(std=factor · (d_model · d_kv)^{-0.5})
  • K, V: normal_(std=factor · d_model^{-0.5})
  • Attention output (o): normal_(std=factor · (n_heads · d_kv)^{-0.5})
  • FFN wi (up): normal_(std=factor · d_model^{-0.5})
  • FFN wo (down): normal_(std=factor · d_ff^{-0.5})

Gemma / Gemma2 / Gemma3 (modeling_gemma.py:365): Base class init for all weights except: RMSNorm weights initialized to zeros_ rather than ones_ because Gemma RMSNorm computes (1 + weight) · x, so zero weight = identity at init.

lm-engine (commit 02b5957, github.com/open-lm-engine/lm-engine)

Three base init methods (lm_engine/hf_models/modeling_utils/init_utils.py), selected via init_method / embedding_init_method config fields. use_depth_scaled_init=True (default) applies total-depth scaling to output projections only, orthogonally on top of any base method.

"normal" (default):

  • Embeddings and LM head: normal_(std=0.02)
  • Q/K/V, FFN gate/up: normal_(std=0.02)
  • Attention output (c_proj), FFN down: normal_(std=0.02 / sqrt(2 · n_layers)) — total-depth scaled

"fan_in":

  • Embeddings: normal_(std=embed_dim^{-0.5})
  • Q/K/V, FFN gate/up: normal_(std=fan_in^{-0.5})
  • Attention output (c_proj), FFN down: normal_(std=fan_in^{-0.5} / sqrt(2 · n_layers)) — total-depth scaled

"mup":

  • Embeddings and LM head: normal_(std=0.02) — intentionally unscaled; logits scaled by 1/m_width in the forward pass, effectively equivalent to normal_(std=0.02/sqrt(m_width)) with no multiplier
  • Q/K/V, FFN gate/up: normal_(std=0.02 / sqrt(m_width)) — scaled init cancels fan-in growth, effectively equivalent to normal_(std=0.02) at base fan-in
  • Attention output (c_proj), FFN down: normal_(std=0.02 / (sqrt(m_width) · sqrt(2 · n_layers))) — effectively equivalent to normal_(std=0.02 / sqrt(2 · n_layers)) at base fan-in
  • Output projection weights additionally marked for per-parameter LR scaling

Three optional forward-pass multipliers (all default None): m_emb scales embedding outputs; m_residual scales each residual branch output (attn and FFN separately); m_width scales final hidden states and logits by 1/m_width.

nanotron (commit 2411b02)

Two init strategies, selected via ModelArgs.init_method (src/nanotron/scaling/parametrization.py).

RandomInit (configurable std; scaling_method=NUM_LAYERS default; examples use std=0.025):

  • Embeddings: normal_(std=std)
  • Q/K/V, FFN gate/up, router: normal_(std=std) — flat
  • Attention output, FFN down: normal_(std=std / sqrt(2 · n_layers)) — total-depth scaled
  • MoE: gate/up flat; down total-depth scaled (same split as dense)
  • LM head: normal_(std=std) — flat
  • RMSNorm: ones_
  • LAYER_INDEX per-layer scaling is defined in the enum but raises NotImplementedError

SpectralMupInit (Yang et al., "A Spectral Condition for Feature Learning", 2023) — a distinct variant from standard MuP (Yang & Hu 2021): no forward-pass output multiplier; init and LR are coupled per-layer by fan shape rather than a global width ratio:

  • All linears: normal_(std=(1.0 / sqrt(fan_in)) · min(1, sqrt(fan_out / fan_in)))
  • Embeddings and LM head: normal_(std=1.0) — hardcoded, not coupled to width
  • RMSNorm: ones_
  • Forward pass: attention softmax scale 1 / d_head instead of 1 / sqrt(d_head); no logit multiplier
  • Per-parameter LR: lr · (fan_out / fan_in) for linears; global lr for embeddings and norms
  • MoE and router not supported under SpectralMupInit

LLM Foundry (MosaicML/Databricks, commit 0cdb2f4, llmfoundry/models/utils/param_init_fns.py)

Eight named schemes, all sharing the same dispatch: weights tagged _is_residual = True (set on out_proj and down_proj only) are depth-scaled by dividing by div_is_residual after base init. Default div_is_residual = sqrt(2 · n_layers) (controlled by init_div_is_residual). Norms: ones_. Biases: zeros_. Embedding init uses the base function by default, overridable via emb_init_std or emb_init_uniform_lim.

baseline_ (configurable init_std):

  • Embeddings: normal_(std=init_std)
  • Q/K/V, FFN gate/up: normal_(std=init_std) — flat
  • Attention output (out_proj), FFN down (down_proj): normal_(std=init_std / sqrt(2 · n_layers)) — total-depth scaled

small_init_ (Nguyen & Salazar 2019, "Transformers without Tears"):

  • Same as baseline_ with init_std = sqrt(2/(5 · d_model))

neox_init_ (GPT-NeoX-20B, Black et al. 2022):

  • Base: small_init_normal_(std=sqrt(2/(5 · d_model)))
  • Residual divisor overridden to n_layers / sqrt(10) instead of the default sqrt(2 · n_layers)
  • out_proj, down_proj: normal_(std=sqrt(2/(5 · d_model)) · sqrt(10) / n_layers) = normal_(std=2 / (n_layers · sqrt(d_model))) — jointly depth- and width-scaled

Also available: kaiming_uniform_, kaiming_normal_, xavier_uniform_, xavier_normal_ — all support the same _is_residual depth-scaling mechanism.

Cerebras ModelZoo (Cerebras, commit f1fd1e0, src/cerebras/modelzoo/models/nlp/gpt2/gpt2_model.py)

Four separate initializer config slots: embedding_initializer, initializer (Q/K/V and FFN gate/up), output_layer_initializer (attention output), ffn_output_layer_initializer (FFN down). Each accepts any named initializer via create_initializer. MuP applied via scale_initializers_by_dimension which multiplies the std by width_scale × depth_scale.

Default (no MuP; initializer_range=0.02):

  • Embeddings: trunc_normal_(std=0.02, a=−0.04, b=0.04) — flat
  • Q/K/V, FFN gate/up: trunc_normal_(std=0.02) — flat
  • Attention output, FFN down: trunc_normal_(std=0.02 / sqrt(2 · N)) — total-depth scaled; applied automatically when no explicit output_layer_initializer is set
  • Norms: ones_, biases: zeros_

MuP (mup_base_hidden_size set; m = hidden_size/mup_base_hidden_size; empirical base std 0.08 from Cerebras-GPT):

  • Embeddings: trunc_normal_(std=base_std) — intentionally unscaled
  • Q/K/V, FFN gate/up: trunc_normal_(std=base_std / sqrt(m)) — width-scaled
  • Attention output, FFN down: trunc_normal_(std=base_std / (sqrt(m) · sqrt(2 · N))) — width- and total-depth scaled simultaneously; the only framework surveyed that combines both
  • Forward: logits × output_logits_alpha / m (default; or / sqrt(m) if scale_output_logits_by_d=False)
  • LR: attention and FFN input groups → lr / m; separate mup_base_filter_size allows independent width multiplier for FFN output when FFN hidden dim scales differently from hidden_size

tools/convert_config_to_mup.py converts an existing SP config to MuP; ships with empirically tuned default base_init_std=0.08 (from Cerebras-GPT paper, arXiv:2304.03208).

MaxText (Google, commit f44b4236e, src/maxtext/layers/)

JAX/Flax framework (NNX + Linen bridge). Init uses nd_dense_init(scale, mode, distribution) which wraps jax.nn.initializers.variance_scaling; with scale=1.0 and mode="fan_in" this is LeCun init: std = √(scale/fan_in). Global dense_init_scale config (default 1.0, base.yml:172) scales variance uniformly when overridden by individual models. No depth scaling anywhere.

  • Embeddings: normal_(std=1/√d_model) via variance_scaling(1.0, "fan_in", "normal", out_axis=0); out_axis=0 makes fan_in = d_model for the (vocab_size, d_model) embedding matrix
  • LM head: weight-tied to embedding (Embed.attend); no separate init
  • Q projection: normal_(std=1/√(d_model · head_dim)) — T5-style, divides kernel by √head_dim at init to absorb the attention scale; disabled when use_qk_norm=True or a custom query_pre_attn_scalar is set
  • K, V projections: normal_(std=1/√fan_in) (fan_in = d_model)
  • Attention output projection: normal_(std=1/√fan_in) (fan_in = n_heads · head_dim)
  • FFN gate/up (wi), FFN down (wo): trunc_normal_(std=1/√fan_in) via nd_dense_init(1.0, "fan_in", "truncated_normal")
  • Norms: JAX defaults (ones for weight, zeros for bias where applicable)

Code Bases Summary Table

(eff. X) = effective std at m=1 for MuP schemes; non-MuP effective std equals nominal. l = 0-indexed layer_id · N = n_layers · d = hidden_dim · m = width/base_width · σ = configurable · f_i/f_o = fan_in/fan_out · d_h = head_dim · = not documented.

Scheme Input proj (Q/K/V, FFN up) Output proj (attn out, FFN down) Depth scaling Embed LM head
torchtitan GPT-OSS trunc_normal_(0.02/√(2(l+1))) same Per-layer, all weights normal_(0.02) trunc_normal_(1/√d)
torchtitan Llama3/4/DS/Qwen trunc_normal_(0.02) trunc_normal_(0.02/√(2(l+1))) Per-layer, output only normal_(1.0) trunc_normal_(1/√d)
OLMo normal normal_(0.02) normal_(0.02) None normal_(0.02) normal_(0.02)
OLMo mitchell trunc_normal_(1/√d) trunc_normal_(1/√(2·f_i·(l+1))) Per-layer + width, output trunc_normal_(1/√d) trunc_normal_(1/√d)
OLMo full_megatron trunc_normal_(0.02) trunc_normal_(0.02/√(2N)) Total-depth, output trunc_normal_(0.02) trunc_normal_(1/√d)
Megatron-LM default normal_(0.02) normal_(0.02/√(2N)) Total-depth, output normal_(0.02) normal_(0.02)
Megatron-LM MuP normal_(0.02/√m) (eff. 0.02) normal_(0.02/(√m·√(2N))) (eff. 0.02/√(2N)) Total-depth, output normal_(0.02) normal_(0.02); fwd logit×1/m
transformers default normal_(0.02) normal_(0.02) None normal_(0.02) normal_(0.02) or tied
transformers GPT-2 normal_(0.02) normal_(0.02/√(2N)) Total-depth, c_proj only normal_(0.02) tied
transformers ModernBERT trunc_normal_(0.02) trunc_normal_(0.02/√(2N)) Total-depth, output trunc_normal_(0.02) trunc_normal_(1/√d)
transformers T5 Q: normal_(1/√(d·d_kv)); K/V: normal_(1/√d) attn: normal_(1/√(n_h·d_kv)); FFN: normal_(1/√d_ff) None normal_(1/√d) tied
lm-engine normal normal_(0.02) normal_(0.02/√(2N)) Total-depth, output normal_(0.02) normal_(0.02)
lm-engine fan_in normal_(1/√f_i) normal_(1/(√f_i·√(2N))) Total-depth, output normal_(1/√d) normal_(1/√d)
lm-engine mup normal_(0.02/√m) (eff. 0.02) normal_(0.02/(√m·√(2N))) (eff. 0.02/√(2N)) Total-depth, output normal_(0.02) normal_(0.02); fwd logit×1/m
nanotron RandomInit normal_(σ) normal_(σ/√(2N)) Total-depth, output normal_(σ) normal_(σ) flat
nanotron SpectralMupInit normal_((1/√f_i)·min(1,√(f_o/f_i))) same formula None (fan-shape) normal_(1.0) normal_(1.0)
LLM Foundry baseline_ normal_(σ) normal_(σ/√(2N)) Total-depth, output normal_(σ)
LLM Foundry small_init_ normal_(√(2/(5d))) normal_(√(2/(5d))/√(2N)) Total-depth, output same
LLM Foundry neox_init_ normal_(√(2/(5d))) normal_(2/(N·√d)) Total-depth + width, output same
Cerebras default trunc_normal_(0.02) trunc_normal_(0.02/√(2N)) Total-depth, output trunc_normal_(0.02)
Cerebras MuP trunc_normal_(0.08/√m) (eff. 0.08) trunc_normal_(0.08/(√m·√(2N))) (eff. 0.08/√(2N)) Total-depth, output trunc_normal_(0.08) fwd logit×α_out/m
MaxText K/V: normal_(1/√d); Q: normal_(1/√(d·d_h)) (T5-style) normal_(1/√f_i) None normal_(1/√d) tied

Papers

Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer (arxiv:2203.03466)

Introduces μP (Maximal Update Parametrization) to enable zero-shot HP transfer: HPs tuned at small base width n₀ transfer to large width n. Parameters are split into input (embeddings), hidden (Q/K/V, attention output, all FFN), and output (LM head) classes. With m = n/n₀:

SP (Standard Parametrization, baseline):

  • All weights: normal_(std=fan_in^{-0.5}); attention scale 1/sqrt(d_head); uniform LR η

μP:

  • Embeddings and biases: normal_(std=n^{-0.5}) — same as SP; LR = η
  • Hidden (Q/K/V, attention output, all FFN): normal_(std=fan_in^{-0.5}) — same init as SP; width-coupling is entirely via LR scaling, not init std; LR = η/m
  • LM head: normal_(std=(n · n₀)^{-0.5}) = normal_(std=fan_in^{-0.5} / sqrt(m)) — narrower than SP by √m; equals SP at m=1; LR = η/m
  • Attention softmax scale: 1/d_head (not 1/sqrt(d_head))
  • Forward: logits × output_mult/m; output_mult is a transferable HP tuned at small scale
  • LR (Adam): hidden and LM head weights → lr/m; embeddings and biases → lr

A Spectral Condition for Feature Learning (arxiv:2310.17813)

Derives μP-equivalent init and LR prescriptions from an elementary spectral condition: ||W||_* = Θ(√(fan_out/fan_in)). Handles arbitrary fan shapes and sparse inputs natively, without Tensor Programs machinery. Reduces to μP (Table 3 of Tensor Programs V) when all hidden widths are equal and input/output dims are Θ(1).

Parametrization 1:

  • All linear layers: normal_(std=(1/√fan_in) · min(1, √(fan_out/fan_in))) — equals fan-in init when fan_out ≥ fan_in; contracted to normal_(std=√fan_out / fan_in) when fan_out < fan_in
  • Embeddings (one-hot/sparse input): natural input norm = 1 regardless of vocab size, effective fan_in = 1; formula gives normal_(std=1.0) — independent of d_model
  • LM head (dense → vocab): fan_out ≥ fan_in in typical LLMs (vocab ≥ d_model), so normal_(std=fan_in^{-0.5}) — same as hidden layer
  • No forward-pass multipliers — output-layer contraction absorbed into init formula
  • Per-layer LR: lr · (fan_out/fan_in) for linears; global lr for embeddings and norms

Origin of the 1/sqrt(2N) depth-scaled init for residual output projections, subsequently adopted by Megatron-LM and most modern LLMs.

  • Embeddings (token and position): normal_(std=0.02) — flat
  • Q/K/V, FFN gate/up: normal_(std=0.02) — flat
  • Attention output projection and FFN down projection: normal_(std=0.02 / sqrt(2 · N)) — total-depth scaled; N = n_layers; factor of 2 counts both residual contributions per block (attention and MLP)
  • LM head: normal_(std=0.02) — flat
  • LayerNorm: ones_ (weight), zeros_ (bias)

Trinity (arxiv:2602.17004)

Uses a single uniform init std for all weight matrices derived from Spike No More's stability analysis. Depth coupling appears only in post-norm RMSNorm gains (not weight inits), via the depth-scaled sandwich norm from Pangu Ultra.

  • All weight matrices (Q/K/V, attention output, FFN gate/up/down, embeddings, LM head): trunc_normal_(std=0.5/√d, a=−1.5/√d, b=1.5/√d) — width-scaled; approximates Spike No More's √(2/(5d)) ≈ 0.632/√d, rounded to 0.5/√d to also match DeepSeek-V3
  • Pre-norm RMSNorm gains: ones_
  • Post-norm RMSNorm gains: 1/√L — total-depth scaled (depth-scaled sandwich norm)
  • Forward: embedding activations scaled by √d at runtime, following Spike No More

Spike No More: Stabilizing the Pre-Training of Large Language Models (arxiv:2312.16903)

Analyzes Pre-LN transformer Jacobians and shows loss spikes arise when embedding std is comparable to weight std — the identity shortcut in each residual block must dominate the sublayer contribution. Proposes changing only the embedding initialization; all other weights use the Megatron-LM baseline.

Baseline (Megatron-LM convention, not changed by this paper):

  • Q/K/V, FFN gate/up: normal_(std=√(2/(5d))) ≈ normal_(std=0.632/√d) — width-scaled
  • Attention output, FFN down: normal_(std=√(1/(5dN))) = normal_(std=0.632/√d / sqrt(2N)) — total-depth scaled; N = n_layers

Proposed (embedding only; two equivalent methods):

  • Scaled Embed: multiply embedding weight matrix by √d at init; effective std becomes √(2/5) ≈ 0.632 — constant, no longer width-dependent
  • Embed LN: add a LayerNorm immediately after the embedding lookup; normalizes embedding std to ≈ 1 at init
  • All other weights remain at the baseline stds above

The Principles of Deep Learning Theory (arxiv:2106.10165)

Theory textbook (Roberts, Yaida, Hanin) using effective field theory / mean-field methods to analyze MLP initialization and signal propagation. Characterizes criticality conditions (He/Kaiming init as the ReLU-optimal variance) from first principles. Not transformer-specific; no per-weight-type practical recipe for LLMs.

AutoInit: Automatic Initialization via Jacobian Tuning (arxiv:2206.13568)

Proposes automatically tuning per-layer init multipliers so that all averaged partial Jacobian norms equal 1 (criticality), without closed-form formulas. Evaluated on MLP and CV architectures (ResMLP, VGG); no transformer experiments or explicit per-weight-type stds.

Critical Initialization of Wide and Deep Neural Networks using Partial Jacobians (arxiv:2111.12143)

Develops averaged partial Jacobian norm (APJN) theory showing that Pre-LN + residual networks with μ=1 lie in an "everywhere-critical" regime at init regardless of weight std. Characterizes He/Fixup/ReZero as special cases. No transformer-specific practical prescriptions.

DeepNet: Scaling Transformers to 1,000 Layers (arxiv:2203.00555)

Proposes DeepNorm: modified residual x_{l+1} = LN(α·x_l + G_l(x_l, θ)) where α > 1 up-scales the skip branch, with weights initialized to Xavier Normal × β (β < 1). Q and K projections are unscaled (β = 1); V, attention output, and FFN weights are scaled by β:

Decoder-only (M layers): α = (2M)^(1/4), β = (8M)^(-1/4)
Encoder-only (N layers): α = (2N)^(1/4), β = (8N)^(-1/4)
Encoder-decoder (N enc layers, M dec layers):

  • Encoder: α = 0.81·(N^4·M)^(1/16), β = 0.87·(N^4·M)^(-1/16)

  • Decoder: α = (3M)^(1/4), β = (12M)^(-1/4)

  • Q, K projections: xavier_normal_(std=√(2/(fan_in+fan_out))) — unscaled (β = 1)

  • V projection, attention output, FFN up/down: xavier_normal_ × β; depth-dependent but weaker than GPT-2-style 1/√(2N) — the 4th-root vs square-root dependence on depth

  • Embeddings: standard xavier (unscaled, not mentioned in β prescription)

GLM-130B (2210.02414) adopted DeepNorm. GLM-4 subsequently dropped it with no stated replacement.

Transformers without Tears: Improving the Normalization of Self-Attention (arxiv:1910.05895)

Proposes three complementary techniques — ScaleNorm, FixNorm, and SmallInit — to stabilize Pre-LN transformers without warmup. The init contribution is SmallInit; no depth scaling.

SmallInit:

  • All attention linear weights: normal_(std=√(2/(5·d_model))) — derived from the variance of attention outputs; equivalent to fan sum of d+4d at unit head dimension
  • Embeddings: FixNorm (L2-normalized at runtime, not std-based); not part of SmallInit proper
  • No FFN-specific formula stated; same SmallInit std applied throughout

LLM Foundry's small_init_ uses exactly √(2/(5·d_model)) and credits this paper.

Improving Deep Transformer with Depth-Scaled Initialization and Merged Attention (arxiv:1908.11365)

Proposes DS-Init: per-layer init where deeper layers receive narrower initialization. Scaling is per-layer by index, not total-depth — each layer has a different std.

DS-Init:

  • All weight matrices: uniform_(-γ/√l, γ/√l) where γ = √(6/(fan_in+fan_out)) is the Glorot bound and l is the 1-indexed layer position; hyperparameter α (default 1.0) gives U(-γα/√l, γα/√l) in full generality
  • Equivalent normal std: normal_(std=√(2/(l·(fan_in+fan_out)))) — per-layer, not total-depth
  • Applies to all weight matrices in both encoder and decoder simultaneously; applying to only one side causes instability
  • Embeddings not mentioned; standard init implied

Derivation sketch (empirical, not formal): The paper tracks Var(r) at each residual connection output and observes it grows with depth under standard init, which they attribute to each sublayer adding variance. Citing Glorot & Bengio (2010), they note activation variance is proportional to parameter variance. The fix: scale parameter variance by 1/l at layer l, which proportionally reduces the sublayer's variance contribution and pushes Var(r) toward 1 across all layers. The 1/√l on std (equivalently 1/l on variance) is motivated by wanting each layer's contribution to be comparable in magnitude to 1/l of the baseline, but there is no closed-form recurrence derivation — it is validated empirically rather than proven.

u-μP: The Unit-Scaled Maximal Update Parametrization (arxiv:2407.17465)

Reformulates μP to eliminate the base-shape hyperparameter and simplify implementation. All activations, weights, and gradients begin at O(1) scale. Weights are parametrized as W = A_W·w₀ where w₀ ~ N(0,1) and A_W is a fixed per-layer scalar — effectively just a different way to specify the init std. Effective distributions:

  • Embeddings: normal_(std=1) (A_W = 1) — independent of width
  • Hidden (Q, K, V, attn output, FFN up/down): normal_(std=1/√fan_in) (A_W = 1/√fan_in) — LeCun init
  • Output/LM head: normal_(std=1/fan_in) (A_W = 1/fan_in) — narrower than hidden by √fan_in
  • Residual-branch output projections: A_W additionally multiplied by √(base_depth/depth) — total-depth scaling relative to a base depth

Attention scale: no explicit 1/d_head or 1/√d_head; absorbed into the unit-scaled Q init and associated operation-level scaling (α_attn-softmax is a transferable HP, not fixed).

LR (Adam) per weight type:

  • Embeddings: η/√fan_out — scales as 1/√d_model; width-invariant (new vs μP which uses constant η̂_emb)
  • Hidden: η/√fan_in
  • Output/LM head: η (base rate, not η/m)
  • Residual: η·√(base_depth/depth)

Versus standard μP: no base_shape HP; embedding LR uses fan_out not a fixed constant; output LR is base η not η/m; hidden and LM head init are both 1/√fan_in not split by m at init time.

Papers Summary Table

(eff. X) = effective std at base width for MuP schemes. n = current width · n₀ = base width · m = n/n₀ · N = n_layers · l = layer index (1-indexed) · d = hidden_dim · d_h = head_dim · f_i/f_o = fan_in/fan_out · = not specified.

Scheme Input proj Output proj Depth Embed LM head LR adj. Fwd pass
SP (μP baseline) normal_(1/√f_i) same None normal_(1/√n) normal_(1/√n) η uniform 1/√d_h attn
μP normal_(1/√f_i) (same as SP) same as SP None normal_(1/√n) normal_(1/√(n·n₀)) (eff. 1/n₀) η/m hidden+output; η embed logits×mult/m; 1/d_h attn
Spectral MuP normal_((1/√f_i)·min(1,√(f_o/f_i))) same formula None (fan-shape only) normal_(1.0) same formula lr·(f_o/f_i) per layer no multipliers
GPT-2 normal_(0.02) normal_(0.02/√(2N)) Total-depth, output normal_(0.02) normal_(0.02) uniform 1/√d_h
Trinity trunc_normal_(0.5/√d) same None (post-norm RMS gains 1/√L) same same uniform embed ×√d
Spike No More normal_(√(2/(5d))) normal_(√(2/(5d))/√(2N)) Total-depth, output ×√d (Scaled) or +LN uniform embed ×√d
DeepNet (decoder) Q/K: xavier_normal_; V/FFN: xavier_normal_·(8N)^{-1/4} same as V/FFN 4th-root, V/output/FFN xavier_normal_ uniform DeepNorm: residual ×α
SmallInit normal_(√(2/(5d))) same (no depth scaling) None FixNorm larger η, no warmup ScaleNorm
DS-Init U(±√(6/(l·(f_i+f_o)))), equiv. normal_(√(2/(l·(f_i+f_o)))) same Per-layer (1/√l), all uniform standard
u-μP normal_(1/√f_i) (same as SP) same as SP √(base_N/N) on resid branch normal_(1.0) normal_(1/f_i) η/√f_o embed; η/√f_i hidden; η LM head attn via A_W, no fixed scale

Depth Scaling: Criteria and Scalings

Different papers and frameworks are implicitly optimizing different criteria, which leads to distinct scalings. Two broad philosophies: (1) scale weights to control variance accumulation in the forward pass, (2) add a structural element that makes depth irrelevant at init.

Bounded final-layer variance (total-depth, 1/√N) — GPT-2, Megatron-LM, OLMo full_megatron, nanotron RandomInit, lm-engine, Cerebras, LLM Foundry baseline_: Each sublayer contributes σᵢ² to the accumulated variance. Set σᵢ = σ/√N uniformly so Σσᵢ² = σ². Clean global budget; requires knowing N upfront. The dominant scheme in practice.

Per-layer convergent series — DS-Init (1908.11365): If you want Var(h_l) bounded at every intermediate layer (not just the final one), you need Σσᵢ² < ∞, which requires σ_l = O(l^{-(1/2+ε)}) for any ε > 0. DS-Init's σ_l ∝ 1/√l is exactly at the boundary — the harmonic series diverges, so it doesn't actually bound intermediate variances. The fully safe per-layer choice would be σ_l ∝ 1/l. The 1/√l prescription is depth-agnostic (each layer only needs its own index, not total N), but this appears to be an incidental property rather than the stated motivation.

Neural ODE / Euler discretization: Modeling the forward pass as an ODE h' = F(h) discretized with step size 1/N gives each sublayer a 1/N prefactor, implying σ_l = O(1/N). Much more conservative than 1/√N; rarely used in practice.

Gradient norm preservation without normalization — Fixup / T-Fixup (Zhang et al. 2019, Huang et al. 2020): Derived by bounding gradient magnitudes for networks without LayerNorm. Fixup gets L^{-1/4} (4th root) for ResNets; T-Fixup adapts this to transformers. The 4th root arises from the same forward/backward interaction as DeepNet: a structural term (here the residual skip rather than an explicit α) takes half the depth burden, leaving the init to handle the other half.

Zero-initialized residual gate — ReZero / SkipInit (Bachlechner 2021, De & Smith 2020): Sidestep the problem entirely: add a scalar gate α_l initialized to 0 at each residual branch, so h_l = h_{l-1} + α_l·G_l. The network starts as pure identity regardless of weight scale; stability comes from the gate rather than the init std. No depth scaling on weights needed.

Fan-shape / spectral condition — Spectral MuP (Yang et al. 2023, arxiv:2310.17813): Ignores depth entirely. Prescribes std from fan shape alone: σ = (1/√fan_in)·min(1, √(fan_out/fan_in)). Depth stability is delegated to LN and residuals.

Summary: Most modern LLMs use total-depth scaling on output projections only (GPT-2 convention) and implicitly rely on LN to handle the rest. Per-layer (DS-Init) and structural (ReZero) alternatives exist but have seen limited adoption in large-scale training.

Tech Reports

Most LLM tech reports do not document weight initialization. The following were searched and found to contain no init scheme:

The following disclose init:

DeepSeek-V2 (arxiv:2405.04434, d_model=5120):

  • All learnable parameters: normal_(std=0.006) — stated as a fixed constant; 0.5/√5120 ≈ 0.007, so not an exact fan-in formula for V2's dimensions; likely inherited from earlier work

DeepSeek-V3 (arxiv:2412.19437, d_model=7168):

  • All learnable parameters: normal_(std=0.006) — matches 0.5/√7168 ≈ 0.006 (Trinity cites this equivalence explicitly); consistent with the Trinity/Spike No More-style 0.5/√d formula for V3's hidden dim
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment