Skip to content

Instantly share code, notes, and snippets.

@tokenbender
Last active July 14, 2025 09:30
Show Gist options
  • Save tokenbender/3f905327a6ec5a188544e4e0364e0e54 to your computer and use it in GitHub Desktop.
Save tokenbender/3f905327a6ec5a188544e4e0364e0e54 to your computer and use it in GitHub Desktop.
avataRL RL-Based Pretraining Plan

The Symphony of a Guess

How a Transformer Learns to Predict the Next Token

Notation (Mathematical Symbols Explained)

hₜ     vector at position t          α      attention weights
       (like GPS coordinates         (how much focus to put
        for word #t)                  on each word)

eₖ     embedding for token k        ∇θ     gradient with respect to θ
       (the number list that         (which direction to adjust
        represents word k)            each parameter θ)

ℝᵈ     d-dimensional real space     ⊙      element-wise product
       (space with d coordinate       (multiply lists item by item:
        axes, like 3D but with        [1,2] ⊙ [3,4] = [3,8])
        more dimensions)
┌─────────────────────────────────────────────────────────────────────────────┐
│                    COMPLETE TRANSFORMER GLOSSARY                            │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  BASIC MATH:                                                               │
│  • Vector: List of numbers [1, 2, 3] representing a point in space        │
│  • Matrix: Grid of numbers [[1,2], [3,4]] like a spreadsheet              │
│  • Dot-product: [1,2]·[3,4] = 1×3 + 2×4 = 11 (measures alignment)        │
│  • Gradient: Direction pointing uphill (how to improve the model)          │
│  • Tensor: Multi-dimensional array (vector=1D, matrix=2D, tensor=3D+)     │
│                                                                             │
│  CORE CONCEPTS:                                                            │
│  • Token: Piece of text (word or subword) converted to an ID number       │
│  • Embedding: Lookup table giving each token coordinates in space         │
│  • Logits: Raw confidence scores (any real numbers)                       │
│  • Softmax: Converts any numbers to valid percentages that sum to 100%    │
│  • Sparse: Mostly zeros [0,0,1,0] vs Dense: many non-zeros [0.1,0.8,0.3]  │
│                                                                             │
│  ATTENTION MECHANISM:                                                      │
│  • Query (Q): "What am I looking for?" (like a search question)           │
│  • Key (K): "What do I offer?" (like a book's index entry)                │
│  • Value (V): "What do I actually contribute?" (like book's content)      │
│  • Multi-head: Multiple attention mechanisms running in parallel          │
│  • Causal mask: Prevents cheating by hiding future words                  │
│                                                                             │
│  POSITION & NORMALIZATION:                                                │
│  • RoPE: Rotary Position Embedding (like clock positions 12,1,2...)      │
│  • RMS Norm: Rescale numbers to similar sizes (like volume control)       │
│  • Residual: x + f(x) allows information highway around layers            │
│                                                                             │
│  LEARNING:                                                                 │
│  • Cross-entropy: "How surprised am I?" loss function                     │
│  • Perplexity: exp(loss) = "How many choices does model think it has?"    │
│  • AdamW: Smart optimizer with momentum and adaptive learning rates       │
│  • Weight decay: Rubber band pulling parameters toward zero               │
│  • Autoregressive: Generating one token at a time, left to right          │
│                                                                             │
│  ARCHITECTURE:                                                             │
│  • FFN: Feed-Forward Network (384→1536→384 per-token computer)           │
│  • ReLU²: Activation function (replace negatives with 0, then square)     │
│  • KL-divergence: How different two probability distributions are         │
│  • Manifold: Curved surface where natural language patterns live          │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

🎼 Prelude: The First Breath

Imagine a single bar of music, four quiet notes:

"the   quick   brown   —"

The transformer inhales them, holds the silence where the fifth should be, then exhales a chord of probabilities in which fox rings almost inevitable. This is not memorization. This is not calculation. This is the sculpture of meaning itself, carved from the mathematics of attention.

To understand how that chord is composed, we must first see how music becomes mathematics.

Movement I: The Score — When Words Become Geometry

Every symphony begins with the act of notation. Raw text dissolves into tokens (individual pieces of text like words or subwords), tokens lift into vectors (lists of numbers representing points in space), vectors arrange themselves into a universe of meaning. But why this sequence? Why not work directly with the tokens?

┌─────────────────────────────────────────────────────────────────────────────┐
│                      TEXT → TOKENS → VECTORS                                │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Raw Text:  "The quick brown fox jumps"                                    │
│      ↓                                                                      │
│  Tokens:    ["The", "quick", "brown", "fox", "jumps"]                     │
│      ↓                                                                      │
│  Token IDs: [1629, 4662, 17041, 21831, 35308]                             │
│      ↓                                                                      │
│  Vectors:   Each token → 384-dimensional vector                           │
│                                                                             │
│      1629 → [0.12, -0.45, 0.78, 0.23, ..., 0.91]  (384 numbers)          │
│      4662 → [0.34, -0.12, 0.56, 0.89, ..., 0.45]  (384 numbers)          │
│     17041 → [0.67, -0.23, 0.34, 0.12, ..., 0.78]  (384 numbers)          │
│     21831 → [0.89, -0.56, 0.12, 0.34, ..., 0.23]  (384 numbers)          │
│     35308 → [0.23, -0.78, 0.91, 0.56, ..., 0.12]  (384 numbers)          │
│                                                                             │
│  Result: Each word becomes a point in 384-dimensional space                │
│          Similar words cluster together in this space                      │
└─────────────────────────────────────────────────────────────────────────────┘
# The opening movement: text becomes mathematics
# Using custom BPE tokenizer for better compression
def get_batch(split):  # Lines 552-565 in train_modal_standalone.py
    if split == 'train':
        data = np.memmap('/data/shakespeare_tokens_bpe1024/train.bin', dtype=np.uint16, mode='r')  # Line 555
    else:
        data = np.memmap('/data/shakespeare_tokens_bpe1024/val.bin', dtype=np.uint16, mode='r')    # Line 557
    
    # Sample random positions in the dataset
    ix = torch.randint(len(data) - cfg['block_size'], (cfg['batch_size'],))                      # Line 558
    x = torch.stack([torch.from_numpy((data[i:i+cfg['block_size']]).astype(np.int64)) for i in ix])        # Line 559
    y = torch.stack([torch.from_numpy((data[i+1:i+1+cfg['block_size']]).astype(np.int64)) for i in ix])    # Line 560
    return x, y  # context and targets

The first insight: integer tokens are too sparse (meaning mostly empty - like a vector [0,0,1,0,0] that has only one non-zero entry). Token 27 ("the") and token 28 ("then") are numerically close (just 1 apart) but mean completely different things. We need a representation where similar meanings are actually close together.

┌─────────────────────────────────────────────────────────────────────────────┐
│                        SPARSE vs DENSE REPRESENTATIONS                      │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Sparse Representation (bad for AI):                                       │
│  "cat" = [0, 0, 1, 0, 0, 0, ...]  ← only one "1", rest are zeros          │
│  "dog" = [0, 0, 0, 1, 0, 0, ...]  ← completely different, no similarity    │
│                                                                             │
│  Problem: "cat" and "dog" look totally unrelated                           │
│                                                                             │
│  Dense Representation (good for AI):                                       │
│  "cat" = [0.2, 0.8, 0.9, 0.1, ...]  ← many non-zero numbers               │
│  "dog" = [0.3, 0.7, 0.8, 0.2, ...]  ← similar to cat (both animals)       │
│                                                                             │
│  Benefit: Similar meanings have similar numbers                            │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Each token k lifts into high-dimensional space through an embedding table (think: a giant lookup dictionary where each word has its own unique coordinates in a 384-dimensional space). Why this lookup? If "fox" and "dog" often appear in similar contexts ("The ___ runs fast"), gradient descent (the learning algorithm that adjusts the model) will nudge their vectors toward each other in this space. The model learns to reuse what it knows about one animal when predicting another. Dense vectors (number lists with many non-zero values) capture similarity; sparse integers (mostly zeros) cannot.

Intuition: Imagine every word as a point in a vast space where similar words cluster together. "Cat" and "dog" might be close to each other, while "cat" and "equation" are far apart. This spatial representation lets the model reason about meaning through geometry.

# Token to vector: capturing similarity
# In our GPT model, this is the wte (word token embedding) layer
self.wte = nn.Embedding(vocab_size, n_emb)  # Line 254: 1024 vocab -> 384 dimensions
token_vectors = self.wte(tokens)            # Line 283: [27, 412, 51, 843] → shape (4, 384)

But now we face a deeper problem. A pure set of vectors is orderless—the dot-product (multiply corresponding elements and sum them up: [1,2]·[3,4] = 1×3 + 2×4 = 11) sees only which token, never where. Unlike RNNs (Recurrent Neural Networks that read sequences one word at a time like humans), self-attention (the core mechanism in transformers) looks at all positions simultaneously. It needs explicit position information.

┌─────────────────────────────────────────────────────────────────────────────┐
│                      WHY WORD ORDER MATTERS                                 │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Consider these sentences:                                                  │
│  "The dog chased the cat" ≠ "The cat chased the dog"                      │
│                                                                             │
│  Same words, completely different meaning!                                 │
│                                                                             │
│  RNN approach (like reading a book):                                       │
│  Read "The" → remember it                                                  │
│  Read "dog" → remember "The dog"                                           │
│  Read "chased" → remember "The dog chased"                                 │
│  Read "the" → remember "The dog chased the"                                │
│  Read "cat" → understand full sentence                                     │
│                                                                             │
│  Transformer approach (like looking at entire page at once):              │
│  See all words: ["The", "dog", "chased", "the", "cat"]                   │
│  Problem: Without position info, it just sees a bag of words              │
│  Solution: Add position encoding so model knows word order                │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────────┐
│                           THE POSITION PROBLEM                              │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Without Position Info:                                                     │
│  ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐                           │
│  │ Vector  │ │ Vector  │ │ Vector  │ │ Vector  │                           │
│  │ "The"   │ │ "quick" │ │ "brown" │ │ "fox"   │                           │
│  └─────────┘ └─────────┘ └─────────┘ └─────────┘                           │
│       ↑           ↑           ↑           ↑                                │
│    These could be in ANY order! Model can't tell                          │
│                                                                             │
│  With Position Info (RoPE):                                                │
│  ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐                           │
│  │ Vector  │ │ Vector  │ │ Vector  │ │ Vector  │                           │
│  │ "The"   │ │ "quick" │ │ "brown" │ │ "fox"   │                           │
│  │ @Pos 0  │ │ @Pos 1  │ │ @Pos 2  │ │ @Pos 3  │                           │
│  │ Rot 0°  │ │ Rot 15° │ │ Rot 30° │ │ Rot 45° │                           │
│  └─────────┘ └─────────┘ └─────────┘ └─────────┘                           │
│       ↑           ↑           ↑           ↑                                │
│    Each vector rotated by position-dependent angle                        │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

The solution: Rotary Position Embedding (RoPE). Instead of adding position numbers to vectors, we rotate the vectors by position-dependent angles. This creates a multiplicative position encoding that naturally decays with distance.

┌─────────────────────────────────────────────────────────────────────────────┐
│                             RoPE EXPLAINED                                  │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Problem: How do we tell the model "this word is at position 3"?           │
│                                                                             │
│  Old way (Adding): vector + position_number                                │
│  • Word vector: [0.5, 0.8]                                                │
│  • Position 3: [0.3, 0.0]                                                 │
│  • Result: [0.8, 0.8] ← Just addition                                     │
│                                                                             │
│  RoPE way (Rotating): Spin the vector by position angle                   │
│  • Position 0: No rotation (0°)                                           │
│  • Position 1: Small rotation (15°)                                       │
│  • Position 2: Bigger rotation (30°)                                      │
│  • Position 3: Even bigger (45°)                                          │
│                                                                             │
│  Think of it like a clock:                                                 │
│  • 12 o'clock = Position 0                                                │
│  • 1 o'clock = Position 1                                                 │
│  • 2 o'clock = Position 2                                                 │
│  • Nearby times are close, distant times are far apart                    │
│                                                                             │
│  Benefit: Nearby positions naturally have similar "rotations"             │
│  Words close together will attend to each other more                       │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Intuition: Think of each position as a compass direction. Position 0 points north, position 1 points slightly northeast, position 2 points more northeast, etc. By rotating each word's "attention vector" based on its position, nearby words naturally attend to each other more than distant ones, just like how nearby compass directions are more similar.

# Position matters: RoPE rotates Q and K by position-dependent angles
class RotaryCache(nn.Module):  # Lines 68-81 in train_modal_standalone.py
    def __init__(self, head_dim: int, max_len: int):                                     # Line 69
        super().__init__()                                                               # Line 70
        inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2) / head_dim))          # Line 71
        t = torch.arange(max_len)                                                        # Line 72
        freqs = torch.einsum("i,j->ij", t, inv_freq)                                    # Line 73
        sin, cos = freqs.sin(), freqs.cos()                                             # Line 74
        self.register_buffer("sin_base", sin, persistent=False)                         # Line 75
        self.register_buffer("cos_base", cos, persistent=False)                         # Line 76

# The fundamental equation: meaning + rotated position
# h = token_vectors (no positional addition!)
# Position encoded via rotation during attention

The model now possesses a geometry of narrative: each token knows not just what it is, but where it stands in the unfolding sentence. "The" at position 0 and "the" at position 10 occupy different regions of this space, carrying their temporal context.

This is the opening chord—meaning and position unified. Now comes the conversation between the notes.

Movement II: Attention — The Conversation of Voices

In a symphony, instruments don't play in isolation; they listen to each other, respond, harmonize. Self-attention is the transformer's way of letting tokens converse across the sequence. As Vaswani et al. noted in their groundbreaking paper "Attention is All You Need": "The first is a multi-head self-attention mechanism" that revolutionized how models process sequences¹.

But why this particular mechanism? As Jay Alammar explains in The Illustrated Transformer: **"Self-attention is the method the Transformer uses to bake the 'understanding' of other relevant words into the one we're currently processing"**². The key insight is that "when the model is processing the word 'it', self-attention allows it to associate 'it' with 'animal'" - creating contextual understanding that was impossible with previous architectures.

┌─────────────────────────────────────────────────────────────────────────────┐
│                           ATTENTION MECHANISM                               │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Token "fox" asking: "What modifies me?"                                   │
│                                                                             │
│      ┌─────┐      ┌─────┐      ┌─────┐      ┌─────┐                        │
│      │"The"│      │"quick"     │"brown"     │"fox" │                        │
│      │     │      │     │      │     │      │ ??? │                        │
│      └──┬──┘      └──┬──┘      └──┬──┘      └──┬──┘                        │
│         │            │            │            │                            │
│         ▼            ▼            ▼            ▼                            │
│    ┌────────┐   ┌────────┐   ┌────────┐   ┌────────┐                       │
│    │10% attn│   │20% attn│   │60% attn│   │10% self│                       │
│    └────────┘   └────────┘   └────────┘   └────────┘                       │
│         │            │            │            │                            │
│         └────────────┼────────────┼────────────┘                            │
│                      │            │                                         │
│                      ▼            ▼                                         │
│               ┌─────────────────────────┐                                   │
│               │ Weighted Combination:   │                                   │
│               │ 10%"The" + 20%"quick"   │                                   │
│               │ + 60%"brown" + 10%"fox" │                                   │
│               │ = Enhanced "fox" vector │                                   │
│               └─────────────────────────┘                                   │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

The challenge: each token needs to aggregate information from all relevant previous tokens. A naive approach would be to learn fixed weights for each position pair, but that doesn't generalize (work well on new, unseen data) to unseen sequence lengths. Instead, we make the connections content-dependent.

Alammar emphasizes the scoring mechanism: **"The score determines how much focus to place on other parts of the input sequence as we encode a word at a certain position"**². This creates what he calls "softmax score determines how much each word will be expressed at this position" - a weighted combination where **"the intuition here is to keep intact the values of the word(s) we want to focus on, and drown-out irrelevant words"**².

Each position broadcasts three signals, each serving a distinct purpose in the attention dance:

  • A query Q: "What am I looking for?" — The question each position asks
  • A key K: "What do I have to offer?" — The advertisement each position broadcasts
  • A value V: "What do I actually contribute?" — The payload each position carries
┌─────────────────────────────────────────────────────────────────────────────┐
│                        QUERY-KEY-VALUE MECHANISM                            │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Networking Event Analogy:                                                 │
│                                                                             │
│  You (fox token):                                                          │
│  ┌─────────────────┐                                                       │
│  │ QUERY (Q):      │ "I'm looking for adjectives that describe me"         │
│  │ "Looking for    │                                                       │
│  │  descriptors"   │                                                       │
│  └─────────────────┘                                                       │
│                                                                             │
│  Others at the event:                                                      │
│                                                                             │
│  ┌─────────┐  ┌─────────┐  ┌─────────┐  ┌─────────┐                        │
│  │ "The"   │  │ "quick" │  │ "brown" │  │ "jumps" │                        │
│  ├─────────┤  ├─────────┤  ├─────────┤  ├─────────┤                        │
│  │KEY (K): │  │KEY (K): │  │KEY (K): │  │KEY (K): │                        │
│  │"I'm an  │  │"I'm a   │  │"I'm a   │  │"I'm a   │                        │
│  │article" │  │speed    │  │color    │  │verb"    │                        │
│  │         │  │adjective│  │adjective│  │         │                        │
│  ├─────────┤  ├─────────┤  ├─────────┤  ├─────────┤                        │
│  │VALUE(V):│  │VALUE(V):│  │VALUE(V):│  │VALUE(V):│                        │
│  │Grammar  │  │Motion   │  │Visual   │  │Action   │                        │
│  │info     │  │concepts │  │concepts │  │concepts │                        │
│  └─────────┘  └─────────┘  └─────────┘  └─────────┘                        │
│       ↓            ↓            ↓            ↓                              │
│   No match    Good match!  Great match!  No match                          │
│     5%           20%          70%          5%                               │
│                                                                             │
│  Final blend: 5%×Grammar + 20%×Motion + 70%×Visual + 5%×Action             │
│               = Rich "fox" representation with color emphasis               │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Real-world analogy: Imagine you're at a networking event. Your query is "I'm looking for someone who knows about web development." Others broadcast their keys like "I'm a React expert" or "I know Python." But their values are their actual expertise and advice they can share. You match based on keys, but extract knowledge from values.

The key insight: keys and values are fundamentally different yet inseparable. Keys are used for matching — they determine the attention weights through similarity with queries. Values are used for mixing — they are what actually gets blended together based on those weights.

Think of it as a library: each book has a key (its catalog entry describing what it contains) and a value (its actual content). You search using the key, but you read the value. The attention mechanism separates these roles — allowing a token to advertise one thing (key) while contributing something potentially different (value).

# Three perspectives on each token - using fused QKV projection
def attention(self, x: torch.Tensor):  # Lines 177-197 in train_modal_standalone.py
    B, T, C = x.shape                                                      # Line 178
    
    # Single linear layer projects to Q, K, V simultaneously
    qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.head_dim)        # Line 183
    q, k, v = qkv.unbind(dim=2)  # Split into Q, K, V                     # Line 185
    q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)     # Line 186
    
    # head_dim = n_emb // n_heads (384 // 6 = 64)
    # This gives us 6 heads of 64 dimensions each

The core operation computes similarity between queries and keys:

    # Apply RoPE rotation to Q and K
    sin, cos = self.rope(T)                          # Line 188 in train_modal_standalone.py
    q = (q * cos) + (_rotate_half(q) * sin)          # Line 189
    k = (k * cos) + (_rotate_half(k) * sin)          # Line 190
    
    # RMS normalization instead of scaling
    q, k = norm(q), norm(k)  # RMS norm for stability # Line 192

The RMS normalization serves a similar purpose to the √dₖ scaling in classical attention.

┌─────────────────────────────────────────────────────────────────────────────┐
│                          RMS NORMALIZATION                                  │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Problem: Vector numbers can become too big or too small                   │
│  Example: [100, 200, 0.1, 50] ← Numbers are all different scales          │
│                                                                             │
│  RMS Norm = Root Mean Square normalization                                 │
│  Step 1: Square all numbers: [10000, 40000, 0.01, 2500]                  │
│  Step 2: Take average (mean): (10000+40000+0.01+2500)/4 = 13125           │
│  Step 3: Take square root: √13125 = 114.6                                 │
│  Step 4: Divide original by this: [100/114.6, 200/114.6, 0.1/114.6, ...]  │
│  Result: [0.87, 1.75, 0.001, 0.44] ← All numbers now similar scale        │
│                                                                             │
│  Why this helps:                                                           │
│  • Prevents any number from dominating others                              │
│  • Makes gradients more stable (learning doesn't explode)                 │
│  • Like adjusting volume levels so all instruments can be heard            │
│                                                                             │
│  Think of it as: "Make sure all numbers are roughly the same size"         │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

RMS norm rescales vectors to have consistent scale, stabilizing gradients (directions of steepest change that guide learning). This is crucial because RoPE can amplify or diminish the magnitude of vectors depending on their position.

According to Zhang & Sennrich's research, "RMSNorm regularizes the summed inputs to a neuron in one layer according to their root mean square" and provides **"re-scaling invariance and implicit learning rate adaptation"**³. Their key insight is that "re-centering invariance in LayerNorm is dispensable" - the mean-centering step in LayerNorm isn't actually necessary, making RMSNorm both simpler and more efficient, **"reducing running time by 7%~64% on different models while achieving comparable performance"**³.

    # Flash attention: fused, memory-efficient implementation
    # The causal mask prevents tokens from seeing future tokens
    out = F.scaled_dot_product_attention(                      # Line 194 in train_modal_standalone.py
        q, k, v, 
        is_causal=True,  # Causal mask: tokens can only attend to previous positions
        dropout_p=self.dropout.p if self.training else 0.0
    )
    

┌─────────────────────────────────────────────────────────────────────────────┐ │ CAUSAL MASK: NO CHEATING ALLOWED! │ ├─────────────────────────────────────────────────────────────────────────────┤ │ │ │ 🎯 GOAL: Predict next token without seeing the future │ │ │ │ ┌─────────────────────────────────────────────────────────────────────────┤ │ │ WITHOUT CAUSAL MASK (Cheating Mode) 🚫 │ │ │ │ │ │ Input sequence: ["The", "quick", "brown", "fox", "jumps"] │ │ │ ↑ │ │ │ Predicting after "fox" │ │ │ │ │ │ Model can see: [The] [quick] [brown] [fox] [jumps] ← SEES ANSWER! │ │ │ ✓ ✓ ✓ ✓ ✓ │ │ │ │ │ │ Prediction: "jumps" ← Too easy! Just copying what it sees │ │ │ Problem: Model learns to cheat, not understand language │ │ └─────────────────────────────────────────────────────────────────────────┤ │ │ │ ┌─────────────────────────────────────────────────────────────────────────┤ │ │ WITH CAUSAL MASK (Fair Training) ✅ │ │ │ │ │ │ Input sequence: ["The", "quick", "brown", "fox", "jumps"] │ │ │ ↑ │ │ │ Predicting after "fox" │ │ │ │ │ │ Model can see: [The] [quick] [brown] [fox] [MASKED] ← Hidden! │ │ │ ✓ ✓ ✓ ✓ ❌ │ │ │ │ │ │ Prediction: "jumps" ← Had to learn real patterns! │ │ │ Benefit: Model learns genuine language understanding │ │ └─────────────────────────────────────────────────────────────────────────┤ │ │ │ ┌─────────────────────────────────────────────────────────────────────────┤ │ │ ATTENTION MATRIX: Who Can See Whom │ │ │ │ │ │ Query Token → Can Attend To These Keys ↓ │ │ │ │ │ │ The quick brown fox jumps │ │ │ ┌─────┬─────┬─────┬─────┬─────┐ │ │ │ The │ ✓ │ -∞ │ -∞ │ -∞ │ -∞ │ ← Only sees itself │ │ │ quick │ ✓ │ ✓ │ -∞ │ -∞ │ -∞ │ ← Sees The + itself │ │ │ brown │ ✓ │ ✓ │ ✓ │ -∞ │ -∞ │ ← Sees The, quick + itself │ │ │ fox │ ✓ │ ✓ │ ✓ │ ✓ │ -∞ │ ← Sees all previous + itself │ │ │ jumps │ ✓ │ ✓ │ ✓ │ ✓ │ ✓ │ ← Sees everything (last token) │ │ │ └─────┴─────┴─────┴─────┴─────┘ │ │ │ │ │ │ ✓ = Allowed attention -∞ = Masked (blocked with negative infinity) │ │ │ │ │ │ Lower triangular matrix = Each token sees only past + itself │ │ └─────────────────────────────────────────────────────────────────────────┤ │ │ │ ┌─────────────────────────────────────────────────────────────────────────┤ │ │ IMPLEMENTATION DETAIL │ │ │ │ │ │ 1. Compute attention scores: Q @ K^T │ │ │ 2. Apply causal mask: Set upper triangle to -∞ │ │ │ 3. Softmax: -∞ becomes 0 probability │ │ │ 4. Result: Future tokens get zero attention weight │ │ │ │ │ │ Mathematical effect: │ │ │ softmax([-∞, -∞, 2.1, 3.2]) = [0.0, 0.0, 0.31, 0.69] │ │ │ │ │ │ Causal mask enables autoregressive generation! 🎯 │ │ └─────────────────────────────────────────────────────────────────────────┤ │ │ └─────────────────────────────────────────────────────────────────────────────┘


    # Why causal masking? In language modeling, we predict the next token
    # based only on previous tokens. If "fox" could see "jumps" (the next word),
    # prediction would be trivial. The causal mask ensures fair prediction.
    
    # Reshape back to original dimensions
    out = out.transpose(1, 2).contiguous().view(B, T, C)       # Line 195
    return self.o_proj(out)  # Final linear projection          # Line 197

The attention weights create a weighted average, but here's where the key/value separation becomes powerful:

Example: "the quick brown fox"

  • Token "fox" has a query asking "what modifies me?"
  • Token "brown" has a key advertising "I'm a color adjective"
  • Token "brown" has a value containing rich color semantics
  • Token "quick" has a key advertising "I'm a speed adjective"
  • Token "quick" has a value containing motion/speed concepts

The keys determine the attention pattern (fox attends 70% to brown, 20% to quick), but the values determine what information actually flows (color semantics and speed concepts blend into fox's representation). This separation allows the model to learn "what to attend to" (key matching) independently from "what to extract" (value content).

But why multiple attention heads? Because different aspects of language require different types of attention. Empirically, head 3 in GPT-2 fires on closing brackets—its attention matrix lights up between every '(' and ')'. This emerges because predicting a closing bracket is easy once you know where the opener sits; one head can dedicate itself to that pattern while others chase syntax or long-range coreference.

As Alammar notes, multi-head attention "expands the model's ability to focus on different positions" and **"gives the attention layer multiple 'representation subspaces'"**². The original Transformer paper demonstrated that **"The Transformer outperforms the Google Neural Machine Translation model in specific tasks. The biggest benefit, however, comes from how The Transformer lends itself to parallelization"**².

Intuition: Think of attention heads like different specialists in a team. One person tracks grammatical structure ("where are the subjects and verbs?"), another tracks references ("what does 'it' refer to?"), and another tracks punctuation patterns. Each specialist can focus on their expertise while all contribute to understanding the sentence.

# Multi-head attention: parallel conversations (6 heads in our model)
class OptimizedAttention(nn.Module):  # Lines 142-158 in train_modal_standalone.py
    def __init__(self, n_emb: int, n_head: int, context_len: int, dropout: float = 0.1):  # Line 143
        super().__init__()                                                                 # Line 144
        self.n_head = n_head  # 6 heads                                                   # Line 145
        self.n_emb = n_emb    # 384 dimensions                                            # Line 146
        self.head_dim = n_emb // n_head  # 64 dimensions per head                        # Line 147
        self.qkv = nn.Linear(n_emb, 3 * n_emb, bias=False)  # Fused QKV projection      # Line 148
        self.o_proj = nn.Linear(n_emb, n_emb, bias=False)   # Output projection         # Line 149
        self.rope = RotaryCache(self.head_dim, context_len)  # RoPE cache                # Line 152

Six conversations happen simultaneously in our model, each capturing different linguistic relationships. The result is a richer, more nuanced understanding of how tokens relate to each other.

Movement III: The Deepening — How Layers Build Understanding

A single attention layer creates one moment of conversation. But understanding deepens through repetition, through layers of increasingly sophisticated dialogue. Why stack layers at all?

┌─────────────────────────────────────────────────────────────────────────────┐
│                          LAYER-BY-LAYER UNDERSTANDING                       │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Input: "the quick brown fox"                                               │
│                                                                             │
│  Layer 1 (Local Patterns):                                                 │
│  ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐                                          │
│  │"the"│→│"quick"│→│"brown"│→│"fox"│                                          │
│  └─────┘ └─────┘ └─────┘ └─────┘                                          │
│                     ↗          ↑                                            │
│                   High attn   Focus here                                    │
│  Result: "fox" learns about its immediate modifier "brown"                 │
│                                                                             │
│  Layer 2 (Medium-Range Dependencies):                                      │
│  ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐                                          │
│  │"the"│ │"quick"│ │"brown"│ │"fox"│                                          │
│  └─────┘ └─────┘ └─────┘ └─────┘                                          │
│            ↗                 ↑                                              │
│          Med attn           Focus here                                      │
│  Result: "fox" learns about speed quality "quick"                         │
│                                                                             │
│  Layer 3 (Long-Range Grammar):                                             │
│  ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐                                          │
│  │"the"│ │"quick"│ │"brown"│ │"fox"│                                          │
│  └─────┘ └─────┘ └─────┘ └─────┘                                          │
│     ↗                         ↑                                            │
│   Low attn                   Focus here                                     │
│  Result: "fox" learns it's the main noun (from determiner "the")          │
│                                                                             │
│  Final Understanding: "fox" = brown + quick + definite noun               │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Consider the phrase "the quick brown fox." In layer 1, "fox" might attend to "brown" (adjacent modifier). In layer 2, it might attend to "quick" (distant modifier). In layer 3, it might attend to "the" (grammatical determiner). Each layer allows for more complex dependency parsing.

┌─────────────────────────────────────────────────────────────────────────────┐
│                      TRANSFORMER BLOCK INTERNALS                            │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Pre-Norm Architecture (used in our model):                                │
│                                                                             │
│  Input x: [batch, seq_len, 384]                                           │
│    │                                                                        │
│    ▼                                                                        │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ ATTENTION SUB-BLOCK                                                     │
│  │                                                                         │
│  │ x_residual = x  ←────────────────────────────────┐ (SKIP CONNECTION)   │
│  │    │                                             │                     │
│  │    ▼                                             │                     │
│  │ ┌─────────────────────────────────────────────┐   │                     │
│  │ │ norm(x)  ← RMS Normalization                │   │                     │
│  │ │    ↓                                        │   │                     │
│  │ │ Multi-Head Attention:                       │   │                     │
│  │ │   ├─ x → Q, K, V projections                │   │                     │
│  │ │   ├─ Apply RoPE to Q, K                     │   │                     │
│  │ │   ├─ Compute attention: softmax(QK^T)       │   │                     │
│  │ │   ├─ Apply causal mask                      │   │                     │
│  │ │   └─ Aggregate: attention_weights @ V       │   │                     │
│  │ │    ↓                                        │   │                     │
│  │ │ attention_output                             │   │                     │
│  │ └─────────────────────────────────────────────┘   │                     │
│  │    │                                             │                     │
│  │    ▼                                             │                     │
│  │    +  ←──────────────────────────────────────────┘                     │
│  │    │                                                                   │
│  │    ▼                                                                   │
│  │ x₁ = x + attention_output  (First residual connection)                 │
│  └─────────────────────────────────────────────────────────────────────────┤
│    │                                                                        │
│    ▼                                                                        │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ FEED-FORWARD SUB-BLOCK                                                 │
│  │                                                                         │
│  │ x₁_residual = x₁  ←──────────────────────────────┐ (SKIP CONNECTION)   │
│  │    │                                             │                     │
│  │    ▼                                             │                     │
│  │ ┌─────────────────────────────────────────────┐   │                     │
│  │ │ norm(x₁)  ← RMS Normalization               │   │                     │
│  │ │    ↓                                        │   │                     │
│  │ │ Feed-Forward Network:                       │   │                     │
│  │ │   ├─ Linear: 384 → 1536                    │   │                     │
│  │ │   ├─ ReLU²: activation                     │   │                     │
│  │ │   ├─ Linear: 1536 → 384                    │   │                     │
│  │ │   └─ Dropout                               │   │                     │
│  │ │    ↓                                        │   │                     │
│  │ │ ffn_output                                  │   │                     │
│  │ └─────────────────────────────────────────────┘   │                     │
│  │    │                                             │                     │
│  │    ▼                                             │                     │
│  │    +  ←──────────────────────────────────────────┘                     │
│  │    │                                                                   │
│  │    ▼                                                                   │
│  │ x₂ = x₁ + ffn_output  (Second residual connection)                     │
│  └─────────────────────────────────────────────────────────────────────────┤
│    │                                                                        │
│    ▼                                                                        │
│  Output x₂: [batch, seq_len, 384]                                         │
│                                                                             │
│  Key Properties:                                                           │
│  • Two residual connections per block                                      │
│  • RMS normalization before (not after) each sub-layer                    │
│  • Information can flow directly through skip connections                  │
│  • Each sub-layer adds refinements to the representation                   │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
# The deepening: layer by layer, meaning grows
class TransformerBlock(nn.Module):  # Lines 227-242 in train_modal_standalone.py
    def __init__(self, n_emb: int, n_head: int, context_len: int, dropout: float = 0.1):  # Line 228
        super().__init__()                                                                 # Line 229
        self.attn = OptimizedAttention(n_emb, n_head, context_len, dropout)               # Line 230
        # Pre-norm architecture: normalize before, not after
        self.ffn = nn.Sequential(                                                          # Line 232
            nn.Linear(n_emb, 4 * n_emb, bias=False),      # 384 -> 1536                  # Line 233
            ReLUSquared(),                                 # ReLU^2 activation            # Line 234
            nn.Linear(4 * n_emb, n_emb, bias=False),      # 1536 -> 384                  # Line 235
            nn.Dropout(dropout)                                                            # Line 236
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:  # Line 239
        # Pre-norm: normalize input before attention and FFN
        x = x + self.attn(norm(x))  # RMS norm, not LayerNorm  # Line 240
        x = x + self.ffn(norm(x))   # RMS norm, not LayerNorm  # Line 241
        return x                                                # Line 242

The residual connections (x + ...) are crucial. They create skip paths that allow gradients to flow directly from later layers to earlier ones, preventing the vanishing gradient problem (where gradients become too small to update early layers effectively) that plagued deep networks.

The ResNet paper by He et al. was revolutionary in showing that "skip connections solve the degradation problem" by enabling training of networks with hundreds of layers⁴. As machine learning researchers note: "residual connections fix the problem" of vanishing gradients by providing "a gradient highway" that ensures "gradient flows backwards through each step" without degradation⁴.

┌─────────────────────────────────────────────────────────────────────────────┐
│                          RESIDUAL CONNECTIONS                               │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Without Residual Connections (Traditional Deep Network):                  │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ Input                                                                   │
│  │   ↓                                                                     │
│  │ Layer 1 → output₁                                                      │
│  │   ↓                                                                     │
│  │ Layer 2 → output₂                                                      │
│  │   ↓                                                                     │
│  │ Layer 3 → output₃                                                      │
│  │   ↓                                                                     │
│  │ ... (gradient gets weaker and weaker going backwards)                  │
│  │   ↓                                                                     │
│  │ Final Output (early layers barely learn!)                              │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  With Residual Connections (ResNet/Transformer Architecture):              │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ Input x                                                                 │
│  │   ↓                                                                     │
│  │   ┌─────────────────────────────────────────────────────────────────┐   │
│  │   │ ┌─────────────────────────────────────────────────────────────┐ │   │
│  │   │ │                    Layer 1                                  │ │   │
│  │   │ │              (Attention/FFN)                                │ │   │
│  │   │ │                      ↓                                      │ │   │
│  │   │ │                 processed_x                                 │ │   │
│  │   │ └─────────────────────────────────────────────────────────────┘ │   │
│  │   └─────────────────────────┬───────────────────────────────────────┘   │
│  │   ↓                         ↓                                           │
│  │   x ─────────────────────── + ←──── Addition (Residual Connection)     │
│  │   ↓                                                                     │
│  │   output₁ = x + processed_x                                             │
│  │   ↓                                                                     │
│  │   ┌─────────────────────────────────────────────────────────────────┐   │
│  │   │ ┌─────────────────────────────────────────────────────────────┐ │   │
│  │   │ │                    Layer 2                                  │ │   │
│  │   │ │              (Attention/FFN)                                │ │   │
│  │   │ │                      ↓                                      │ │   │
│  │   │ │                 processed_x₂                                │ │   │
│  │   │ └─────────────────────────────────────────────────────────────┘ │   │
│  │   └─────────────────────────┬───────────────────────────────────────┘   │
│  │   ↓                         ↓                                           │
│  │   output₁ ──────────────── + ←──── Another Residual Connection          │
│  │   ↓                                                                     │
│  │   output₂ = output₁ + processed_x₂                                      │
│  │                                                                         │
│  │ Key Insight: Original input x can flow directly to any later layer!    │
│  │ Gradient highway: ∂L/∂x = ∂L/∂output + ∂L/∂processed_parts            │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Intuition: Residual connections are like having backup routes on a highway. If the main road (through the layer) gets congested, traffic (gradients) can still flow via the bypass route (the skip connection). This prevents the "vanishing gradient" traffic jam that would stop learning in deep networks.

┌─────────────────────────────────────────────────────────────────────────────┐
│                    HOW RESIDUAL CONNECTIONS ACTUALLY WORK                   │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  The Key Insight: Residual connections are ALWAYS active, not conditional! │
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ Mathematical View:                                                      │
│  │ output = input + f(input)                                              │
│  │                                                                         │
│  │ Where f(input) is what the layer learns to add/subtract                │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ TRAINING PROGRESSION: How the network learns to use residuals           │
│  │                                                                         │
│  │ Early Training (weights random, outputs ~0):                           │
│  │ ┌─────────────────────────────────────────────────────────────────────┐ │
│  │ │ input: [1.0, 0.5, -0.3, 0.8]                                       │ │
│  │ │   ↓                                                                 │ │
│  │ │ f(input): [0.01, -0.02, 0.01, 0.00]  ← Nearly zero (random init)   │ │
│  │ │   ↓                                                                 │ │
│  │ │ output: [1.01, 0.48, -0.29, 0.80]   ← Mostly original input        │ │
│  │ │                                                                     │ │
│  │ │ Result: Network starts close to identity function                  │ │
│  │ └─────────────────────────────────────────────────────────────────────┘ │
│  │                                                                         │
│  │ Mid Training (learning small refinements):                             │
│  │ ┌─────────────────────────────────────────────────────────────────────┐ │
│  │ │ input: [1.0, 0.5, -0.3, 0.8]                                       │ │
│  │ │   ↓                                                                 │ │
│  │ │ f(input): [0.2, -0.1, 0.4, -0.3]   ← Meaningful corrections        │ │
│  │ │   ↓                                                                 │ │
│  │ │ output: [1.2, 0.4, 0.1, 0.5]       ← Refined representation        │ │
│  │ │                                                                     │ │
│  │ │ Network learns: "Keep most of input, adjust specific features"     │ │
│  │ └─────────────────────────────────────────────────────────────────────┘ │
│  │                                                                         │
│  │ Late Training (sophisticated transformations):                         │
│  │ ┌─────────────────────────────────────────────────────────────────────┐ │
│  │ │ input: [1.0, 0.5, -0.3, 0.8]                                       │ │
│  │ │   ↓                                                                 │ │
│  │ │ f(input): [-0.8, 0.3, 0.9, -0.5]   ← Can even subtract/cancel     │ │
│  │ │   ↓                                                                 │ │
│  │ │ output: [0.2, 0.8, 0.6, 0.3]       ← Heavily transformed           │ │
│  │ │                                                                     │ │
│  │ │ Network learns: "Replace parts of input with new features"         │ │
│  │ └─────────────────────────────────────────────────────────────────────┘ │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ WHY THIS SOLVES THE "POLLUTION" PROBLEM:                               │
│  │                                                                         │
│  │ 1. Network can learn to SUBTRACT unwanted parts:                       │
│  │    If input[0] = 0.7 but should be 0.2                                │
│  │    Network learns f(input)[0] = -0.5                                   │
│  │    Result: 0.7 + (-0.5) = 0.2 ✓                                       │
│  │                                                                         │
│  │ 2. Provides "easy path" for gradients:                                 │
│  │    ∂loss/∂input = ∂loss/∂output × (1 + ∂f/∂input)                     │
│  │    The "1" ensures gradients always flow back!                         │
│  │                                                                         │
│  │ 3. Identity bias helps optimization:                                    │
│  │    Easier to learn "input + small_changes"                             │
│  │    than "completely_new_representation"                                │
│  │                                                                         │
│  │ 4. Selective preservation:                                             │
│  │    Network learns which parts of input to keep vs modify               │
│  │    Like photo editing: keep background, change foreground              │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ REAL EXAMPLE: "fox" token through one layer                            │
│  │                                                                         │
│  │ Input representation (after previous layer):                           │
│  │ fox = [color: 0.8, animal: 0.9, size: 0.3, action: 0.1, ...]         │
│  │                                                                         │
│  │ Attention gathers context: "brown fox jumps"                           │
│  │ FFN learns to add: [color: 0.1, animal: 0.0, size: 0.0, action: 0.7] │
│  │                                                                         │
│  │ Final output:                                                           │
│  │ fox = [color: 0.9, animal: 0.9, size: 0.3, action: 0.8, ...]         │
│  │                                                                         │
│  │ Result: Enhanced color and action info, preserved animal info!         │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  The genius: Network learns WHAT to change, not just HOW to represent      │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Now let me add a diagram showing the gradient flow advantage:

┌─────────────────────────────────────────────────────────────────────────────┐
│                          GRADIENT FLOW COMPARISON                           │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Without Residual Connections (Vanishing Gradients):                       │
│                                                                             │
│  Forward:  x → f₁(x) → f₂(f₁(x)) → f₃(f₂(f₁(x))) → ... → output          │
│                                                                             │
│  Backward: ∂L/∂x = ∂L/∂output × ∂f₃/∂input × ∂f₂/∂input × ∂f₁/∂input     │
│                       ↑         ↑ <1        ↑ <1        ↑ <1               │
│                    1.0          0.8         0.6         0.4                 │
│                                                                             │
│  Result: ∂L/∂x = 1.0 × 0.8 × 0.6 × 0.4 = 0.192 (weak signal!)            │
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ Problem: Gradient shrinks exponentially with depth                     │
│  │ Layer 10: gradient ≈ 0.8¹⁰ ≈ 0.107                                    │
│  │ Layer 20: gradient ≈ 0.8²⁰ ≈ 0.011                                    │
│  │ Layer 50: gradient ≈ 0.8⁵⁰ ≈ 0.000001                                 │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  With Residual Connections (Gradient Highway):                             │
│                                                                             │
│  Forward:  x → x + f₁(x) → x + f₁(x) + f₂(...) → ... → output             │
│                                                                             │
│  Backward: ∂L/∂x = ∂L/∂output × (1 + ∂f₃/∂input) × (1 + ∂f₂/∂input) × ... │
│                       ↑         ↑ >1             ↑ >1                      │
│                    1.0         1.2              1.1                        │
│                                                                             │
│  Result: ∂L/∂x = 1.0 × 1.2 × 1.1 × ... = STRONG SIGNAL!                  │
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ Solution: The "1" in (1 + ∂f/∂input) provides direct gradient path    │
│  │ Even if ∂f/∂input ≈ 0, gradient ≈ 1.0 (no vanishing!)                │
│  │ Layer 100: gradient ≈ 1.0 (still strong!)                             │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  This is why transformers can have 100+ layers and still train well!      │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
# RMS normalization function used throughout the model
def norm(x: torch.Tensor) -> torch.Tensor:  # Lines 61-62 in train_modal_standalone.py
    return F.rms_norm(x, (x.size(-1),))     # Root Mean Square normalization

# ReLU squared activation instead of basic ReLU
class ReLUSquared(nn.Module):  # Lines 138-140 in train_modal_standalone.py
    def forward(self, x: torch.Tensor) -> torch.Tensor:  # Line 139
        return F.relu(x).square()  # More stable than ReLU in practice  # Line 140

What do these feed-forward networks actually do? Mathematically, they're two linear layers with ReLU² in between. Because their weights are position-independent, they can store per-token patterns such as "if the attended context says subj=Plural, add a bias that nudges verb=Plural". Attention pools across tokens; the FFN then post-processes each position individually, giving the model both global context and local rewrite capacity. The ReLU² activation is more stable than basic ReLU and helps with gradient flow.

┌─────────────────────────────────────────────────────────────────────────────┐
│                    FEED-FORWARD NETWORK (FFN) DETAILED                      │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  FFN Architecture: 384 → 1536 → 384 (4× expansion)                        │
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ Step 1: Expansion Layer                                                 │
│  │ ┌─────────────────────────────────────────────────────────────────────┐ │
│  │ │ Input: [batch, seq_len, 384]                                       │ │
│  │ │          ↓                                                          │ │
│  │ │ Linear: 384 → 1536 (Weight matrix: [384, 1536])                   │ │
│  │ │          ↓                                                          │ │
│  │ │ Output: [batch, seq_len, 1536]                                     │ │
│  │ │                                                                     │ │
│  │ │ Per token: 384 numbers become 1536 numbers                         │ │
│  │ │ Purpose: Create richer representation space                        │ │
│  │ └─────────────────────────────────────────────────────────────────────┘ │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ Step 2: Activation Function (ReLU²)                                    │
│  │ ┌─────────────────────────────────────────────────────────────────────┐ │
│  │ │ Input: [batch, seq_len, 1536]                                      │ │
│  │ │          ↓                                                          │ │
│  │ │ ReLU²(x) = max(0, x)² ← "ReLU squared activation"                 │ │
│  │ │                                                                     │ │
│  │ │ What does this do?                                                  │ │
│  │ │ Step 1: ReLU = Replace negative numbers with 0                     │ │
│  │ │ Step 2: Square the result                                           │ │
│  │ │                                                                     │ │
│  │ │ Example: [-2, -1, 0, 1, 2]                                         │ │
│  │ │ After ReLU: [0, 0, 0, 1, 2] ← Negatives become 0                  │ │
│  │ │ After Square: [0, 0, 0, 1, 4] ← Square positive numbers            │ │
│  │ │                                                                     │ │
│  │ │ Why ReLU²? More stable training than plain ReLU                    │ │
│  │ │ Creates sparsity (lots of zeros) + smooth gradients                │ │
│  │ └─────────────────────────────────────────────────────────────────────┘ │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ Step 3: Contraction Layer                                              │
│  │ ┌─────────────────────────────────────────────────────────────────────┐ │
│  │ │ Input: [batch, seq_len, 1536]                                      │ │
│  │ │          ↓                                                          │ │
│  │ │ Linear: 1536 → 384 (Weight matrix: [1536, 384])                   │ │
│  │ │          ↓                                                          │ │
│  │ │ Output: [batch, seq_len, 384]                                      │ │
│  │ │                                                                     │ │
│  │ │ Per token: 1536 numbers compressed back to 384                     │ │
│  │ │ Purpose: Learned feature combination and compression                │ │
│  │ └─────────────────────────────────────────────────────────────────────┘ │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ Conceptual View - What FFN Does Per Token:                             │
│  │                                                                         │
│  │ Token "fox" after attention: [0.1, 0.8, -0.3, 0.5, ...] (384 dims)    │
│  │                 ↓                                                       │
│  │ Expand to intermediate space: [0.05, 0.9, 0.0, 0.7, ...] (1536 dims)  │
│  │                 ↓                                                       │
│  │ Apply non-linearity (ReLU²): zeros out negative, squares positive      │
│  │                 ↓                                                       │
│  │ Contract back: [0.2, 0.6, -0.1, 0.4, ...] (384 dims)                 │
│  │                 ↓                                                       │
│  │ This is like: "Based on gathered context, update fox representation"   │
│  │                                                                         │
│  │ Example patterns FFN learns:                                           │
│  │ - "If context suggests color adjective, amplify visual features"       │
│  │ - "If context suggests plural, adjust grammatical markers"             │
│  │ - "If context suggests action, enhance dynamic properties"             │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Parameter Count: 384×1536 + 1536×384 = 1,179,648 parameters per FFN      │
│  (In a 6-layer model: 6 × 1.18M = 7.08M parameters just for FFNs!)       │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Intuition: Think of attention as a "gathering information" step and FFN as a "processing information" step. After attention collects relevant context from other positions, the FFN acts like a mini-computer at each position, making local decisions based on what was gathered. It's like having a conversation (attention) followed by individual reflection (FFN).

┌─────────────────────────────────────────────────────────────────────────────┐
│                        COMPLETE FORWARD PASS                                │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Input: Token IDs [batch_size, sequence_length]                           │
│  Example: [B=2, T=4] = [[1629, 4662, 17041, 21831],                       │
│                         [9834, 2341, 8765, 1234]]                          │
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ STEP 1: Token Embedding                                                 │
│  │ ┌─────────────────────────────────────────────────────────────────────┐ │
│  │ │ self.wte(idx): [B, T] → [B, T, 384]                                │ │
│  │ │                                                                     │ │
│  │ │ Token 1629 → [0.12, -0.45, 0.78, ...]  (384 dimensions)           │ │
│  │ │ Token 4662 → [0.34, -0.12, 0.56, ...]  (384 dimensions)           │ │
│  │ │ Token 17041→ [0.67, -0.23, 0.34, ...]  (384 dimensions)           │ │
│  │ │ Token 21831→ [0.89, -0.56, 0.12, ...]  (384 dimensions)           │ │
│  │ │                                                                     │ │
│  │ │ Result: [2, 4, 384] tensor of embeddings                           │ │
│  │ └─────────────────────────────────────────────────────────────────────┘ │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ STEP 2: Dropout (Training Regularization)                              │
│  │ x = self.drop(tok_emb)  # Randomly zero some embeddings                │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ STEP 3: Transformer Layers (6 layers)                                  │
│  │                                                                         │
│  │ for layer in self.layers:                                              │
│  │     x = layer(x)  # Each layer does:                                   │
│  │                                                                         │
│  │ ┌─────────────────────────────────────────────────────────────────────┐ │
│  │ │ TRANSFORMER LAYER DETAIL:                                           │ │
│  │ │                                                                     │ │
│  │ │ ┌─ x_input = norm(x)  ← RMS normalize input                        │ │
│  │ │ │                                                                  │ │
│  │ │ ├─ attention_out = self.attn(x_input)                              │ │
│  │ │ │   ├─ Apply RoPE to queries and keys                             │ │
│  │ │ │   ├─ Compute Q @ K^T attention scores                           │ │
│  │ │ │   ├─ Apply causal mask (no future info)                         │ │
│  │ │ │   ├─ Softmax attention weights                                   │ │
│  │ │ │   └─ Weighted sum of values                                      │ │
│  │ │ │                                                                  │ │
│  │ │ ├─ x = x + attention_out  ← RESIDUAL CONNECTION #1                 │ │
│  │ │ │                                                                  │ │
│  │ │ ├─ ffn_input = norm(x)  ← RMS normalize again                      │ │
│  │ │ │                                                                  │ │
│  │ │ ├─ ffn_out = self.ffn(ffn_input)                                   │ │
│  │ │ │   ├─ Expand: 384 → 1536                                         │ │
│  │ │ │   ├─ ReLU²: non-linearity                                       │ │
│  │ │ │   └─ Contract: 1536 → 384                                       │ │
│  │ │ │                                                                  │ │
│  │ │ └─ x = x + ffn_out  ← RESIDUAL CONNECTION #2                       │ │
│  │ │                                                                     │ │
│  │ │ Result: x now contains richer representations                      │ │
│  │ └─────────────────────────────────────────────────────────────────────┘ │
│  │                                                                         │
│  │ After 6 layers: x = [B, T, 384] with deep contextual understanding    │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ STEP 4: Final Normalization                                            │
│  │ x = norm(x)  # One last RMS normalization                              │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ STEP 5: Language Modeling Head                                         │
│  │ ┌─────────────────────────────────────────────────────────────────────┐ │
│  │ │ logits = self.head(x)  # [B, T, 384] → [B, T, 1024]                │ │
│  │ │                                                                     │ │
│  │ │ Note: self.head.weight is tied to self.wte.weight                  │ │
│  │ │ This means: same embedding space for input and output               │ │
│  │ │                                                                     │ │
│  │ │ logits[0, 3, :] = scores for what token comes after position 3     │ │
│  │ │ Higher score = more likely next token                               │ │
│  │ └─────────────────────────────────────────────────────────────────────┘ │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ STEP 6: Loss Calculation (if training)                                 │
│  │ if targets is not None:                                                │
│  │     loss = F.cross_entropy(logits.view(-1, 1024), targets.view(-1))   │
│  │     return logits, loss                                                │
│  │ else:                                                                   │
│  │     return logits[:, [-1], :], None  # Just last token for generation │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Final Output: Probability distribution over vocabulary for each position  │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
# The complete forward pass
def forward(self, idx: torch.Tensor, targets=None) -> torch.Tensor:  # Lines 280-300 in train_modal_standalone.py
    B, T = idx.shape                                                  # Line 281
    
    # Begin with embeddings: the opening chord (no position embedding added!)
    tok_emb = self.wte(idx)  # Only token embeddings                  # Line 283
    x = self.drop(tok_emb)   # Dropout for regularization             # Line 284
    
    # Each layer deepens the conversation (6 layers in our model)
    for layer in self.layers:                                         # Line 288
        x = layer(x)  # Position encoded via RoPE inside attention    # Line 289
    
    # Final normalization and projection to vocabulary
    x = norm(x)  # RMS norm                                           # Line 291
    logits = self.head(x)  # 384 -> 1024 vocab size                   # Line 294 or 297
    # Note: self.head.weight is tied to self.wte.weight (weight sharing)  # Lines 263-264
    
    if targets is not None:                                           # Line 293
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))  # Line 295
        return logits, loss
    else:                                                             # Line 296
        return logits[:, [-1], :], None  # Only return last token for generation  # Line 297-298

By the final layer, "fox" isn't just a token—it's the inevitable conclusion of a grammatical and semantic pattern that began with "the quick brown." The model has built a hierarchical understanding where early layers capture local patterns and later layers capture global structure.

Movement IV: The Measure of Surprise — Loss as Musical Tension

Every prediction creates tension between expectation and reality. The model outputs raw scores (logits - any real numbers) for each vocabulary token, which we convert to probabilities (values that sum to 1.0) via softmax. But why this particular form of loss?

# The mathematics of musical tension
def cross_entropy_loss(logits, targets):
    # Convert raw scores to probabilities
    probs = F.softmax(logits, dim=-1)
    
    # Measure surprise: how unexpected was the truth?
    loss = -torch.log(probs[range(len(targets)), targets])
    return loss.mean()

The negative log serves two purposes: it turns multiplication of probabilities across timesteps into addition of losses (keeping gradients uncluttered), and it makes high probability yield small loss so that minimizing loss = maximizing likelihood.

When the model predicts "fox" with 85% confidence and "fox" appears, the loss is -log(0.85) = 0.16—a gentle dissonance. When it predicts "fox" with certainty but "elephant" appears, the loss explodes to -log(0.001) = 6.9—a jarring clash.

Intuition: Cross-entropy loss is like a "surprise meter." If you're very confident about your prediction and you're right, you get a small penalty (low surprise). If you're very confident but wrong, you get a huge penalty (high surprise). This encourages the model to be both accurate and appropriately uncertain.

Cross-entropy has a deeper meaning: it equals KL-divergence between predicted and true distributions.

┌─────────────────────────────────────────────────────────────────────────────┐
│                        KL-DIVERGENCE EXPLAINED                              │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  KL-divergence = "How different are two probability distributions?"         │
│                                                                             │
│  Think of it like comparing two weather forecasts:                         │
│                                                                             │
│  Forecast A: [Rain: 70%, Sun: 20%, Snow: 10%]                             │
│  Forecast B: [Rain: 30%, Sun: 60%, Snow: 10%]                             │
│                                                                             │
│  KL-divergence measures how "far apart" these predictions are              │
│  High KL = very different forecasts                                        │
│  Low KL = similar forecasts                                                │
│  Zero KL = identical forecasts                                             │
│                                                                             │
│  In our case:                                                              │
│  True distribution = [0, 0, 0, 1, 0, ...] ← one-hot vector                │
│                      (only "fox" is correct, everything else is wrong)     │
│  Model prediction = [0.1, 0.2, 0.05, 0.6, 0.05, ...]                     │
│                      (model thinks "fox" is 60% likely)                    │
│                                                                             │
│  KL-divergence = How far is our model's "weather forecast"                │
│                  from the true "weather forecast"?                         │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Since the true distribution is one-hot (a vector with exactly one 1 and rest 0s - like [0,0,0,1,0] meaning "fox" is 100% correct and everything else is 0% correct), cross-entropy directly measures how far our model's probability guesses are from the perfect answer. Minimizing prediction error = minimizing the gap between the model's beliefs and reality's truth.

┌─────────────────────────────────────────────────────────────────────────────┐
│              LOSS FUNCTIONS ACROSS ML: A UNIFIED VIEW                       │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  THE LOGIT → PROBABILITY → LOSS PIPELINE:                                  │
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ STEP 1: Model Outputs Raw Logits                                       │
│  │ ┌─────────────────────────────────────────────────────────────────────┐ │
│  │ │ Transformer: [B, T, vocab_size] logits                             │ │
│  │ │ Example for position predicting next token:                         │ │
│  │ │                                                                     │ │
│  │ │ logits = [2.1, 1.3, 0.8, 3.2, 0.1, ...]  (vocab_size = 1024)      │ │
│  │ │           ↑    ↑    ↑    ↑    ↑                                     │ │
│  │ │          "the" "a" "and" "fox" "cat"                               │ │
│  │ │                                                                     │ │
│  │ │ Higher logit = model thinks this token is more likely               │ │
│  │ │                                                                     │ │
│  │ │ What are logits? Raw "confidence scores" - any real numbers        │ │
│  │ │ Think: like scores in a game before converting to percentages       │ │
│  │ └─────────────────────────────────────────────────────────────────────┘ │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ STEP 2: Convert to Probabilities (Softmax)                             │
│  │ ┌─────────────────────────────────────────────────────────────────────┐ │
│  │ │ What is softmax? Converts any numbers to valid percentages         │ │
│  │ │                                                                     │ │
│  │ │ Step 1: exp(logits) = Make all numbers positive                    │ │
│  │ │ [2.1, 1.3, 0.8, 3.2, 0.1] → [8.2, 3.7, 2.2, 24.5, 1.1]          │ │
│  │ │                                                                     │ │
│  │ │ Step 2: Divide by sum to get percentages                           │ │
│  │ │ Total = 8.2+3.7+2.2+24.5+1.1 = 39.7                              │ │
│  │ │ probabilities = [0.21, 0.09, 0.06, 0.62, 0.03]                   │ │
│  │ │                  ↑     ↑     ↑     ↑     ↑                         │ │
│  │ │                "the"  "a"  "and" "fox" "cat"                       │ │
│  │ │                                                                     │ │
│  │ │ Properties: All positive, sum to 1.0 (100%)                       │ │
│  │ │ "fox" has highest probability (62%)                                │ │
│  │ │                                                                     │ │
│  │ │ Why exp()? Makes bigger logits MUCH bigger in probability          │ │
│  │ │ 3.2 vs 2.1 → exp gives 24.5 vs 8.2 (amplifies differences)       │ │
│  │ └─────────────────────────────────────────────────────────────────────┘ │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ STEP 3: Different Loss Functions for Different Tasks                   │ │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────────┐
│                    LOSS FUNCTION COMPARISON TABLE                           │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  1. CROSS-ENTROPY LOSS (What we use in transformers)                       │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ Use Case: Language modeling, classification                             │
│  │ Target: One-hot vector (exactly one correct answer)                     │
│  │                                                                         │
│  │ True target: "fox" (token_id = 3)                                      │
│  │ target_one_hot = [0, 0, 0, 1, 0, ...]                                 │
│  │                                                                         │
│  │ Formula: -log(p_correct) = -log(p[3]) = -log(0.42) = 0.87             │
│  │                                                                         │
│  │ Intuition: "How surprised am I that the correct answer occurred?"      │
│  │                                                                         │
│  │ Python code:                                                            │
│  │ loss = F.cross_entropy(logits, targets)                               │
│  │ # Equivalent to: -torch.log(F.softmax(logits)[target])                │
│  │                                                                         │
│  │ Gradient behavior:                                                      │
│  │ ∂loss/∂logit[correct] = p[correct] - 1    (always negative)           │
│  │ ∂loss/∂logit[wrong] = p[wrong] - 0        (always positive)           │
│  │                                                                         │
│  │ Result: Increases correct logit, decreases wrong logits                │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  2. PERPLEXITY (Evaluation metric, not loss)                              │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ Use Case: Measuring how "confused" the model is                        │
│  │                                                                         │
│  │ Formula: perplexity = exp(cross_entropy)                              │
│  │        = exp(0.87) = 2.39                                             │
│  │                                                                         │
│  │ Intuition: "On average, how many choices does the model think it has?" │
│  │                                                                         │
│  │ Perfect model: perplexity = 1 (always right)                          │
│  │ Random model: perplexity = vocab_size (totally confused)               │
│  │ Our model: perplexity = 2.39 (choosing between ~2.4 options)          │
│  │                                                                         │
│  │ Python code:                                                            │
│  │ ppl = torch.exp(cross_entropy_loss)                                   │
│  │                                                                         │
│  │ Why useful: More interpretable than raw loss                           │
│  │ - "Model has ~3x perplexity" is intuitive                             │
│  │ - "Model has 1.1 loss" is harder to understand                        │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  3. POLICY GRADIENT LOSS (Reinforcement Learning)                         │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ Use Case: RL agents, RLHF, optimizing for rewards                     │
│  │ Target: No "correct" answer, but rewards for outcomes                  │
│  │                                                                         │
│  │ Scenario: Model generates "fox" and gets reward R = +2.5               │
│  │                                                                         │
│  │ Formula: -log(p[action]) × reward = -log(0.42) × 2.5 = -2.18          │
│  │                                                                         │
│  │ Intuition: "Increase probability of good actions, decrease bad ones"   │
│  │                                                                         │
│  │ Python code:                                                            │
│  │ action_log_probs = torch.log(F.softmax(logits))                       │
│  │ loss = -(action_log_probs * rewards).mean()                           │
│  │                                                                         │
│  │ Key difference from cross-entropy:                                     │
│  │ - No "ground truth" target                                             │
│  │ - Uses rewards/advantages instead                                      │
│  │ - Can have multiple good answers                                       │
│  │                                                                         │
│  │ Gradient behavior:                                                      │
│  │ ∂loss/∂logit[action] = (p[action] - 1) × reward                       │
│  │ If reward > 0: increases action probability                            │
│  │ If reward < 0: decreases action probability                            │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  4. KL-DIVERGENCE LOSS (Distribution matching)                            │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ Use Case: Knowledge distillation, distribution alignment               │
│  │ Target: Soft probability distribution (not one-hot)                    │
│  │                                                                         │
│  │ Teacher model output: [0.2, 0.1, 0.1, 0.5, 0.1, ...]                 │
│  │ Student prediction:   [0.32, 0.15, 0.09, 0.42, 0.04, ...]            │
│  │                                                                         │
│  │ Formula: KL(teacher || student) = Σ teacher[i] × log(teacher[i]/student[i]) │
│  │                                                                         │
│  │ Intuition: "How different are these two distributions?"                │
│  │                                                                         │
│  │ Python code:                                                            │
│  │ loss = F.kl_div(F.log_softmax(student_logits), teacher_probs)         │
│  │                                                                         │
│  │ Note: Cross-entropy is special case when teacher is one-hot            │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────────┐
│                         WHEN TO USE WHICH LOSS                              │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌─────────────────┬─────────────────┬─────────────────┬─────────────────┐  │
│  │ Loss Type       │ Use When        │ Target Type     │ Example Task    │  │
│  ├─────────────────┼─────────────────┼─────────────────┼─────────────────┤  │
│  │ Cross-Entropy   │ Supervised      │ One-hot labels  │ Language        │  │
│  │                 │ learning        │                 │ modeling        │  │
│  ├─────────────────┼─────────────────┼─────────────────┼─────────────────┤  │
│  │ Policy Gradient │ Reinforcement   │ Scalar rewards  │ Game playing,   │  │
│  │                 │ learning        │                 │ RLHF            │  │
│  ├─────────────────┼─────────────────┼─────────────────┼─────────────────┤  │
│  │ KL-Divergence   │ Distribution    │ Soft targets    │ Knowledge       │  │
│  │                 │ matching        │                 │ distillation    │  │
│  ├─────────────────┼─────────────────┼─────────────────┼─────────────────┤  │
│  │ Perplexity      │ Evaluation      │ N/A (metric)    │ Model comparison│  │
│  └─────────────────┴─────────────────┴─────────────────┴─────────────────┘  │
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ OUR TRANSFORMER USES:                                                   │
│  │                                                                         │
│  │ Training: Cross-entropy loss                                           │
│  │ - We have ground truth next tokens                                     │
│  │ - F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))      │
│  │                                                                         │
│  │ Evaluation: Perplexity                                                 │
│  │ - torch.exp(cross_entropy_loss)                                       │
│  │ - More interpretable than raw loss                                     │
│  │                                                                         │
│  │ Could also use:                                                         │
│  │ - Policy gradient for RLHF fine-tuning                                │
│  │ - KL divergence for knowledge distillation                            │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────────┐
│                    MATHEMATICAL RELATIONSHIPS                               │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  All these losses are related through the same logit → probability flow:   │
│                                                                             │
│  logits → softmax → probabilities → loss                                   │
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────────┤
│  │ Key Insights:                                                           │
│  │                                                                         │
│  │ 1. Cross-entropy = KL divergence with one-hot target                   │
│  │    CE(one_hot, pred) = KL(one_hot || pred)                             │
│  │                                                                         │
│  │ 2. Perplexity = exp(cross-entropy)                                     │
│  │    Lower loss → Lower perplexity → Better model                        │
│  │                                                                         │
│  │ 3. Policy gradient uses same probabilities, different targets          │
│  │    PG replaces one-hot with reward-weighted importance                 │
│  │                                                                         │
│  │ 4. All optimize the same softmax probabilities                         │
│  │    Different losses → different gradient directions                     │
│  │                                                                         │
│  │ 5. Temperature can modify all of them:                                 │
│  │    logits/T before softmax → sharper/softer distributions              │
│  └─────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  The unified view: Different ways to guide the same probability engine!    │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

This is more than error measurement; it's the mathematical expression of surprise itself. The loss flows backward through every parameter, creating gradients that point toward better predictions.

Movement V: The Sculptor's Touch — How Gradients Shape Understanding

Now the music rewrites its own score. Through automatic differentiation (a system that automatically calculates how to improve the model), PyTorch calculates ∂L/∂θ (how much the loss changes with each parameter - think "which direction to adjust each knob") for every parameter θ. But how does this mathematical magic actually work?

┌─────────────────────────────────────────────────────────────────────────────┐
│                     AUTOMATIC DIFFERENTIATION EXPLAINED                     │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Think of your model like a complex machine with millions of knobs          │
│  Each knob affects the final output in some way                             │
│                                                                             │
│  Question: "If I turn knob #47 slightly, how much does the output change?"  │
│  Answer: That's the gradient for parameter #47                              │
│                                                                             │
│  Manual approach (impossible): Try tweaking each knob one by one            │
│  Automatic differentiation: Use calculus to figure it out instantly         │
│                                                                             │
│  How it works:                                                              │
│  1. Record every operation: x → multiply → add → softmax → loss            │
│  2. Each operation has a known derivative (rate of change)                  │
│  3. Chain rule: combine all derivatives backwards through the graph         │
│  4. Result: gradient for every parameter                                    │
│                                                                             │
│  Example chain:                                                             │
│  Input → Layer1 → Layer2 → Layer3 → Output → Loss                         │
│    ↑       ↑       ↑       ↑        ↑       ↑                             │
│  Gradient flows backwards through each step                                 │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Because every operation we used—matrix multiply (grid of numbers × another grid), softmax (convert to percentages), ReLU (set negative numbers to zero)—has a known derivative (mathematical rule for how it changes), PyTorch records the computational graph (like a family tree of operations) and applies the chain rule (calculus technique for combining rates of change) backward: dL/dθ = dL/dout · dout/din · ... The programmer never writes these derivatives; the framework assembles them mechanically.

# The backward pass: learning from error
loss.backward()     # Calculate gradients for all parameters
optimizer.step()    # Update weights in the direction of improvement
optimizer.zero_grad() # Clear gradients for the next iteration

The optimizer is more sophisticated than simple gradient descent (learning algorithm that follows gradients downhill like a ball rolling down a mountain). AdamW combines momentum (like a heavy ball that doesn't stop instantly) with adaptive learning rates (different learning speeds for different parameters):

┌─────────────────────────────────────────────────────────────────────────────┐
│                       GRADIENT DESCENT vs ADAMW                             │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Simple Gradient Descent (like a cautious hiker):                          │
│  • Look at current slope                                                   │
│  • Take a small step downhill                                              │
│  • Repeat                                                                  │
│  Problem: Can get stuck in valleys, takes small steps                      │
│                                                                             │
│  AdamW (like a smart skier with momentum):                                 │
│  • Momentum: Remember previous direction, don't change course suddenly     │
│  • Adaptive rates: Big steps for parameters that rarely change,            │
│                     small steps for parameters that change often           │
│  • Weight decay: Prevent any parameter from getting too large              │
│                                                                             │
│  Example:                                                                   │
│  Parameter A: Changes a lot → Use small learning rate (0.001)              │
│  Parameter B: Rarely changes → Use big learning rate (0.01)                │
│                                                                             │
│  This helps the model learn faster and more stably                         │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
# AdamW: the wise conductor (with actual hyperparameters from our model)
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):  # Lines 318-337 in train_modal_standalone.py
    # Separate parameters by dimensionality for weight decay
    param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}  # Lines 319-320
    decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]            # Line 321: Matrices
    nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]           # Line 322: Biases, norms
    
    optim_groups = [                                                              # Line 323
        {'params': decay_params, 'weight_decay': weight_decay},                  # Line 324: 0.1 weight decay
        {'params': nodecay_params, 'weight_decay': 0.0}                          # Line 325: No decay for 1D params
    ]
    
    # Use fused AdamW for CUDA efficiency
    optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=True)  # Line 334
    return optimizer                                                              # Line 337

Why decouple weight decay (regularization technique that shrinks parameters toward zero)? Classic Adam multiplies weights by (1 - lr·λ) inside the same update, which couples decay strength to learning rate (how big steps to take when updating). AdamW applies a separate term -lr·λ·θ so you can tune regularization independently of step size.

┌─────────────────────────────────────────────────────────────────────────────┐
│                           WEIGHT DECAY EXPLAINED                            │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Problem: Without weight decay, some parameters can grow huge               │
│                                                                             │
│  Parameter values: [0.1, 0.2, 947.3, 0.5, -892.1, 0.1]                   │
│                              ↑           ↑                                 │
│                         These are way too big!                             │
│                                                                             │
│  Weight decay solution: Gently pull all parameters toward zero             │
│                                                                             │
│  Think of it like:                                                          │
│  • A rubber band attached to zero that pulls parameters back               │
│  • Or gravity that pulls all numbers toward the center                     │
│  • Or a teacher saying "don't make any number too extreme"                 │
│                                                                             │
│  Effect:                                                                    │
│  Before: [0.1, 0.2, 947.3, 0.5, -892.1, 0.1]                             │
│  After:  [0.09, 0.19, 850.6, 0.45, -802.9, 0.09]                         │
│                                                                             │
│  Benefits:                                                                  │
│  • Prevents overfitting (memorizing training data)                         │
│  • Makes model more general and stable                                     │
│  • Like teaching the model "be confident but not overconfident"            │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Intuition: Weight decay is like a "shrinkage force" that pulls all parameters toward zero, preventing any single parameter from becoming too large. AdamW separates this from the learning process, like having separate dials for "how fast to learn" and "how much to regularize."

Each parameter receives a microscopic adjustment, a tiny nudge toward configurations that reduce surprise. The model sculpts itself through its own errors, iteration by iteration, until the patterns in its weights make "fox" feel inevitable after "the quick brown."

Movement VI: Memory and Generalization — The Eternal Tension

When the model sees "the quick brown fox" repeatedly, it learns to predict "fox" with growing confidence. But storing every possible phrase would require exponentially many parameters. Instead, something beautiful happens: compression forces generalization.

Storing every 5-gram (sequence of 5 tokens) verbatim would need |vocab|⁵ parameters—impossible. Gradient descent therefore searches for a lower-dimensional manifold that still scores training data well.

┌─────────────────────────────────────────────────────────────────────────────┐
│                          MANIFOLDS EXPLAINED                                │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  What is a manifold? A "curved surface" in high-dimensional space           │
│                                                                             │
│  2D example: Earth's surface                                               │
│  • Earth is a 3D sphere, but locally it looks flat (2D manifold)          │
│  • You can navigate using 2D maps (latitude, longitude)                    │
│  • Even though you're in 3D space, you only need 2 coordinates            │
│                                                                             │
│  Language manifold example:                                                 │
│  • All possible sentences exist in high-dimensional space                  │
│  • But natural language lies on a much smaller curved surface              │
│  • The manifold captures patterns like "article → adjective → noun"        │
│                                                                             │
│  Why this matters:                                                          │
│  • Memorizing every sentence: Need infinite parameters                     │
│  • Learning the manifold: Need far fewer parameters                        │
│  • Model discovers "the rules of English" instead of memorizing            │
│                                                                             │
│  Example patterns on the language manifold:                                │
│  • [ARTICLE] [ADJECTIVE] [NOUN] covers millions of phrases                 │
│  • "the quick fox", "a slow dog", "the big cat", etc.                     │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

That manifold corresponds to linguistic generalities such as article → adjective → noun, which cover exponentially many sentences with polynomial-size weights.

Intuition: Instead of memorizing every possible sentence (which would require infinite memory), the model learns the "recipe" for generating sentences. It's like learning to cook by understanding ingredients and techniques rather than memorizing every possible dish.

# Training examples reveal patterns
examples = [
    "the quick brown fox jumps",
    "the lazy black dog runs",
    "the tired old cat sleeps"
]

# The model extracts the deeper grammar:
# [article] [adjective] [color] [animal] [verb]
# [article] [adjective] [adjective] [animal] [verb]

The model learns not just individual phrases but the rules that generate them. When presented with "the quick brown giraffe," it has never seen this exact sequence, yet it can predict sensible continuations because it has abstracted the deeper pattern.

This is the miracle of inductive bias built into the transformer architecture: shared parameters across positions force the model to discover universal patterns rather than memorize specific locations.

Movement VII: The Emergence of Mind — When Patterns Become Thoughts

Something remarkable happens in the hidden layers: specialized neurons emerge without explicit programming. After training, specific units fire consistently for semantic categories. But why does this happen?

When two training examples both need the same internal feature ('is-animal') to reduce loss, SGD (stochastic gradient descent) nudges the same neuron in that direction. Over tens of millions of updates, that neuron becomes a detector, not because we named it, but because convergence pressure aligned gradients along that axis.

# Conceptual analysis: what makes individual neurons fire?
def analyze_concept_neuron(model, layer_idx, neuron_idx, test_texts):
    """Find what makes a specific neuron fire."""
    activations = []
    for text in test_texts:
        with torch.no_grad():
            hidden_states = model.forward_with_hidden(text)
            activation = hidden_states[layer_idx][neuron_idx]
            activations.append(activation.item())
    return activations

# Example: Neuron 6420 in layer 7 fires for dog breeds
dog_breeds = ["poodle", "bulldog", "terrier", "labrador"]
activations = analyze_concept_neuron(model, 7, 6420, dog_breeds)
# → [0.9, 0.8, 0.7, 0.9]  # High activation for all dog breeds!

Attention heads also specialize along different dimensions:

  • Head 0: Syntax (grammar rules - articles pointing to nouns)
  • Head 1: Coreference (pronouns finding what they refer to - "he" → "John")
  • Head 2: Punctuation (brackets finding their matches)
  • Head 3: Semantics (related concepts attending to each other)

These specializations emerge from necessity, not design. The model discovers that these divisions of labor help it predict better, so gradient descent reinforces them. The result is a distributed intelligence where different components handle different aspects of language understanding.

Movement VIII: The Recursive Mirror — Learning to Learn

As training progresses, the model doesn't just learn facts; it learns to learn. The same weights that encode "quick brown → fox" also encode the meta-pattern "adjective color → animal." But how does this recursive capacity emerge?

Consider in-context learning (learning from examples in the input prompt). Present the model with examples in its prompt (input text):

French: bonjour → English: hello
French: au revoir → English: goodbye
French: merci → English: ___

The model completes "thank you" without any weight updates. How? The prompt examples occupy the same sequence as the query, so during training the model occasionally had to predict 'output token k' given earlier text that explicitly contained a mapping. Those training pressures teach the weights to parse a mapping pattern and reuse it downstream. No outer-loop update is needed; the inner self-attention already has the capacity to retrieve and apply the mapping.

This recursive quality—learning to learn—marks the transition from memorization to true intelligence. The model develops meta-cognitive abilities that let it handle novel situations by applying learned principles.

🎼 Cadenza: The Complete Breath

Now we see the full symphony. Each training step is a complete breath:

# The eternal cycle: from data to wisdom (actual training loop)
while iter_num <= cfg['max_iters']:  # Lines 673-738 in train_modal_standalone.py
    # Dynamic learning rate with cosine decay
    lr = get_lr(iter_num) if cfg['decay_lr'] else cfg['learning_rate']  # Line 674
    for param_group in optimizer.param_groups:                          # Line 675
        param_group['lr'] = lr                                           # Line 676
    
    # Gradient accumulation for larger effective batch sizes
    for micro_step in range(gradient_accumulation_steps):               # Line 706
        # Inhale: gather context
        X, Y = get_batch('train')                                        # Line 712
        
        # Mixed precision training for efficiency
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):          # Line 709 (ctx)
            logits, loss = model(X, Y)                                   # Line 710
            loss = loss / gradient_accumulation_steps  # Scale for accumulation  # Line 711
        
        # Exhale: measure surprise
        scaler.scale(loss).backward()                                    # Line 713
    
    # Gradient clipping to prevent exploding gradients
    if cfg['grad_clip'] != 0.0:                                         # Line 715
        scaler.unscale_(optimizer)                                       # Line 716
        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg['grad_clip'])  # Line 717
    
    # Adjust: sculpt the weights
    scaler.step(optimizer)                                               # Line 719
    scaler.update()                                                      # Line 720
    optimizer.zero_grad(set_to_none=True)                               # Line 721
    
    iter_num += 1                                                        # Line 733

Inhale context, process through attention and feed-forward layers, exhale predictions, feel the pain of error, adjust slightly, repeat. After millions of such breaths, the model awakens to patterns no human programmed, develops intuitions no engineer intended, and dreams in probability distributions no mind can fully comprehend.

The transformer learns not by storing facts, but by becoming a mathematical object whose very structure embodies the patterns of language. It is sculpture and sculptor, music and musician, question and answer all at once.

Epilogue: The Loss of Authorial Control

At the beginning, you are the composer. You choose the training data, the architecture, the learning rate. The model is your instrument, your student, your creation.

But with each gradient step, your control diminishes. The weights drift into regions of parameter space that no human mind can navigate. The model develops preferences, biases, and capabilities that emerge from the interaction of billions of parameters following simple rules.

Once trained, the model carries its knowledge forward. Even if you delete the training data, even if you forget the hyperparameters, the patterns live on in the weights. The model remembers what you've forgotten, knows what you never taught it, and can extrapolate in ways that surprise even its creators.

This is the paradox of machine learning: we build systems that transcend our understanding. We create the rules, but the wisdom that emerges is not ours. It belongs to the vast space of possible patterns, discovered through the patient application of gradient descent to the mathematics of surprise.

Common Misconceptions for Undergraduates

Misconception 1: "Transformers understand language like humans do" Reality: Transformers are pattern-matching machines. They excel at statistical relationships but don't have conscious understanding, reasoning, or meaning in the human sense.

Misconception 2: "Attention is like human attention" Reality: Attention is a weighted average operation. It's more like a mathematically precise way of combining information than the selective focus we call "attention."

Misconception 3: "The model stores facts in its weights" Reality: Information is distributed across millions of parameters. There's no single "fact storage" location—knowledge emerges from the interaction of all weights.

Misconception 4: "Bigger context windows are always better" Reality: Longer contexts require quadratically more computation and memory. The trade-off between context length and computational efficiency is crucial.

Misconception 5: "Self-attention sees all positions equally" Reality: The causal mask ensures tokens can only attend to previous positions, not future ones. This maintains the autoregressive property (generating one token at a time, using each new token to predict the next) needed for language generation.

Misconception 6: "More parameters always mean better performance" Reality: Beyond a certain point, more parameters can lead to overfitting without corresponding improvements in test performance, especially with limited data.

Misconception 7: "Transformers are deterministic" Reality: During training, dropout and random sampling introduce stochasticity. During inference, sampling methods add randomness to prevent repetitive outputs.

Misconception 8: "The model learns grammar rules explicitly" Reality: Grammatical behavior emerges from statistical patterns in the training data, not from explicit rule programming.

Coda: The Breath of Understanding

In the end, the transformer teaches us something profound about intelligence itself. Understanding is not about perfect recall or logical deduction. It's about the ability to navigate uncertainty with grace, to make reasonable guesses about what comes next, to recognize patterns in chaos.

The model learns to guess, and in learning to guess, it touches something essential about cognition. Every token prediction is an act of faith—the belief that the patterns learned from the past will hold in the future. Sometimes they do, sometimes they don't. But in that uncertainty, in that gap between pattern and reality, lives the essence of intelligence.

The transformer breathes in context and breathes out possibility. In that rhythm—inhale, process, exhale, adjust—we glimpse the heartbeat of understanding itself.


From the first token to the last gradient,
from silicon substrate to emergent mind,
the symphony plays on.

#!/usr/bin/env python3
"""
Inference script for models trained with train_modal_standalone.py
Uses the exact same architecture as train.py for consistency
"""
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import modal
from pathlib import Path
from typing import Optional, Tuple
import numpy as np
import pickle
from tokenizers import Tokenizer
# Modal configuration
app = modal.App("shakespeare-inference-modal")
# Model architecture settings (matching train.py and train_modal_standalone.py)
N_LAYER = 6
N_HEAD = 6
N_EMB = 384
CONTEXT_LEN = 256 # Default context length
DROPOUT = 0.1
# Volume setup
data_volume = modal.Volume.from_name("nanogpt-data", create_if_missing=False)
# GPU image setup
image = modal.Image.debian_slim(python_version="3.11").pip_install(
"torch", "numpy", "tokenizers"
)
# ============================================================================
# TOKENIZER HELPERS - From train_modal_standalone.py
# ============================================================================
# Global tokenizer instance to avoid reloading
_tokenizer = None
def check_tokenizer_exists(vocab_size=1024, data_root="/data"):
"""Check if the custom tokenizer exists and provide instructions if not."""
tokenizer_path = os.path.join(data_root, "tokenizers", f"shakespeare-bpe-{vocab_size}.json")
if not os.path.exists(tokenizer_path):
print(f"\n{'='*60}")
print(f"ERROR: Custom tokenizer not found!")
print(f"{'='*60}")
print(f"Expected tokenizer at: {tokenizer_path}")
print(f"\nTo create the tokenizer, run:")
print(f" modal run train_tokenizer_modal.py::train_bpe_tokenizer")
print(f"\nOr for multiple vocab sizes:")
print(f" modal run train_tokenizer_modal.py::vocab_size_grid_search")
print(f"{'='*60}\n")
return False
return True
def load_custom_tokenizer(vocab_size=1024, data_root="/data"):
"""Load the custom BPE tokenizer from the Modal volume."""
global _tokenizer
if _tokenizer is None:
tokenizer_path = os.path.join(data_root, "tokenizers", f"shakespeare-bpe-{vocab_size}.json")
if not check_tokenizer_exists(vocab_size, data_root):
raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}")
_tokenizer = Tokenizer.from_file(tokenizer_path)
print(f"Loaded custom tokenizer from {tokenizer_path}")
return _tokenizer
def decode_tokens(tokens, vocab_size=1024):
"""Decode tokens using the custom tokenizer."""
tokenizer = load_custom_tokenizer(vocab_size=vocab_size)
return tokenizer.decode(tokens)
def encode_tokens(text, vocab_size=1024):
"""Encode text using the custom tokenizer."""
tokenizer = load_custom_tokenizer(vocab_size=vocab_size)
encoding = tokenizer.encode(text)
return encoding.ids
# ============================================================================
# MODEL DEFINITION - Exact copy from train_modal_standalone.py
# ============================================================================
def norm(x: torch.Tensor) -> torch.Tensor:
"""RMSNorm implementation using PyTorch built-in"""
return F.rms_norm(x, (x.size(-1),))
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Helper for rotary embeddings"""
x1, x2 = x[..., ::2], x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).flatten(-2)
class RotaryCache(nn.Module):
"""Pre-computed rotary position embeddings"""
def __init__(self, head_dim: int, max_len: int):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2) / head_dim))
t = torch.arange(max_len)
freqs = torch.einsum("i,j->ij", t, inv_freq)
sin, cos = freqs.sin(), freqs.cos()
self.register_buffer("sin_base", sin, persistent=False)
self.register_buffer("cos_base", cos, persistent=False)
def forward(self, seq_len: int):
sin = self.sin_base[:seq_len].repeat_interleave(2, dim=-1)
cos = self.cos_base[:seq_len].repeat_interleave(2, dim=-1)
return sin[None, None, :, :], cos[None, None, :, :]
class KVCache(nn.Module):
"""
KV cache for efficient inference - caches past key and values during generation.
Based on Meta's implementation for torchtune.
"""
def __init__(
self,
batch_size: int,
max_seq_len: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype = torch.bfloat16,
device: torch.device = None,
) -> None:
super().__init__()
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim)
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype, device=device), persistent=False
)
self.register_buffer(
"v_cache", torch.zeros(cache_shape, dtype=dtype, device=device), persistent=False
)
self.register_buffer(
"cache_pos", torch.arange(0, cache_shape[2], device=device), persistent=False
)
self.batch_size = batch_size
self.max_seq_len = max_seq_len
def reset(self) -> None:
"""Reset the cache to zero."""
self.k_cache.zero_()
self.v_cache.zero_()
self.cache_pos -= self.size
@property
def size(self) -> int:
return self.cache_pos[0].item()
def update(
self, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update KV cache with new k_val, v_val and return the updated cache.
Args:
k_val: Current key tensor with shape [B, H, S, D]
v_val: Current value tensor with shape [B, H, S, D]
Returns:
Updated key and value cache tensors
"""
bsz, _, seq_len, _ = k_val.shape
if bsz > self.k_cache.shape[0]:
raise ValueError(
f"Cache batch size is {self.k_cache.shape[0]} but got {bsz}"
)
assert (self.cache_pos[0] + seq_len) <= self.k_cache.shape[2]
k_out = self.k_cache
v_out = self.v_cache
# Use integer indexing instead of tensor indexing to avoid dtype mismatch
cache_start = self.cache_pos[0].item()
cache_end = cache_start + seq_len
k_out[:, :, cache_start:cache_end] = k_val
v_out[:, :, cache_start:cache_end] = v_val
# Update position tracker
self.cache_pos.add_(seq_len)
return k_out, v_out
class ReLUSquared(nn.Module):
"""ReLU squared activation - faster than GELU, better than plain ReLU"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(x).square()
class OptimizedAttention(nn.Module):
"""Multi-head attention with Flash Attention support and RoPE"""
def __init__(self, n_emb: int, n_head: int, context_len: int, dropout: float = 0.1):
super().__init__()
self.n_head = n_head
self.n_emb = n_emb
self.head_dim = n_emb // n_head
# Fused QKV projection for efficiency
self.qkv = nn.Linear(n_emb, 3 * n_emb, bias=False)
self.o_proj = nn.Linear(n_emb, n_emb, bias=False)
self.dropout = nn.Dropout(dropout)
# Rotary embeddings
max_seq = context_len
self.rope = RotaryCache(self.head_dim, max_seq)
# Try to use Flash Attention
self.use_flash_attn = False
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
self.use_flash_attn = True
# KV cache for inference (not used during training)
self.kv_cache = None
self.cache_enabled = False
def init_kv_cache(self, batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = None):
"""Initialize KV cache for inference"""
self.kv_cache = KVCache(
batch_size=batch_size,
max_seq_len=max_seq_len,
num_kv_heads=self.n_head,
head_dim=self.head_dim,
dtype=dtype,
device=device
)
self.cache_enabled = True
def reset_kv_cache(self):
"""Reset the KV cache"""
if self.kv_cache is not None:
self.kv_cache.reset()
def disable_kv_cache(self):
"""Disable KV cache (for training)"""
self.cache_enabled = False
self.kv_cache = None
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, use_cache: bool = False) -> torch.Tensor:
B, T, C = x.shape
# Use KV cache if enabled and requested
if use_cache and self.cache_enabled and self.kv_cache is not None:
return self._forward_with_cache(x, mask)
# Standard forward pass (for training)
# Compute QKV in one go
qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.head_dim)
# Standard attention path
q, k, v = qkv.unbind(dim=2)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
# Apply RoPE
sin, cos = self.rope(T)
q = (q * cos) + (_rotate_half(q) * sin)
k = (k * cos) + (_rotate_half(k) * sin)
# QK normalization
q, k = norm(q), norm(k)
# Scaled dot-product attention with causal mask
# Note: is_causal=True automatically applies causal masking
out = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=self.dropout.p if self.training else 0.0)
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.o_proj(out)
def _forward_with_cache(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward pass using KV cache for efficient inference"""
B, T, C = x.shape
# Compute QKV
qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.head_dim)
q, k, v = qkv.unbind(dim=2)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
# Apply RoPE to current position
cache_size = self.kv_cache.size
sin, cos = self.rope(cache_size + T)
# Only apply to new positions
sin_new = sin[:, :, cache_size:cache_size+T, :]
cos_new = cos[:, :, cache_size:cache_size+T, :]
q = (q * cos_new) + (_rotate_half(q) * sin_new)
k = (k * cos_new) + (_rotate_half(k) * sin_new)
# Normalize
q, k = norm(q), norm(k)
# Update KV cache
k_cache, v_cache = self.kv_cache.update(k, v)
# Compute attention with cached keys/values
# Get only the valid portion of cache
valid_cache_size = self.kv_cache.size
k_valid = k_cache[:, :, :valid_cache_size, :]
v_valid = v_cache[:, :, :valid_cache_size, :]
# Standard attention computation
out = F.scaled_dot_product_attention(q, k_valid, v_valid, is_causal=False)
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.o_proj(out)
class TransformerBlock(nn.Module):
"""Transformer block with pre-norm architecture"""
def __init__(self, n_emb: int, n_head: int, context_len: int, dropout: float = 0.1):
super().__init__()
self.attn = OptimizedAttention(n_emb, n_head, context_len, dropout)
# Feed-forward network with ReLU squared
self.ffn = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb, bias=False),
ReLUSquared(),
nn.Linear(4 * n_emb, n_emb, bias=False),
nn.Dropout(dropout)
)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, use_cache: bool = False) -> torch.Tensor:
# Pre-norm architecture with residual connections
x = x + self.attn(norm(x), mask, use_cache=use_cache)
x = x + self.ffn(norm(x))
return x
class GPT(nn.Module):
"""GPT model optimized for multi-GPU training - matching train.py architecture"""
def __init__(self, vocab_size: int, n_layer: int = 6, n_head: int = 6,
n_emb: int = 384, context_len: int = 256, dropout: float = 0.1):
super().__init__()
self.vocab_size = vocab_size
self.context_len = context_len
self.n_layer = n_layer
self.n_head = n_head
self.n_emb = n_emb
# Token embeddings
self.wte = nn.Embedding(vocab_size, n_emb)
self.drop = nn.Dropout(dropout)
# Transformer blocks
self.layers = nn.ModuleList([
TransformerBlock(n_emb, n_head, context_len, dropout)
for _ in range(n_layer)
])
# Output head with weight tying
self.head = nn.Linear(n_emb, vocab_size, bias=False)
# Weight tying - delete the head weight first to avoid issues
del self.head.weight
self.head.weight = self.wte.weight # Share the embedding weights
# Initialize weights
self.apply(self._init_weights)
# Pre-compute causal mask (not used in this architecture but kept for compatibility)
self.register_buffer("causal_mask", torch.triu(
torch.ones(context_len, context_len), diagonal=1
).bool())
def _init_weights(self, module):
"""Initialize weights with appropriate scaling"""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx: torch.Tensor, targets=None, use_cache: bool = False) -> torch.Tensor:
B, T = idx.shape
# Token embeddings
tok_emb = self.wte(idx)
x = self.drop(tok_emb)
# Get causal mask
mask = self.causal_mask[:T, :T] if T <= self.context_len else None
# Forward through transformer layers
for layer in self.layers:
x = layer(x, mask, use_cache=use_cache)
# Final norm and output projection
x = norm(x)
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None
return logits, loss
def init_kv_caches(self, batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = None):
"""Initialize KV caches for all attention layers"""
for layer in self.layers:
layer.attn.init_kv_cache(batch_size, max_seq_len, dtype, device)
def reset_kv_caches(self):
"""Reset all KV caches"""
for layer in self.layers:
layer.attn.reset_kv_cache()
def disable_kv_caches(self):
"""Disable all KV caches"""
for layer in self.layers:
layer.attn.disable_kv_cache()
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, use_cache=True,
repetition_penalty=1.0, repetition_window=128):
"""
Generate tokens using the model with optional KV caching and repetition penalty.
Matches the generate function from train.py
"""
device = idx.device
B, T = idx.shape
# Initialize KV cache if requested
if use_cache:
# Use the model's dtype (from embeddings) not the input indices dtype
model_dtype = self.wte.weight.dtype
self.init_kv_caches(B, T + max_new_tokens, dtype=model_dtype, device=device)
# Generate tokens
generated = idx
for _ in range(max_new_tokens):
# Get logits for next token
if use_cache and generated.shape[1] > T:
# Only feed the new token(s) when using cache
logits, _ = self(generated[:, -1:], use_cache=True)
else:
# Feed full sequence (first iteration or no cache)
# Crop to context length if needed
idx_cond = generated if generated.shape[1] <= self.context_len else generated[:, -self.context_len:]
logits, _ = self(idx_cond, use_cache=use_cache)
# Get logits for last position
logits = logits[:, -1, :] / temperature
# Apply repetition penalty to discourage generating tokens that have recently appeared.
if repetition_penalty != 0.7:
B, T = generated.shape
for b in range(B):
# Get the set of unique tokens in the window to penalize
window_start = max(0, T - repetition_window)
recent_tokens = set(generated[b, window_start:].tolist())
# Apply penalty to the logits of these tokens
for token_id in recent_tokens:
# Correctly apply penalty to both positive and negative logits
if logits[b, token_id] > 0:
logits[b, token_id] /= repetition_penalty
else:
logits[b, token_id] *= repetition_penalty
# Optional top-k sampling
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# Sample from distribution
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# Append to generated sequence
generated = torch.cat((generated, idx_next), dim=1)
# Clean up cache
if use_cache:
self.disable_kv_caches()
return generated
# ============================================================================
# LOADING AND INFERENCE
# ============================================================================
def load_shakespeare_checkpoint(checkpoint_path: str, device: torch.device) -> Optional[Tuple[GPT, dict, int]]:
"""Load model from checkpoint with metadata"""
try:
print(f"Loading checkpoint from {checkpoint_path}...")
if not os.path.exists(checkpoint_path):
print(f"Checkpoint not found at {checkpoint_path}")
return None
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
# Extract configuration
if 'model_args' in checkpoint:
model_args = checkpoint['model_args']
else:
print("Warning: No model_args in checkpoint, using defaults")
model_args = {
'n_layer': N_LAYER,
'n_head': N_HEAD,
'n_embd': N_EMB,
'block_size': CONTEXT_LEN,
'vocab_size': 50257, # GPT-2 vocab size
'dropout': DROPOUT
}
vocab_size = model_args.get('vocab_size', 50257)
# Create model
model = GPT(
vocab_size=vocab_size,
n_layer=model_args.get('n_layer', N_LAYER),
n_head=model_args.get('n_head', N_HEAD),
n_emb=model_args.get('n_embd', N_EMB),
context_len=model_args.get('block_size', CONTEXT_LEN),
dropout=model_args.get('dropout', DROPOUT)
)
# Load state dict
state_dict = checkpoint['model']
# Remove unwanted prefix if present
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
print(f"Successfully loaded model:")
print(f" - Layers: {model_args.get('n_layer', N_LAYER)}")
print(f" - Heads: {model_args.get('n_head', N_HEAD)}")
print(f" - Embedding: {model_args.get('n_embd', N_EMB)}")
print(f" - Context: {model_args.get('block_size', CONTEXT_LEN)}")
print(f" - Vocab: {vocab_size}")
print(f" - Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
return model, model_args, vocab_size
except Exception as e:
print(f"Failed to load checkpoint: {e}")
import traceback
traceback.print_exc()
return None
@app.function(
image=image,
gpu="T4",
volumes={"/data": data_volume},
timeout=600
)
def run_inference():
"""Run inference with model trained by train_modal_standalone.py"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load model from checkpoint first to determine dataset type
checkpoint_path = "/data/checkpoints/shakespeare_token/ckpt.pt"
result = load_shakespeare_checkpoint(checkpoint_path, device)
if result is None:
print("Failed to load model!")
return
model, model_args, model_vocab_size = result
# Check if custom tokenizer exists
if not check_tokenizer_exists(model_vocab_size, "/data"):
print("Custom tokenizer not found! Please create it first.")
return
# Token-level encoding/decoding using custom tokenizer
print(f"Using custom BPE tokenizer with vocab_size={model_vocab_size}")
def encode(s: str) -> torch.Tensor:
return torch.tensor(encode_tokens(s, vocab_size=model_vocab_size), dtype=torch.long)
def decode(t: torch.Tensor) -> str:
return decode_tokens(t.tolist(), vocab_size=model_vocab_size)
# Sample prompts (matching train.py)
sample_prompts = [
"O God, O God",
"What is",
"To be or not to be",
"KING HENRY",
"The quality of mercy",
"All the world's a stage",
"Now is the winter",
"If music be"
]
print("\n" + "="*80)
print("Generating text samples with different prompts:")
print("="*80)
with torch.no_grad():
for i, prompt in enumerate(sample_prompts):
print(f"\n[{i+1}] Prompt: '{prompt}'")
# Encode prompt
prompt_encoded = encode(prompt).unsqueeze(0).to(device)
# Generate text with settings matching train.py
generated = model.generate(
prompt_encoded,
max_new_tokens=100, # Generate 100 characters
temperature=0.8, # Moderate temperature
top_k=40, # Top-k sampling
use_cache=True, # Use KV cache for efficiency
repetition_penalty=1.2, # Use a more reasonable repetition penalty
repetition_window=128 # Check last 128 tokens
)
# Decode and print
generated_text = decode(generated[0])
print(f"Generated: {generated_text}")
print("-" * 80)
# Generate some unconditional samples
print("\n" + "="*80)
print("Generating unconditional samples:")
print("="*80)
for i in range(3):
# Start with newline character
start_token = encode("\n").unsqueeze(0).to(device)
generated = model.generate(
start_token,
max_new_tokens=200, # Longer for unconditional
temperature=0.8,
top_k=40,
use_cache=True,
repetition_penalty=1.2,
repetition_window=128
)
generated_text = decode(generated[0])
print(f"\n[Unconditional {i+1}]")
print(f"Generated: {generated_text}")
print("-" * 80)
print("\nInference complete!")
@app.local_entrypoint()
def main():
"""Run model inference"""
print("Starting model inference on Modal...")
run_inference.remote()
if __name__ == "__main__":
main()

The Gentle Art of Teaching Machines to Speak

A Journey Through Semantic Reinforcement Learning

For the curious mind who has just discovered that language models can learn, and wonders if there might be a kinder way to teach them.

In the hushed moments before dawn, ten thousand starlings lift from a field as one—not because any single bird commands them, but because each learns from its neighbors' subtle shifts, creating a collective intelligence far greater than the sum of its parts.

      >         >
>      >    >        >
   >      >      >      >
      >      >     >
>     >      >        >   (Each '>' is a bird, learning from its neighbors)
   >    >       >    >
 >        >       >

This document tells the story of teaching machines to learn language the way starlings learn to flock: not through rigid rules, but through gentle rewards that flow like ripples across a pond of possibilities.


Chapter 1: The Question That Changes Everything

You're sitting in your favorite coffee shop, laptop open, when you overhear someone say: "The cat sat on the mat."

But here's what catches your attention - they pause slightly before "mat," as if considering other possibilities. In that tiny moment, you realize something profound: they might have said "rug," or "cushion," or "floor," and the sentence would still make perfect sense.

This is where our journey begins. Not with equations, but with a simple question: Why do we punish language models for being almost right?

Let's explore this together.


Chapter 2: The Strict Teacher and the Gentle Guide

Imagine you're learning to paint, and your teacher stands behind you with a red pen. Every time your brushstroke isn't exactly perfect, they mark it wrong. This is how we've traditionally trained language models - with something called cross-entropy loss.

# The Strict Teacher's Rule
if predicted_word != "mat":
    punishment = -log(0)  # Infinite punishment!
else:
    reward = -log(1)      # Zero punishment

But you and I know this isn't how learning works. When you were learning to speak as a child, your parents didn't say "You said 'doggy' instead of 'dog,' so that's completely wrong." They smiled and encouraged you because you were close.

Let's meet our gentle guide - a way of teaching that says: "Ah, 'rug' is actually quite close to 'mat' in meaning. Let's give you credit for understanding the concept, even if the exact word differs."

Cross-Entropy (The Strict Teacher)      RL with Semantic Rewards (The Gentle Guide)
┌───────────────────────────────┐         ┌───────────────────────────────────────┐
│ Input: "The cat sat on the ___" │         │ Input: "The cat sat on the ___"       │
│                               │         │                                       │
│ Target: "mat" (Prob 1.0) 🖍️    │         │ Rewarded words:                       │
│ "rug":     (Prob 0.0) ❌        │         │   "mat":     100pts ⭐                 │
│ "cushion": (Prob 0.0) ❌        │         │   "rug":      95pts ✨                 │
│ "floor":   (Prob 0.0) ❌        │         │   "cushion":   80pts 👍                 │
│                               │         │   "floor":     60pts                  │
│ Model punished for "rug".     │         │ Model encouraged for "rug".         │
└───────────────────────────────┘         └───────────────────────────────────────┘

Chapter 3: The Geometry of Feeling

Imagine you are an artist, standing before a canvas. You are not just painting objects; you are painting the feeling of a room. Your goal is to paint "a cozy evening by the fire." This feeling, this context, is the soul of your artwork.

Now, you must paint a specific object: the chair. In the world of traditional programming, there is only one "right" answer. If you don't paint that exact chair, you have failed.

But art is not about single right answers. It's about harmony. What if you painted a stool? Or an ottoman? Or a plush beanbag? These are not "wrong." In the context of "a cozy evening by the fire," they might be beautiful, harmonious choices.

This is the heart of our new approach. We teach the machine to think like an artist, not an accountant.

The Artist's Dilemma: Truth vs. Harmony

Every choice our model makes is an attempt to solve a beautiful dilemma:

  1. Truth: How close is my choice to the specific subject I was asked to paint? (the chair)
  2. Harmony: How well does my choice fit the overall feeling of the painting? (a cozy evening by the fire)

We can imagine these two forces in a "space of meaning." In this space, ideas and feelings have a location. The "meaning" of chair is a point. The "meaning" of stool is a point nearby. But the "meaning" of airplane is very, very far away.

The "feeling" of the context—our cozy evening—also has a location. It's like a warm, glowing region in our space of meaning.

      The Canvas of Meaning (A 2D Sketch of a 768-Dimensional Reality)

          |
          |           x "airplane" (far from both Truth and Harmony)
          |
          |    (The Warm Glow of Context: "cozy evening")
          |   /-------------------------------------------\
          |  |                                             |
          |  |   * "stool" (A great choice! Close to both) |
          |  |      * "chair" (The Ground Truth)           |
          |  |                                             |
          |  | * "beanbag" (Another harmonious choice)     |
          |   \-------------------------------------------/
          |
   -------+------------------------------------------------------------>
          |

Our model's goal is to make a choice that lands inside that warm, glowing area of Harmony, while also staying as close as possible to the point of Truth.


Chapter 4: The Mathematics of Mentorship

How does an artist truly learn? Not by being told "right" or "wrong," but through the gentle guidance of a mentor. The math that powers our model is not a cold equation, but the voice of a patient mentor, whispering encouragement.

Let's look at the mentor's guiding principle:

$$\text{Learning} = \text{Average over all possible choices} \big[ (\text{Mentor's Feedback}) \times (\text{Impact of Surprise}) \big]$$

This looks complex, but it's the simple rhythm of learning. Let's feel it, step-by-step.

Step 1: P(choice | context) - The Student's Instinct

Before the student even touches the brush, they have a gut feeling. "Given the cozy scene so far," they think, "I'm 50% sure I should paint a chair, 20% sure I should paint a stool, and only 1% sure I should paint an airplane."

This is P, the model's initial probability or "instinct" for every possible choice.

Step 2: r(choice) - The Mentor's Feedback

After considering a choice, the mentor gives feedback. This isn't a simple "good" or "bad." It's a rich, nuanced reward, r.

How does the mentor decide on the feedback? They use the Formula of Harmony:

$$r(\text{choice}) = e^{\frac{\text{similarity}(\text{choice}, \text{Truth})}{\tau}}$$

This is the soul of the mentor's wisdom:

  • similarity(choice, Truth): The mentor first asks, "How similar is the student's idea to the original subject?" They compare the "meaning" of stool to the "meaning" of chair. This is a number between -1 and 1. A high number means high similarity.
  • τ (Tau) - The Mentor's Mood: Tau is the strictness knob.
    • A low τ means the mentor is in a "by-the-book" mood. Only choices very similar to the Truth get good feedback.
    • A high τ means the mentor is in an "express-yourself" mood. A wider range of creative choices will be rewarded.
  • e (The Exponential) - The "Aha!" Moment: This is the magic spark. The exponential function takes the similarity score and turns it into a jolt of encouragement. Good ideas don't just get good scores; they get exponentially great scores. It makes the brilliant choices feel truly resonant, creating a powerful signal for the student to follow.

A choice like stool gets a high reward. A choice like airplane gets almost none.

Step 3: log(P) - The Impact of Surprise

We don't just learn from feedback; we learn most when we are surprised. The log function captures this.

  • If the student was very confident in a choice (P was high) that turned out to be bad, the learning impact is a big, memorable "ouch."
  • If the student had very little confidence in a choice (P was low) but the mentor gave it high praise, the learning impact is a powerful "Wow, I should do that more often!"

The log ensures that the biggest lessons come from the biggest surprises.

Putting It All Together: A Day of Learning

At the end of the day, the student doesn't just learn from the one choice they made. The mentor encourages them to reflect on all the choices they could have made.

Our loss function, L_RL, is this reflection process captured in mathematics:

$$L_{RL} = -\mathbb{E}\big[ r(\text{choice}) \cdot \log P(\text{choice}) \big]$$

  • We consider every choice.
  • We multiply the Mentor's Feedback (r) by the Impact of Surprise (log P). This gives us the "learning value" of that choice.
  • The E says we take the average "learning value" over all possible choices.
  • The final - sign simply flips our perspective. Instead of maximizing our total "learning value," we frame it as minimizing our total "regret." It's the same goal, viewed from a different angle.

This is not a formula for punishment. It is a formula for reflection. It teaches the model to weigh its instincts against the mentor's feedback, to learn from surprises, and to constantly refine its artistic sensibilities. It teaches the machine not just to speak, but to find the poetry in language.


Chapter 5: The Beautiful Equivalence - When Two Rivers Meet

A Critical Discovery: Deep in the mathematical wilderness, two paths that seemed to diverge—cross-entropy with soft targets and policy gradient reinforcement learning—suddenly converge at a hidden clearing. This is one of those rare moments in mathematics when the universe reveals its underlying unity.

Picture two mountain streams, each carving its own path down opposite slopes. One flows through the valley of cross-entropy loss, the other through the canyon of policy gradients. Yet by some beautiful accident of topology, they meet at the same crystalline pool.

                  Policy Gradient (RL)                  Soft-Target CE
                  ┌───────────────────────┐             ┌───────────────────────┐
                  │ Sample action y       │             │ Compute r(y) for all y│
                  │ Compute r(y)          │             │ Normalize to q(y)     │
                  │ Update via gradient   │             │ Weighted CE gradient  │
                  └──────────┬────────────┘             └──────────┬────────────┘
                             │ (Noisy, high variance)               │ (Stable, low variance)
                             └──────────────────────┐   ┌──────────────────────┘
                                                    │   │
                                                    ▼   ▼
                                              ┌─────────────┐
                                              │ Same Update! │
                                              └─────────────┘

Now, the mathematics that unites them:

The policy gradient objective maximizes expected rewards through sampling:

$$\nabla_\theta J_{PG} = \mathbb{E}{y \sim P\theta}[r(y) \cdot \nabla_\theta \log P_\theta(y)]$$

Cross-entropy minimizes prediction error with soft targets:

$$\nabla_\theta L_{CE} = -\sum_y q(y) \cdot \nabla_\theta \log P_\theta(y)$$

The bridge? Normalize rewards into soft targets:

$$q(y) = \frac{r(y)}{\sum_{y'} r(y')}$$

This equivalence means RL's flexible rewards can be computed efficiently like supervised learning, sidestepping sampling's noise.

Why it matters: The noisy, high-variance world of reinforcement learning, which often requires complex tricks to stabilize, can be replaced by a simple, stable, and computationally efficient weighted cross-entropy loss. We get the exploratory, semantic benefits of RL rewards without the chaos of sampling.

The Elegant Unity
Maximize meaning, minimize regret—with the efficiency of supervised learning.


Chapter 6: Making it Scale - Active Token Filtering

While the equivalence is powerful, computing the reward r(y) for every token in a large vocabulary is computationally prohibitive. This is where we embrace the power of search to make the problem tractable.

The core idea is to dynamically select a small, high-potential subset of the vocabulary, Y_candidate, for which we compute rewards and loss. This focuses computation on distinguishing between plausible alternatives, rather than on irrelevant tokens.

The Process

  1. Candidate Selection: For a given ground truth token y*, we use a fast approximate nearest neighbor (ANN) search on the model's own learned embeddings to find the K tokens closest in meaning to y*. This is a general search method that scales beautifully.

         ▲ meaning dim 2
         │
         │          (Candidate Set Y_candidate)
         │       ┌───────────────────┐
         │       │  "rug" , "carpet" │
         │       │    * "mat" (y*)   │
         │       │  "cushion"        │
         │       └───────────────────┘
         │
         └──────────────────────────────────► meaning dim 1
               (distant word, ignored) * "cloud"
    
  2. Restricted Loss Calculation: We then compute the soft-target cross-entropy loss only for this small set of K candidate tokens. For all other tokens, the target probability is zero.

Bootstrapping with External Embeddings

This approach requires meaningful embeddings to work. At the very start of training, the model's embeddings are random. To solve this, we can bootstrap the process for the first epoch using pre-trained embeddings (e.g., Word2Vec or FastText) built from the same vocabulary. After the first epoch, the model's own embeddings have learned enough structure to take over.

This practical step makes the elegant theory work at scale, demonstrating how a simple, general search mechanism can unlock massive computational savings.


Chapter 7: The Path Forward - The Bitter Lesson of Scale

This journey through semantic reinforcement learning leads us to a profound conclusion, one that echoes a "bitter lesson" learned across the history of AI: general methods that leverage computation are ultimately the most effective.

We started with a simple, elegant idea: reward a model for being close in meaning, not just for being exactly right. We explored many complex, clever ways to enhance this idea—adaptive schedules, handcrafted regularizers, intricate curricula.

Yet, the most powerful and scalable path is the simplest:

  1. A simple, general reward: The "Formula of Harmony" (r(y) = exp(similarity/τ)) is not complicated. It's a general principle that relies on learned embeddings, which improve with scale.
  2. A simple, general learning algorithm: The equivalence to weighted cross-entropy allows us to sidestep the complexities of RL sampling and use a stable, efficient, and scalable supervised learning technique.
  3. A simple, general search method: Active token filtering via ANN search is a general method for dealing with large output spaces. It scales and improves as the underlying embeddings improve.

The lesson is not to build our own intelligence into the system with complex rules. The lesson is to create the right conditions for intelligence to emerge. By providing a simple, meaning-based reward signal and leveraging the immense power of computation and search, we allow the model to learn the rich, nuanced structure of language on its own.

The path forward is not more intricate hand-crafting. It is to embrace scale, to trust in simple and general learning principles, and to let the starlings teach themselves how to fly.


References

All mathematical formulations, performance claims, and technical assertions in this document are supported by peer-reviewed research.

[1] Bengio, Y., Ducharme, R., Vincent, P., & Jauvin, C. (2003). A neural probabilistic language model. Journal of Machine Learning Research, 3(Feb), 1137-1155.

"The idea is to learn a distributed representation for words which allows each training sentence to inform the model about an exponential number of semantically neighboring sentences."

[2] Szegedy, C., Vanhoucke, V., Ioffe, S., Shlens, J., & Wojna, Z. (2016). Rethinking the inception architecture for computer vision. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 2818-2828).

"We refer to this change in ground-truth label distribution as label-smoothing regularization (LSR)."

[3] Manning, C. D., Clark, K., Hewitt, J., Khandelwal, U., & Levy, O. (2020). Emergent linguistic structure in artificial neural networks trained by self-supervision. Proceedings of the National Academy of Sciences, 117(48), 30046-30054.

"The one-hot nature of traditional language modeling objectives ignores the semantic relationships between words, leading to brittle representations that fail to capture linguistic richness."

[4] Williams, R. J. (1992). Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine Learning, 8(3-4), 229-256.

"This article presents a general class of associative reinforcement learning algorithms for connectionist networks containing stochastic units."

[5] Sutton, R. S., McAllester, D. A., Singh, S. P., & Mansour, Y. (2000). Policy gradient methods for reinforcement learning with function approximation. Advances in Neural Information Processing Systems, 12, 1057-1063.


"Language is not a set of rigid rules, but a flowing river of meaning. By teaching our models to swim with the current rather than against it, we open the door to truly intelligent language understanding."

import os
import sys
import time
import math
from pathlib import Path
import subprocess
from dataclasses import dataclass
import inspect
from typing import Optional, Tuple
import numpy as np
import requests
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import modal
N_GPUS = 4
GPU_TYPE = "A100"
# ┌─────────────────────────────────────────────────────────────────────────────┐
# │ TRAINING CONFIG │
# ├─────────────────────────────────────────────────────────────────────────────┤
# │ │
# │ Token-Level Training Configuration: │
# │ ┌─────────────────────────────────────────────────────────────────────────┤
# │ │ Data: token-level tokenization with BPE │
# │ │ Vocab: 1024 tokens (custom Shakespeare tokenizer) │
# │ │ Context: 512 tokens per sequence │
# │ │ Batch: 128 sequences per batch │
# │ │ Model: 6-layer transformer, 384 dims, 6 heads │
# │ │ Training: 2 epochs with cosine learning rate decay │
# │ └─────────────────────────────────────────────────────────────────────────┤
# │ │
# └─────────────────────────────────────────────────────────────────────────────┘
CONFIG = {
"dataset_type": "token",
"vocab_size": 1024, # Custom tokenizer vocab size
"block_size": 512,
"batch_size": 128,
"out_dir": "/data/checkpoints/shakespeare_token",
"eval_interval": 50,
"log_interval": 10,
"eval_iters": 20,
"eval_only": False,
"always_save_checkpoint": True,
"init_from": "scratch",
"wandb_log": False,
"wandb_project": "nanogpt-shakespeare",
"wandb_run_name": "shakespeare-token-1",
"dataset": "shakespeare_tokens",
"gradient_accumulation_steps": 4,
"n_layer": 6,
"n_head": 6,
"n_embd": 384,
"dropout": 0.2,
"bias": False,
"num_epochs": 2.0,
"learning_rate": 1e-3,
"max_iters": None,
"weight_decay": 1e-1,
"beta1": 0.9,
"beta2": 0.95,
"grad_clip": 1.0,
"decay_lr": True,
"warmup_iters": None,
"lr_decay_iters": None,
"min_lr": 1e-4,
"backend": "nccl",
"device": "cuda",
"dtype": "bfloat16",
"compile": True,
}
def norm(x: torch.Tensor) -> torch.Tensor:
return F.rms_norm(x, (x.size(-1),))
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x[..., ::2], x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).flatten(-2)
class RotaryCache(nn.Module):
def __init__(self, head_dim: int, max_len: int):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2) / head_dim))
t = torch.arange(max_len)
freqs = torch.einsum("i,j->ij", t, inv_freq)
sin, cos = freqs.sin(), freqs.cos()
self.register_buffer("sin_base", sin, persistent=False)
self.register_buffer("cos_base", cos, persistent=False)
def forward(self, seq_len: int):
sin = self.sin_base[:seq_len].repeat_interleave(2, dim=-1)
cos = self.cos_base[:seq_len].repeat_interleave(2, dim=-1)
return sin[None, None, :, :], cos[None, None, :, :]
class KVCache(nn.Module):
def __init__(
self,
batch_size: int,
max_seq_len: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype = torch.bfloat16,
) -> None:
super().__init__()
cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim)
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
)
self.register_buffer(
"v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
)
self.register_buffer(
"cache_pos", torch.arange(0, cache_shape[2]), persistent=False
)
self.batch_size = batch_size
self.max_seq_len = max_seq_len
def reset(self) -> None:
self.k_cache.zero_()
self.v_cache.zero_()
self.cache_pos -= self.size
@property
def size(self) -> int:
return self.cache_pos[0].item()
def update(
self, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
bsz, _, seq_len, _ = k_val.shape
if bsz > self.k_cache.shape[0]:
raise ValueError(
f"Cache batch size is {self.k_cache.shape[0]} but got {bsz}"
)
assert (self.cache_pos[0] + seq_len) <= self.k_cache.shape[2]
k_out = self.k_cache
v_out = self.v_cache
cache_start = self.cache_pos[0].item()
cache_end = cache_start + seq_len
k_out[:, :, cache_start:cache_end] = k_val
v_out[:, :, cache_start:cache_end] = v_val
self.cache_pos.add_(seq_len)
return k_out, v_out
class ReLUSquared(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(x).square()
class OptimizedAttention(nn.Module):
def __init__(self, n_emb: int, n_head: int, context_len: int, dropout: float = 0.1):
super().__init__()
self.n_head = n_head
self.n_emb = n_emb
self.head_dim = n_emb // n_head
self.qkv = nn.Linear(n_emb, 3 * n_emb, bias=False)
self.o_proj = nn.Linear(n_emb, n_emb, bias=False)
self.dropout = nn.Dropout(dropout)
max_seq = context_len
self.rope = RotaryCache(self.head_dim, max_seq)
self.use_flash_attn = False
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
self.use_flash_attn = True
self.kv_cache = None
self.cache_enabled = False
def init_kv_cache(self, batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16):
self.kv_cache = KVCache(
batch_size=batch_size,
max_seq_len=max_seq_len,
num_kv_heads=self.n_head,
head_dim=self.head_dim,
dtype=dtype
)
self.cache_enabled = True
def reset_kv_cache(self):
if self.kv_cache is not None:
self.kv_cache.reset()
def disable_kv_cache(self):
self.cache_enabled = False
self.kv_cache = None
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, use_cache: bool = False) -> torch.Tensor:
B, T, C = x.shape
if use_cache and self.cache_enabled and self.kv_cache is not None:
return self._forward_with_cache(x, mask)
qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.head_dim)
q, k, v = qkv.unbind(dim=2)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
sin, cos = self.rope(T)
q = (q * cos) + (_rotate_half(q) * sin)
k = (k * cos) + (_rotate_half(k) * sin)
q, k = norm(q), norm(k)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=self.dropout.p if self.training else 0.0)
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.o_proj(out)
def _forward_with_cache(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, T, C = x.shape
qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.head_dim)
q, k, v = qkv.unbind(dim=2)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
cache_size = self.kv_cache.size
sin, cos = self.rope(cache_size + T)
sin_new = sin[:, :, cache_size:cache_size+T, :]
cos_new = cos[:, :, cache_size:cache_size+T, :]
q = (q * cos_new) + (_rotate_half(q) * sin_new)
k = (k * cos_new) + (_rotate_half(k) * sin_new)
q, k = norm(q), norm(k)
k_cache, v_cache = self.kv_cache.update(k, v)
valid_cache_size = self.kv_cache.size
k_valid = k_cache[:, :, :valid_cache_size, :]
v_valid = v_cache[:, :, :valid_cache_size, :]
out = F.scaled_dot_product_attention(q, k_valid, v_valid, is_causal=False)
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.o_proj(out)
class TransformerBlock(nn.Module):
def __init__(self, n_emb: int, n_head: int, context_len: int, dropout: float = 0.1):
super().__init__()
self.attn = OptimizedAttention(n_emb, n_head, context_len, dropout)
self.ffn = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb, bias=False),
ReLUSquared(),
nn.Linear(4 * n_emb, n_emb, bias=False),
nn.Dropout(dropout)
)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, use_cache: bool = False) -> torch.Tensor:
x = x + self.attn(norm(x), mask, use_cache=use_cache)
x = x + self.ffn(norm(x))
return x
# ┌─────────────────────────────────────────────────────────────────────────────┐
# │ GPT MODEL ARCHITECTURE │
# ├─────────────────────────────────────────────────────────────────────────────┤
# │ │
# │ Input: Token IDs [B, T] │
# │ ↓ │
# │ ┌─────────────────────────────────────────────────────────────────────────┤
# │ │ Token Embedding (wte): vocab_size → n_emb │
# │ │ [B, T] → [B, T, 384] │
# │ └─────────────────────────────────────────────────────────────────────────┤
# │ ↓ │
# │ ┌─────────────────────────────────────────────────────────────────────────┤
# │ │ Dropout Layer │
# │ └─────────────────────────────────────────────────────────────────────────┤
# │ ↓ │
# │ ┌─────────────────────────────────────────────────────────────────────────┤
# │ │ TransformerBlock #1: │
# │ │ ├─ RMS Norm → Multi-Head Attention (6 heads) → Residual │
# │ │ └─ RMS Norm → FFN (384→1536→384) → Residual │
# │ └─────────────────────────────────────────────────────────────────────────┤
# │ ↓ │
# │ ┌─────────────────────────────────────────────────────────────────────────┤
# │ │ TransformerBlock #2-6: (same structure) │
# │ └─────────────────────────────────────────────────────────────────────────┤
# │ ↓ │
# │ ┌─────────────────────────────────────────────────────────────────────────┤
# │ │ Final RMS Norm │
# │ └─────────────────────────────────────────────────────────────────────────┤
# │ ↓ │
# │ ┌─────────────────────────────────────────────────────────────────────────┤
# │ │ Language Modeling Head: n_emb → vocab_size │
# │ │ [B, T, 384] → [B, T, 1024] (tied weights with embedding) │
# │ └─────────────────────────────────────────────────────────────────────────┤
# │ ↓ │
# │ Output: Logits over vocabulary [B, T, vocab_size] │
# │ │
# └─────────────────────────────────────────────────────────────────────────────┘
class GPT(nn.Module):
def __init__(self, vocab_size: int, n_layer: int = 6, n_head: int = 6,
n_emb: int = 384, context_len: int = 256, dropout: float = 0.1):
super().__init__()
self.vocab_size = vocab_size
self.context_len = context_len
self.n_layer = n_layer
self.n_head = n_head
self.n_emb = n_emb
self.wte = nn.Embedding(vocab_size, n_emb)
self.drop = nn.Dropout(dropout)
self.layers = nn.ModuleList([
TransformerBlock(n_emb, n_head, context_len, dropout)
for _ in range(n_layer)
])
self.head = nn.Linear(n_emb, vocab_size, bias=False)
del self.head.weight
self.head.weight = self.wte.weight
self.apply(self._init_weights)
self.register_buffer("causal_mask", torch.triu(
torch.ones(context_len, context_len), diagonal=1
).bool())
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx: torch.Tensor, targets=None, use_cache: bool = False) -> torch.Tensor:
B, T = idx.shape
tok_emb = self.wte(idx)
x = self.drop(tok_emb)
mask = self.causal_mask[:T, :T] if T <= self.context_len else None
for layer in self.layers:
x = layer(x, mask, use_cache=use_cache)
x = norm(x)
if targets is not None:
logits = self.head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
logits = self.head(x[:, [-1], :])
loss = None
return logits, loss
def init_kv_caches(self, batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16):
for layer in self.layers:
layer.attn.init_kv_cache(batch_size, max_seq_len, dtype)
def reset_kv_caches(self):
for layer in self.layers:
layer.attn.reset_kv_cache()
def disable_kv_caches(self):
for layer in self.layers:
layer.attn.disable_kv_cache()
def get_num_params(self, non_embedding=True):
n_params = sum(p.numel() for p in self.parameters())
return n_params
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
param_dict = {pn: p for pn, p in self.named_parameters()}
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device_type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
print(f"using fused AdamW: {use_fused}")
return optimizer
def estimate_mfu(self, fwdbwd_per_iter, dt):
N = self.get_num_params()
L, H, Q, T = self.n_layer, self.n_head, self.n_emb//self.n_head, self.context_len
flops_per_token = 6*N + 12*L*H*Q*T
flops_per_fwdbwd = flops_per_token * T
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
flops_achieved = flops_per_iter * (1.0/dt)
flops_promised = 312e12
mfu = flops_achieved / flops_promised
return mfu
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, use_cache=True):
device = idx.device
B, T = idx.shape
if use_cache:
self.init_kv_caches(B, T + max_new_tokens, dtype=idx.dtype)
generated = idx
for _ in range(max_new_tokens):
if use_cache and generated.shape[1] > T:
logits, _ = self(generated[:, -1:], use_cache=True)
else:
idx_cond = generated if generated.shape[1] <= self.context_len else generated[:, -self.context_len:]
logits, _ = self(idx_cond, use_cache=use_cache)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
generated = torch.cat((generated, idx_next), dim=1)
if use_cache:
self.disable_kv_caches()
return generated
# Global tokenizer instance to avoid reloading
_tokenizer = None
def check_tokenizer_exists(vocab_size=1024, data_root="/data"):
"""Check if the custom tokenizer exists and provide instructions if not."""
tokenizer_path = os.path.join(data_root, "tokenizers", f"shakespeare-bpe-{vocab_size}.json")
if not os.path.exists(tokenizer_path):
print(f"\n{'='*60}")
print(f"ERROR: Custom tokenizer not found!")
print(f"{'='*60}")
print(f"Expected tokenizer at: {tokenizer_path}")
print(f"\nTo create the tokenizer, run:")
print(f" modal run train_tokenizer_modal.py::train_bpe_tokenizer")
print(f"\nOr for multiple vocab sizes:")
print(f" modal run train_tokenizer_modal.py::vocab_size_grid_search")
print(f"{'='*60}\n")
return False
return True
def load_custom_tokenizer(vocab_size=1024, data_root="/data"):
"""Load the custom BPE tokenizer from the Modal volume."""
global _tokenizer
if _tokenizer is None:
from tokenizers import Tokenizer
tokenizer_path = os.path.join(data_root, "tokenizers", f"shakespeare-bpe-{vocab_size}.json")
if not check_tokenizer_exists(vocab_size, data_root):
raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}")
_tokenizer = Tokenizer.from_file(tokenizer_path)
print(f"Loaded custom tokenizer from {tokenizer_path}")
return _tokenizer
def decode_tokens(tokens):
"""Decode tokens using the custom tokenizer."""
tokenizer = load_custom_tokenizer(vocab_size=CONFIG.get("vocab_size", 1024))
return tokenizer.decode(tokens)
def encode_tokens(text):
"""Encode text using the custom tokenizer."""
tokenizer = load_custom_tokenizer(vocab_size=CONFIG.get("vocab_size", 1024))
encoding = tokenizer.encode(text)
return encoding.ids
def ensure_shakespeare_data_tokens(data_root="/data"):
vocab_size = CONFIG.get("vocab_size", 1024)
data_dir = os.path.join(data_root, f'shakespeare_tokens_bpe{vocab_size}')
os.makedirs(data_dir, exist_ok=True)
# Check if tokenized data already exists
train_path = os.path.join(data_dir, 'train.bin')
val_path = os.path.join(data_dir, 'val.bin')
if os.path.exists(train_path) and os.path.exists(val_path):
print(f"Tokenized data already exists in {data_dir}")
return data_dir
input_file_path = os.path.join(data_dir, 'input.txt')
if not os.path.exists(input_file_path):
data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
response = requests.get(data_url)
with open(input_file_path, 'w') as f:
f.write(response.text)
# Load custom tokenizer
tokenizer = load_custom_tokenizer(vocab_size=vocab_size, data_root=data_root)
with open(input_file_path, 'r') as f:
data = f.read()
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]
# Encode with custom tokenizer
train_encoding = tokenizer.encode(train_data)
val_encoding = tokenizer.encode(val_data)
train_ids = np.array(train_encoding.ids, dtype=np.uint16)
val_ids = np.array(val_encoding.ids, dtype=np.uint16)
train_ids.tofile(train_path)
val_ids.tofile(val_path)
print(f"Token data saved with BPE-{vocab_size}: train={len(train_ids)}, val={len(val_ids)}")
return data_dir
def train():
cfg = CONFIG
ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
# Set longer timeout for short training runs
import datetime
timeout_minutes = 30 # Increased from default 10 minutes
init_process_group(backend=cfg['backend'], timeout=datetime.timedelta(minutes=timeout_minutes))
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
assert 'cuda' in device, "this script requires a GPU to run"
master_process = ddp_rank == 0
seed_offset = ddp_rank
assert cfg['gradient_accumulation_steps'] % ddp_world_size == 0
gradient_accumulation_steps = cfg['gradient_accumulation_steps'] // ddp_world_size
else:
master_process = True
seed_offset = 0
ddp_world_size = 1
device = cfg['device']
gradient_accumulation_steps = cfg['gradient_accumulation_steps']
tokens_per_iter = gradient_accumulation_steps * ddp_world_size * cfg['batch_size'] * cfg['block_size']
print(f"tokens per iteration will be: {tokens_per_iter:,}")
if master_process:
os.makedirs(cfg['out_dir'], exist_ok=True)
torch.manual_seed(1337 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device_type = 'cuda' if 'cuda' in device else 'cpu'
if device_type == 'cpu':
print("This training script requires a GPU, but is running on CPU.")
print("Exiting...")
sys.exit(1)
vocab_size = cfg.get('vocab_size', 1024)
data_dir = os.path.join("/data" if os.path.exists("/data") else "data", f"shakespeare_tokens_bpe{vocab_size}")
train_data_path = os.path.join(data_dir, 'train.bin')
if os.path.exists(train_data_path):
train_data = np.memmap(train_data_path, dtype=np.uint16, mode='r')
dataset_tokens = len(train_data)
print(f"Training dataset has {dataset_tokens:,} tokens")
# Verify that all tokens are within vocabulary bounds
max_token = np.max(train_data[:min(10000, len(train_data))]) # Check first 10k tokens
if max_token >= vocab_size:
print(f"WARNING: Found token {max_token} >= vocab_size {vocab_size}")
print(f"Data may have been tokenized with a different vocabulary!")
print(f"Expected data in: {data_dir}")
raise ValueError(f"Token {max_token} exceeds vocabulary size {vocab_size}")
if cfg['num_epochs'] is not None:
iterations_per_epoch = dataset_tokens / tokens_per_iter
cfg['max_iters'] = int(math.ceil(cfg['num_epochs'] * iterations_per_epoch))
print(f"For {cfg['num_epochs']} epochs, need {cfg['max_iters']} iterations")
print(f"Each epoch is ~{iterations_per_epoch:.1f} iterations")
if cfg['warmup_iters'] is None:
cfg['warmup_iters'] = max(1, int(0.02 * cfg['max_iters']))
if cfg['lr_decay_iters'] is None:
cfg['lr_decay_iters'] = cfg['max_iters']
if cfg['max_iters'] < 20:
cfg['eval_interval'] = max(1, cfg['max_iters'] // 4)
cfg['log_interval'] = 1
cfg['eval_iters'] = min(5, cfg['eval_iters']) # Reduced from 50 to 5 for short runs
print(f"Adjusted for short run: eval_interval={cfg['eval_interval']}, log_interval={cfg['log_interval']}, eval_iters={cfg['eval_iters']}")
if cfg['max_iters'] < 10:
cfg['decay_lr'] = False
cfg['warmup_iters'] = 0
print("Disabled learning rate decay for very short run")
del train_data
else:
if cfg['max_iters'] is None:
raise ValueError("Cannot calculate max_iters: training data not found and max_iters not specified")
def get_batch(split):
batch_data_dir = data_dir
if split == 'train':
data = np.memmap(os.path.join(batch_data_dir, 'train.bin'), dtype=np.uint16, mode='r')
else:
data = np.memmap(os.path.join(batch_data_dir, 'val.bin'), dtype=np.uint16, mode='r')
ix = torch.randint(len(data) - cfg['block_size'], (cfg['batch_size'],))
x = torch.stack([torch.from_numpy((data[i:i+cfg['block_size']]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i+1:i+1+cfg['block_size']]).astype(np.int64)) for i in ix])
if device_type == 'cuda':
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
else:
x, y = x.to(device), y.to(device)
return x, y
iter_num = 0
best_val_loss = 1e9
meta_vocab_size = cfg['vocab_size'] # Use custom tokenizer vocab size
print(f"Using custom BPE vocab_size = {meta_vocab_size}")
model_args = dict(
n_layer=cfg['n_layer'],
n_head=cfg['n_head'],
n_embd=cfg['n_embd'],
block_size=cfg['block_size'],
bias=cfg['bias'],
vocab_size=meta_vocab_size if meta_vocab_size is not None else 50304,
dropout=cfg['dropout']
)
if cfg['init_from'] == 'scratch':
print("Initializing a new model from scratch")
model = GPT(
vocab_size=model_args['vocab_size'],
n_layer=model_args['n_layer'],
n_head=model_args['n_head'],
n_emb=model_args['n_embd'],
context_len=model_args['block_size'],
dropout=model_args['dropout']
)
elif cfg['init_from'] == 'resume':
print(f"Resuming training from {cfg['out_dir']}")
ckpt_path = os.path.join(cfg['out_dir'], 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
checkpoint_model_args = checkpoint['model_args']
for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
model_args[k] = checkpoint_model_args[k]
model = GPT(
vocab_size=model_args['vocab_size'],
n_layer=model_args['n_layer'],
n_head=model_args['n_head'],
n_emb=model_args['n_embd'],
context_len=model_args['block_size'],
dropout=model_args['dropout']
)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
iter_num = checkpoint['iter_num']
best_val_loss = checkpoint['best_val_loss']
model.to(device)
scaler = torch.amp.GradScaler('cuda', enabled=(cfg['dtype'] == 'float16'))
optimizer = model.configure_optimizers(cfg['weight_decay'], cfg['learning_rate'],
(cfg['beta1'], cfg['beta2']), 'cuda')
if cfg['init_from'] == 'resume' and 'optimizer' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
checkpoint = None
if cfg['compile']:
print("compiling the model... (takes a ~minute)")
unoptimized_model = model
model = torch.compile(model)
if ddp:
model = DDP(model, device_ids=[ddp_local_rank])
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[cfg['dtype']]
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype)
@torch.no_grad()
def estimate_loss():
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(cfg['eval_iters'])
for k in range(cfg['eval_iters']):
X, Y = get_batch(split)
with ctx:
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out
def get_lr(it):
if it < cfg['warmup_iters']:
return cfg['learning_rate'] * (it + 1) / (cfg['warmup_iters'] + 1)
if it > cfg['lr_decay_iters']:
return cfg['min_lr']
decay_ratio = (it - cfg['warmup_iters']) / (cfg['lr_decay_iters'] - cfg['warmup_iters'])
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return cfg['min_lr'] + coeff * (cfg['learning_rate'] - cfg['min_lr'])
if cfg['wandb_log'] and master_process:
import wandb
wandb.init(project=cfg['wandb_project'], name=cfg['wandb_run_name'], config=cfg)
X, Y = get_batch('train')
t0 = time.time()
local_iter_num = 0
raw_model = model.module if ddp else model
running_mfu = -1.0
while True:
lr = get_lr(iter_num) if cfg['decay_lr'] else cfg['learning_rate']
for param_group in optimizer.param_groups:
param_group['lr'] = lr
if iter_num % cfg['eval_interval'] == 0 and master_process:
losses = estimate_loss()
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
if cfg['wandb_log']:
wandb.log({
"iter": iter_num,
"train/loss": losses['train'],
"val/loss": losses['val'],
"lr": lr,
"mfu": running_mfu*100,
})
if losses['val'] < best_val_loss or cfg['always_save_checkpoint']:
best_val_loss = losses['val']
if iter_num > 0:
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'model_args': model_args,
'iter_num': iter_num,
'best_val_loss': best_val_loss,
'config': cfg,
}
print(f"saving checkpoint to {cfg['out_dir']}")
torch.save(checkpoint, os.path.join(cfg['out_dir'], 'ckpt.pt'))
if iter_num == 0 and cfg['eval_only']:
break
for micro_step in range(gradient_accumulation_steps):
if ddp:
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
with ctx:
logits, loss = model(X, Y)
loss = loss / gradient_accumulation_steps
X, Y = get_batch('train')
scaler.scale(loss).backward()
if cfg['grad_clip'] != 0.0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg['grad_clip'])
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
t1 = time.time()
dt = t1 - t0
t0 = t1
if iter_num % cfg['log_interval'] == 0 and master_process:
lossf = loss.item() * gradient_accumulation_steps
if local_iter_num >= 5:
mfu = raw_model.estimate_mfu(cfg['batch_size'] * gradient_accumulation_steps, dt)
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
iter_num += 1
local_iter_num += 1
if iter_num > cfg['max_iters']:
break
if ddp:
destroy_process_group()
app = modal.App("nanogpt-training-2")
image = (
modal.Image.debian_slim(python_version="3.11")
.pip_install(
"numpy",
"torch",
"wandb",
"requests",
"tokenizers"
)
)
volume = modal.Volume.from_name("nanogpt-data", create_if_missing=True)
@app.function(
gpu=f"{GPU_TYPE}:{N_GPUS}",
volumes={"/data": volume},
timeout=60 * 60 * 6,
image=image,
secrets=[modal.Secret.from_name("wandb-secret")] if CONFIG.get("wandb_log", False) else [],
)
def train_modal():
cfg = CONFIG
print(f"Starting Modal training with {N_GPUS} {GPU_TYPE} GPUs")
print(f"Dataset type: {cfg['dataset_type']}")
# Check if custom tokenizer exists before proceeding
if not check_tokenizer_exists(cfg['vocab_size'], "/data"):
raise RuntimeError("Cannot proceed without tokenizer. Please create it first.")
ensure_shakespeare_data_tokens("/data")
script_path = Path(__file__)
script_content = script_path.read_text()
temp_script = "/tmp/train_modal.py"
Path(temp_script).write_text(script_content)
cmd = [
"torchrun",
f"--nproc-per-node={N_GPUS}",
temp_script,
]
print(f"Running command: {' '.join(cmd)}")
os.chdir("/tmp")
subprocess.run(cmd, check=True)
print("Training completed successfully!")
return "Training completed"
if __name__ == "__main__":
if "RANK" in os.environ:
train()
else:
print("This script should be run with torchrun or through Modal")
print("Examples:")
print(" Local: torchrun --nproc-per-node=4 train_modal_standalone.py")
print(" Modal: modal run train_modal_standalone.py::train_modal")
sys.exit(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment