Created
June 4, 2025 07:10
-
-
Save deshwalmahesh/7efabfcb854fe4ae65e9df6fe474c9e1 to your computer and use it in GitHub Desktop.
Intuition of Multi Head Fused Attention with print statements
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
All of the code is from: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb | |
""" | |
import torch, sys | |
from torch import nn | |
from colorama import Fore, Style, init as color_init | |
import numpy as np | |
# np.set_printoptions() | |
# ---------- terminal styling helpers ---------- | |
color_init(autoreset=True) | |
BOLD, ITALIC, RESET = Style.BRIGHT, '\033[3m', Style.RESET_ALL | |
def explain(label, tensor, why="", show=False, rows=10): | |
head = f"{Fore.YELLOW}{label}:{RESET} " | |
shape = f"{ITALIC}{Fore.CYAN}{tuple(tensor.shape)}{RESET} " | |
print(f"{head}{shape} {Fore.GREEN} <- {why}{RESET} \n") | |
if show: | |
with torch.no_grad(): | |
# peek at first attention head & first few tokens | |
preview = tensor.cpu().numpy().round(2) | |
torch.set_printoptions(precision=2, linewidth=150, sci_mode=False) | |
np.set_printoptions(precision=2, linewidth=150) | |
print(f"Tensor Looks like:\n{Fore.MAGENTA}{preview}{RESET}\n\n") | |
# ---------- multi-head attention ---------- | |
class MHAPyTorchScaledDotProduct(nn.Module): | |
def __init__(self, d_in, d_out, num_heads, dropout=0.0, qkv_bias=False): | |
super().__init__() | |
assert d_out % num_heads == 0, "embed_dim must divide num_heads" | |
self.h, self.d_h = num_heads, d_out // num_heads | |
self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias) # 1 matmul → QKV | |
self.proj = nn.Linear(d_out, d_out) # W_O | |
self.dropout = dropout | |
def forward(self, x): | |
B, N, _ = x.shape | |
explain("Input shape: ", x, "[BATCH, num_tokens, embed_dim] | These are raw input embeddings or output from previous block", False) | |
qkv = self.qkv(x) | |
explain("Output of 'self.qkv(INPUT)' shape: ", qkv, | |
"We multiplied INPUT with EACH of Q, K, V 'FUSED' matrix -> [BATCH, num_tokens, 3 * embed_dim]. '3' because one for each of Q, K, V matrix", True) | |
# separate the 3 blocks and split into heads | |
qkv = qkv.view(B, N, 3, self.h, self.d_h) | |
explain("After Reshaping: ", qkv, | |
"[BATCH, num_tokens, 3, num_heads, head_dim] | Each operation (be it reshape or permute) has it's own backprop graph.\n\t\t\t\t\tLinear Algebra property makes one big head act as horizontally stacked small heads: W = [W1 <CONCAT> W2]", False) | |
qkv = qkv.permute(2, 0, 3, 1, 4) | |
explain("After Rearranging of dimensions: ", qkv, | |
"moved '3' to front so we can unpack EACH as (Q,K,V) individually. It'll act as if 3 different Linear Layers are being used here") | |
Q, K, V = qkv # unpack | |
explain("After Unpacking shape (only one of the 'Q' shown here): ", Q, "Each of the Q|K|V looks exactly like this: [BATCH, num_heads, num_tokens, head_dim]", True) | |
ctx = nn.functional.scaled_dot_product_attention( | |
Q, K, V, dropout_p=self.dropout if self.training else 0.0, is_causal=True) | |
explain("After Attention shape: ", ctx, | |
"It used above ALL of Q|K|V to compute 'softmax(Q @ K_Transpose/root(d_h)) @ V'. So now during backprop, they'll be updated like INDIVIDUAL") | |
ctx = ctx.transpose(1, 2).reshape(B, N, self.h * self.d_h) | |
explain("After Concat shape: ", ctx, | |
"We reversed the split which we did earlier by merging heads back into model dim (d_model)") | |
out = self.proj(ctx) | |
explain("Final Block Output shape: ", out, | |
"We applied final Linear matrix (W_O) to get final output. It's same as Input shape | [BATCH, num_tokens, embed_dim]", True) | |
return out | |
# --------------- quick demo ---------------- | |
if __name__ == "__main__": | |
torch.manual_seed(0) | |
x = torch.randn(1, 5, 4) | |
mha = MHAPyTorchScaledDotProduct(4, 4, num_heads=2) | |
y = mha(x) |
Author
deshwalmahesh
commented
Jun 4, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment