Last active
April 8, 2024 09:59
-
-
Save uzl/508c15e6bca21e9775309af4266e29d3 to your computer and use it in GitHub Desktop.
Attention Calculation Methods
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.backends.cuda import sdp_kernel, SDPBackend | |
import time | |
import torch | |
import torch.utils.benchmark as benchmark | |
# check dependency | |
assert torch.cuda.is_available(), 'CUDA is expected.' | |
cu_major, cu_minor = torch.version.cuda.split('.') | |
# assert int(cu_major) >= 11 and int(cu_minor) >= 6, f'Expected CUDA 11.6 and above. But found {cu_major}.{cu_minor}' | |
pt_major, pt_minor = str(torch.__version__).split('.')[:2] | |
assert int(pt_major) >= 2 and int(pt_minor) >= 0, 'Expected Pytorch 2.0 and above.' | |
# = | |
# for reproducible result (but might not fully deterministic) | |
torch.manual_seed(37) | |
torch.cuda.manual_seed_all(37) | |
device = "cuda" | |
# dummy input | |
b_size = 32 | |
seq_len = 1024 | |
num_heads = 32 | |
embd_dim = 32 | |
dtype = torch.float16 | |
query = torch.rand(b_size, num_heads, seq_len, embd_dim, device=device, dtype=dtype) | |
key = torch.rand(b_size, num_heads, seq_len, embd_dim, device=device, dtype=dtype) | |
value = torch.rand(b_size, num_heads, seq_len, embd_dim, device=device, dtype=dtype) | |
print(f"flash_sdp_enabled:\t\t{torch.backends.cuda.flash_sdp_enabled()}") | |
print(f"mem_efficient_sdp_enabled:\t{torch.backends.cuda.mem_efficient_sdp_enabled()}") | |
print(f"math_sdp_enabled:\t\t{torch.backends.cuda.math_sdp_enabled()}") | |
print("_" * 80) | |
# Checking different attention calculation methods | |
#------------------------------------------------- | |
# PyTorch naive implementation defined in C++ | |
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False): | |
res_1 = F.scaled_dot_product_attention(query, key, value) | |
# calculate time | |
t1 = benchmark.Timer( | |
stmt="F.scaled_dot_product_attention(query, key, value)", | |
setup="import torch.nn.functional as F", | |
globals={"query": query, "key": key, "value": value}, | |
) | |
print(f"PyTorch naive attention took:\t\t{round(t1.blocked_autorange().mean * 1e6)} microseconds") | |
# Memory-Efficient Attention | |
# Self-attention Does Not Need O(n2) Memory [https://arxiv.org/abs/2112.05682] | |
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True): | |
res_2 = F.scaled_dot_product_attention(query, key, value) | |
# calculate time | |
t2 = benchmark.Timer( | |
stmt="F.scaled_dot_product_attention(query, key, value)", | |
setup="import torch.nn.functional as F", | |
globals={"query": query, "key": key, "value": value}, | |
) | |
print(f"Memory-Efficient attention took:\t{round(t2.blocked_autorange().mean * 1e6)} microseconds") | |
# FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness | |
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): | |
res_3 = None | |
try: | |
res_3 = F.scaled_dot_product_attention(query, key, value) | |
# calculate time | |
t3 = benchmark.Timer( | |
stmt="F.scaled_dot_product_attention(query, key, value)", | |
setup="import torch.nn.functional as F", | |
globals={"query": query, "key": key, "value": value}, | |
) | |
print(f"Flash attention took:\t{round(t3.blocked_autorange().mean * 1e6)} microseconds") | |
except RuntimeError: | |
print("FlashAttention is not supported in this device") | |
print('\nRESULT:') | |
print(f"PyTorch naive vs Memory-Efficient: is same? --> {torch.allclose(res_1, res_2, rtol=0.001, atol=0.0000001)}") | |
if res_3 is not None: | |
print(f"PyTorch naive vs Flash: is same? --> {torch.allclose(res_1, res_3, rtol=0.001, atol=0.0000001)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment