Skip to content

Instantly share code, notes, and snippets.

@kouroshHakha
Last active August 22, 2023 04:53
Show Gist options
  • Save kouroshHakha/0f831afaba1f6df5312849abd89090f5 to your computer and use it in GitHub Desktop.
Save kouroshHakha/0f831afaba1f6df5312849abd89090f5 to your computer and use it in GitHub Desktop.
import torch
import torch.nn.functional as F
# Create random inputs for testing
batch_size = 128
seq_length = 512
embed_dim = 64
enable_math = False
query = torch.rand(batch_size, seq_length, embed_dim, device="cuda", requires_grad=True)
key = torch.rand(batch_size, seq_length, embed_dim, device="cuda", requires_grad=True)
value = torch.rand(batch_size, seq_length, embed_dim, device="cuda", requires_grad=True)
# Creating a 3D attention mask
# Here, I'll create a simple 3D mask where every packed sequence has a full 2D mask
# Meaning, every token attends to every other token in its sequence
# You can adjust this mask based on your packing logic
attn_mask = torch.ones(batch_size, seq_length, seq_length, device="cuda", dtype=torch.bool)
half_batch = batch_size // 2
for i in range(half_batch, batch_size):
attn_mask[i, :seq_length//2, seq_length//2:] = False
# Creating the causal mask
causal_mask = torch.tril(torch.ones(seq_length, seq_length, dtype=torch.bool, device="cuda"))
# Combine the causal mask with the attn_mask using logical AND operation
combined_mask = attn_mask & causal_mask
with torch.backends.cuda.sdp_kernel(enable_math=enable_math):
output = F.scaled_dot_product_attention(query, key, value, attn_mask=combined_mask)
empirical_mask = torch.zeros_like(attn_mask, dtype=torch.bool)
# Calculating empirical mask
for i in range(batch_size):
for j in range(seq_length):
if output.requires_grad:
# Zero out gradients from previous iterations
query.grad = None
key.grad = None
value.grad = None
output[i, j].sum().backward(retain_graph=True)
# If gradient for value is non-zero, that location was attended to
empirical_mask[i, j] = (value.grad[i].abs() > 1e-6).any(dim=-1)
# Check if the empirical mask matches the expected mask
is_same = torch.all(combined_mask == empirical_mask)
print(f"The empirical mask matches the expected mask: {is_same.item()}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment