Last active
August 22, 2023 04:53
-
-
Save kouroshHakha/0f831afaba1f6df5312849abd89090f5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
# Create random inputs for testing | |
batch_size = 128 | |
seq_length = 512 | |
embed_dim = 64 | |
enable_math = False | |
query = torch.rand(batch_size, seq_length, embed_dim, device="cuda", requires_grad=True) | |
key = torch.rand(batch_size, seq_length, embed_dim, device="cuda", requires_grad=True) | |
value = torch.rand(batch_size, seq_length, embed_dim, device="cuda", requires_grad=True) | |
# Creating a 3D attention mask | |
# Here, I'll create a simple 3D mask where every packed sequence has a full 2D mask | |
# Meaning, every token attends to every other token in its sequence | |
# You can adjust this mask based on your packing logic | |
attn_mask = torch.ones(batch_size, seq_length, seq_length, device="cuda", dtype=torch.bool) | |
half_batch = batch_size // 2 | |
for i in range(half_batch, batch_size): | |
attn_mask[i, :seq_length//2, seq_length//2:] = False | |
# Creating the causal mask | |
causal_mask = torch.tril(torch.ones(seq_length, seq_length, dtype=torch.bool, device="cuda")) | |
# Combine the causal mask with the attn_mask using logical AND operation | |
combined_mask = attn_mask & causal_mask | |
with torch.backends.cuda.sdp_kernel(enable_math=enable_math): | |
output = F.scaled_dot_product_attention(query, key, value, attn_mask=combined_mask) | |
empirical_mask = torch.zeros_like(attn_mask, dtype=torch.bool) | |
# Calculating empirical mask | |
for i in range(batch_size): | |
for j in range(seq_length): | |
if output.requires_grad: | |
# Zero out gradients from previous iterations | |
query.grad = None | |
key.grad = None | |
value.grad = None | |
output[i, j].sum().backward(retain_graph=True) | |
# If gradient for value is non-zero, that location was attended to | |
empirical_mask[i, j] = (value.grad[i].abs() > 1e-6).any(dim=-1) | |
# Check if the empirical mask matches the expected mask | |
is_same = torch.all(combined_mask == empirical_mask) | |
print(f"The empirical mask matches the expected mask: {is_same.item()}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment