Created
April 18, 2025 11:29
-
-
Save macleginn/93ef9c3dae1cd2db0f921e422d3a2ead to your computer and use it in GitHub Desktop.
An example of extracting attention weights from a transformer encoder in PyTorch that works with PyTorch 2.6 with some explanatory comments.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from collections import defaultdict | |
# --- Hook Implementation --- | |
# Dictionary to store attention weights during the forward pass | |
# Structure: {layer_index: attention_weights_tensor} | |
# The tensor shape will be (batch_size, num_heads, seq_len, seq_len) | |
# This dictionary is captured in the definition below. | |
attention_weights_store = defaultdict(list) | |
# This function is a bit involved. We want to register a hook with a component | |
# (module) of a transformer model. A hook is a function that will be called when | |
# either the "forward" or the "backward" methods are called on the component. We | |
# will ask for our hook to be called on the forward pass. | |
# A hook needs to receive the module, its input, and its output as arguments, | |
# this is a convention defined by PyTorch. If we wanted to add a single hook, | |
# we could have defined a single function called hook, selected a module from | |
# a model and called the .register_forward_hook method on this module (see | |
# below). However, here we want to add hooks to all attention layers in a | |
# transformer, and we may not know in advance how many of those we wil have. | |
# In order to do this automatically, we define a "factory function": a function | |
# that returns a function on demand. The functions that our factory function, | |
# save_attention_weights_hook_generator, will return will be indexed by layer | |
# ids. All these generated functions will take their input module, which in this | |
# is self attention, and dump its outputs, attention weights, to the dictionary | |
# we defined above, attention_weights_store, which acts a global variable. | |
# (We could also ask save_attention_weights_hook_generator to create and return | |
# a separate dictionary for storing weights for each layer.) These on-demand | |
# functions will be associated with the model's layers. They will be called | |
# automatically on the forward pass, and we will never use them directly. | |
# However, we will be able to access their outputs through | |
# attention_weights_store. | |
def save_attention_weights_hook_generator(layer_idx): | |
""" | |
Factory function to create a hook that saves attention weights. | |
Captures the layer index using closure. | |
""" | |
def hook(module, input, output): | |
# The output of MultiheadAttention when need_weights=True is a tuple: | |
# (attn_output, attn_output_weights) | |
# We configured average_attn_weights=False, so attn_output_weights | |
# has shape (batch_size, num_heads, seq_len, seq_len) | |
if isinstance(output, tuple) and len(output) > 1 and output[1] is not None: | |
# Store the detached weights on CPU | |
attention_weights_store[layer_idx] = output[1].detach().cpu() | |
else: | |
raise ValueError(f'Could not find attention weights in the output of the self-attention module {layer_idx}!') | |
return hook | |
# --- Configuration --- | |
D_MODEL = 128 # Embedding dimension | |
NHEAD = 2 # Number of attention heads | |
NUM_ENCODER_LAYERS = 2 # Number of Transformer encoder layers | |
DIM_FEEDFORWARD = 512 # Dimension of the feedforward network | |
DROPOUT = 0.1 # Dropout rate | |
SEQ_LEN = 10 # Example sequence length | |
BATCH_SIZE = 4 # Example batch size | |
# --- Model Definition --- | |
class MyTransformerEncoderLayer(nn.TransformerEncoderLayer): | |
""" | |
We need to subclass the standard transformer encoder layer in order | |
to force the attention sublayer to return attention weights. Due to recent | |
changes in the implementation, the standard way of doing this doesn't work | |
any more. | |
""" | |
def _sa_block(self, x, attn_mask, key_padding_mask, **kwargs): | |
x, attn_weights = self.self_attn( | |
x, x, x, | |
attn_mask=attn_mask, | |
key_padding_mask=key_padding_mask, | |
# The crucial bit: | |
need_weights=True, | |
average_attn_weights=False, # Ensures per-head weights | |
**kwargs | |
) | |
return self.dropout1(x) | |
class TransformerEncoderWithHooks(nn.Module): | |
def __init__(self, d_model, nhead, num_layers, dim_feedforward, dropout): | |
super().__init__() | |
self.d_model = d_model | |
self.nhead = nhead | |
self.num_layers = num_layers | |
# Create the encoder layers | |
self.layers = nn.ModuleList() | |
for i in range(num_layers): | |
layer = MyTransformerEncoderLayer( | |
d_model=d_model, | |
nhead=nhead, | |
dim_feedforward=dim_feedforward, | |
dropout=dropout, | |
batch_first=True # Make input/output (batch, seq, feature) | |
) | |
# Access the MultiheadAttention module within the layer | |
mha_module = layer.self_attn | |
# In older version of PyTorch, we could just set the parameters | |
# of the standard encoder layer, but now they hard coded them | |
# for efficiency: | |
# Ask the module to return attention weights | |
# Set average_attn_weights to False to get per-head weights | |
# mha_module.need_weights = True | |
# mha_module.average_attn_weights = False | |
# Register the forward hook on the MultiheadAttention module | |
# Use the factory to create a hook on the fly and capture the | |
# current layer index 'i' | |
mha_module.register_forward_hook( | |
save_attention_weights_hook_generator(i)) | |
self.layers.append(layer) | |
self.norm = nn.LayerNorm(d_model) # Final layer norm, optional | |
def forward(self, src, src_mask=None, src_key_padding_mask=None): | |
""" | |
Forward pass for the encoder. | |
Args: | |
src (Tensor): Input sequence, shape (batch_size, seq_len, d_model) | |
due to batch_first=True. | |
src_mask (Tensor, optional): Additive mask for self-attention. | |
Shape (seq_len, seq_len). | |
src_key_padding_mask (Tensor, optional): Mask indicating padding tokens. | |
Shape (batch_size, seq_len). | |
Returns: | |
Tensor: Output sequence, shape (batch_size, seq_len, d_model). | |
""" | |
output = src | |
# Clear the global store before the forward pass | |
global attention_weights_store | |
attention_weights_store.clear() | |
for layer in self.layers: | |
output = layer(output, src_mask=src_mask, | |
src_key_padding_mask=src_key_padding_mask) | |
output = self.norm(output) | |
return output | |
# --- Example Usage --- | |
# Instantiate the model | |
model = TransformerEncoderWithHooks( | |
d_model=D_MODEL, | |
nhead=NHEAD, | |
num_layers=NUM_ENCODER_LAYERS, | |
dim_feedforward=DIM_FEEDFORWARD, | |
dropout=DROPOUT | |
) | |
# Set model to evaluation mode (disables dropout, etc.) | |
model.eval() | |
# Create dummy input data | |
# Shape: (batch_size, seq_len, d_model) because batch_first=True | |
dummy_input = torch.rand(BATCH_SIZE, SEQ_LEN, D_MODEL) | |
print(f"Input shape: {dummy_input.shape}") | |
# Perform the forward pass | |
with torch.no_grad(): # No need to track gradients for inference | |
output = model(dummy_input) | |
print(f"Output shape: {output.shape}") | |
# Check the stored attention weights | |
print("\n--- Stored Attention Weights ---") | |
if not attention_weights_store: | |
print("No attention weights were captured.") | |
else: | |
for layer_idx, weights in attention_weights_store.items(): | |
# Expected shape: (batch_size, num_heads, seq_len, seq_len) | |
print(f"Layer {layer_idx}: Weights = {weights.shape}") | |
# Check that attention weights in the first input sum to 1. | |
print(f"Weight row sums = {weights[0].sum(dim=-1).flatten()}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment