Skip to content

Instantly share code, notes, and snippets.

@deshwalmahesh
Created June 4, 2025 07:10
Show Gist options
  • Save deshwalmahesh/7efabfcb854fe4ae65e9df6fe474c9e1 to your computer and use it in GitHub Desktop.
Save deshwalmahesh/7efabfcb854fe4ae65e9df6fe474c9e1 to your computer and use it in GitHub Desktop.
Intuition of Multi Head Fused Attention with print statements
"""
All of the code is from: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb
"""
import torch, sys
from torch import nn
from colorama import Fore, Style, init as color_init
import numpy as np
# np.set_printoptions()
# ---------- terminal styling helpers ----------
color_init(autoreset=True)
BOLD, ITALIC, RESET = Style.BRIGHT, '\033[3m', Style.RESET_ALL
def explain(label, tensor, why="", show=False, rows=10):
head = f"{Fore.YELLOW}{label}:{RESET} "
shape = f"{ITALIC}{Fore.CYAN}{tuple(tensor.shape)}{RESET} "
print(f"{head}{shape} {Fore.GREEN} <- {why}{RESET} \n")
if show:
with torch.no_grad():
# peek at first attention head & first few tokens
preview = tensor.cpu().numpy().round(2)
torch.set_printoptions(precision=2, linewidth=150, sci_mode=False)
np.set_printoptions(precision=2, linewidth=150)
print(f"Tensor Looks like:\n{Fore.MAGENTA}{preview}{RESET}\n\n")
# ---------- multi-head attention ----------
class MHAPyTorchScaledDotProduct(nn.Module):
def __init__(self, d_in, d_out, num_heads, dropout=0.0, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "embed_dim must divide num_heads"
self.h, self.d_h = num_heads, d_out // num_heads
self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias) # 1 matmul → QKV
self.proj = nn.Linear(d_out, d_out) # W_O
self.dropout = dropout
def forward(self, x):
B, N, _ = x.shape
explain("Input shape: ", x, "[BATCH, num_tokens, embed_dim] | These are raw input embeddings or output from previous block", False)
qkv = self.qkv(x)
explain("Output of 'self.qkv(INPUT)' shape: ", qkv,
"We multiplied INPUT with EACH of Q, K, V 'FUSED' matrix -> [BATCH, num_tokens, 3 * embed_dim]. '3' because one for each of Q, K, V matrix", True)
# separate the 3 blocks and split into heads
qkv = qkv.view(B, N, 3, self.h, self.d_h)
explain("After Reshaping: ", qkv,
"[BATCH, num_tokens, 3, num_heads, head_dim] | Each operation (be it reshape or permute) has it's own backprop graph.\n\t\t\t\t\tLinear Algebra property makes one big head act as horizontally stacked small heads: W = [W1 <CONCAT> W2]", False)
qkv = qkv.permute(2, 0, 3, 1, 4)
explain("After Rearranging of dimensions: ", qkv,
"moved '3' to front so we can unpack EACH as (Q,K,V) individually. It'll act as if 3 different Linear Layers are being used here")
Q, K, V = qkv # unpack
explain("After Unpacking shape (only one of the 'Q' shown here): ", Q, "Each of the Q|K|V looks exactly like this: [BATCH, num_heads, num_tokens, head_dim]", True)
ctx = nn.functional.scaled_dot_product_attention(
Q, K, V, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
explain("After Attention shape: ", ctx,
"It used above ALL of Q|K|V to compute 'softmax(Q @ K_Transpose/root(d_h)) @ V'. So now during backprop, they'll be updated like INDIVIDUAL")
ctx = ctx.transpose(1, 2).reshape(B, N, self.h * self.d_h)
explain("After Concat shape: ", ctx,
"We reversed the split which we did earlier by merging heads back into model dim (d_model)")
out = self.proj(ctx)
explain("Final Block Output shape: ", out,
"We applied final Linear matrix (W_O) to get final output. It's same as Input shape | [BATCH, num_tokens, embed_dim]", True)
return out
# --------------- quick demo ----------------
if __name__ == "__main__":
torch.manual_seed(0)
x = torch.randn(1, 5, 4)
mha = MHAPyTorchScaledDotProduct(4, 4, num_heads=2)
y = mha(x)
@deshwalmahesh
Copy link
Author

Screenshot 2025-06-04 at 12 15 52 PM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment