Created
May 14, 2025 21:13
-
-
Save alexarmbr/4d5f2c4fdd45b98f912d6728a8cdc01d to your computer and use it in GitHub Desktop.
pytorch profiler
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
import torch.profiler | |
def benchmark_forward_pass(q, k, v, num_warmup=10, num_timed_runs=20): | |
""" | |
Benchmarks the forward pass of torch.nn.functional.scaled_dot_product_attention. | |
""" | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
# warm-up steps | |
for _ in range(num_warmup): | |
_ = F.scaled_dot_product_attention(q, k, v) | |
torch.cuda.synchronize() # Ensure warm-up kernels complete | |
start_event.record() | |
for _ in range(num_timed_runs): | |
_ = F.scaled_dot_product_attention(q, k, v) | |
end_event.record() | |
torch.cuda.synchronize() | |
return start_event.elapsed_time(end_event) / num_timed_runs | |
def benchmark_backward_pass(q, k, v, num_warmup=10, num_timed_runs=20): | |
""" | |
Benchmarks the backward pass of torch.nn.functional.scaled_dot_product_attention. | |
""" | |
grad_output = torch.randn_like(q) | |
output = F.scaled_dot_product_attention(q, k, v) | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
# Warm-up steps | |
for _ in range(num_warmup): | |
# Clear gradients from previous iterations | |
q.grad = None | |
k.grad = None | |
v.grad = None | |
# Perform the forward pass to set up the graph for backward | |
output.backward(grad_output, retain_graph=True) # No need to retain graph if output is fresh | |
torch.cuda.synchronize() # Ensure all warm-up operations (including backward) are complete | |
start_event.record() | |
for _ in range(num_timed_runs): | |
# Clear gradients for this specific iteration | |
q.grad = None | |
k.grad = None | |
v.grad = None | |
output.backward(grad_output, retain_graph=True) | |
end_event.record() | |
torch.cuda.synchronize() | |
return start_event.elapsed_time(end_event) / num_timed_runs | |
if __name__ == "__main__": | |
batch_size = 1 | |
num_heads = 24 | |
seq_len = 4096 + 512 | |
head_dim = 128 | |
device = torch.device("cuda") | |
q_fwd,k_fwd,v_fwd = tuple(torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16) for _ in range(3)) | |
with torch.profiler.profile( | |
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], | |
# record_shapes=True, | |
# profile_memory=True, | |
# with_stack=True, | |
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiles/sdpa_benchmark'), | |
) as prof: | |
with torch.profiler.record_function("forward_pass_benchmark"): | |
avg_forward_time = benchmark_forward_pass(q_fwd, k_fwd, v_fwd, num_warmup=10, num_timed_runs=20) | |
q_bwd,k_bwd,v_bwd = tuple(torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16, requires_grad=True) for _ in range(3)) | |
with torch.profiler.record_function("backward_pass_benchmark"): | |
avg_backward_time = benchmark_backward_pass(q_bwd, k_bwd, v_bwd, num_warmup=10, num_timed_runs=20) | |
print(f"Average forward pass time: {avg_forward_time:.3f} ms") | |
print(f"Average backward pass time: {avg_backward_time:.3f} ms") | |
print(prof.key_averages().table(sort_by='cuda_time_total')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
some notes
the
torch.profiler.record_function('label')
is an effective way to quickly check how long all kernels launched under this range take to execute. For example, according to the above, the SDPA forward kernel has a latency of ~0.77ms. It is called 30 times under theforward_pass_benchmark
label, 30 calls should take ~21ms. Theforward_pass_benchmark
label has a cuda time of 23ms, which is close, the extra 2ms in there are probably because there are two device syncs under the label.these labels seem to not be displayed accurately in the trace visualization, at least for this example