Last active
June 18, 2023 13:42
-
-
Save mxchinegod/d1f9578216ea352c9c7ff8d8ba573f48 to your computer and use it in GitHub Desktop.
A token matrix visualization representing the importance a token has in multi-headed attention layers in transformers like GPT-2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import GPT2Tokenizer, GPT2Model | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from matplotlib.animation import FuncAnimation | |
def get_model_info(model): | |
num_layers = model.config.n_layer | |
num_heads = model.config.n_head | |
hidden_dim = model.config.n_embd | |
head_size = hidden_dim // num_heads | |
return num_layers, num_heads, hidden_dim, head_size | |
model_name = "gpt2" # Replace with your desired GPT-2 model name | |
model = GPT2Model.from_pretrained(model_name) | |
print(get_model_info(model)) | |
def get_activated_neurons_with_weights(model, tokenizer, input_text): | |
input_ids = tokenizer.encode(input_text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(input_ids, output_attentions=True) | |
attentions = outputs.attentions | |
# Get the attention weights for the last layer | |
last_layer_attention = attentions[-1][0] | |
# Find the neurons with the highest average attention across all tokens | |
average_attention = last_layer_attention.mean(dim=0) | |
activated_neurons = torch.argsort(average_attention, descending=True) | |
# Get the attention weights for the activated neurons | |
activated_neuron_weights = last_layer_attention[:, activated_neurons] | |
return activated_neurons, activated_neuron_weights, tokenizer.decode(input_ids[0]) | |
# Example usage | |
model_name = "gpt2" # Replace with your desired GPT-2 model name | |
input_text = "I dont know?" | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
model = GPT2Model.from_pretrained(model_name) | |
# Generate the initial chunk of text and extract neuron activations | |
activated_neurons, neuron_weights, generated_text = get_activated_neurons_with_weights(model, tokenizer, input_text) | |
# Store the generated text and neuron activations | |
memory = { | |
'generated_text': generated_text, | |
'activated_neurons': activated_neurons, | |
'neuron_weights': neuron_weights | |
} | |
# Retrieve the stored memory | |
retrieved_generated_text = memory['generated_text'] | |
retrieved_activated_neurons = memory['activated_neurons'] | |
retrieved_neuron_weights = memory['neuron_weights'] | |
# Create a figure to display the animation | |
fig = plt.figure() | |
# Define a function to update the animation | |
def update(frame): | |
# Clear the previous plot | |
plt.clf() | |
# Plot the current frame | |
im = plt.imshow(np.array(retrieved_neuron_weights[0][frame]), cmap='hot', interpolation='nearest') | |
plt.colorbar(im) | |
# Add labels for each token | |
tokens = tokenizer.tokenize(retrieved_generated_text) | |
plt.xticks(range(len(tokens)), tokens, rotation='vertical') | |
plt.yticks(range(len(tokens)), tokens) | |
# Create the animation | |
anim = FuncAnimation(fig, update, frames=len(retrieved_neuron_weights[0]), interval=500, repeat=True) | |
# Display the animation | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment