Skip to content

Instantly share code, notes, and snippets.

@alexarmbr
Created May 14, 2025 21:13
Show Gist options
  • Save alexarmbr/4d5f2c4fdd45b98f912d6728a8cdc01d to your computer and use it in GitHub Desktop.
Save alexarmbr/4d5f2c4fdd45b98f912d6728a8cdc01d to your computer and use it in GitHub Desktop.
pytorch profiler
import torch
import torch.nn.functional as F
import torch.profiler
def benchmark_forward_pass(q, k, v, num_warmup=10, num_timed_runs=20):
"""
Benchmarks the forward pass of torch.nn.functional.scaled_dot_product_attention.
"""
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# warm-up steps
for _ in range(num_warmup):
_ = F.scaled_dot_product_attention(q, k, v)
torch.cuda.synchronize() # Ensure warm-up kernels complete
start_event.record()
for _ in range(num_timed_runs):
_ = F.scaled_dot_product_attention(q, k, v)
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_timed_runs
def benchmark_backward_pass(q, k, v, num_warmup=10, num_timed_runs=20):
"""
Benchmarks the backward pass of torch.nn.functional.scaled_dot_product_attention.
"""
grad_output = torch.randn_like(q)
output = F.scaled_dot_product_attention(q, k, v)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# Warm-up steps
for _ in range(num_warmup):
# Clear gradients from previous iterations
q.grad = None
k.grad = None
v.grad = None
# Perform the forward pass to set up the graph for backward
output.backward(grad_output, retain_graph=True) # No need to retain graph if output is fresh
torch.cuda.synchronize() # Ensure all warm-up operations (including backward) are complete
start_event.record()
for _ in range(num_timed_runs):
# Clear gradients for this specific iteration
q.grad = None
k.grad = None
v.grad = None
output.backward(grad_output, retain_graph=True)
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_timed_runs
if __name__ == "__main__":
batch_size = 1
num_heads = 24
seq_len = 4096 + 512
head_dim = 128
device = torch.device("cuda")
q_fwd,k_fwd,v_fwd = tuple(torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16) for _ in range(3))
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
# record_shapes=True,
# profile_memory=True,
# with_stack=True,
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiles/sdpa_benchmark'),
) as prof:
with torch.profiler.record_function("forward_pass_benchmark"):
avg_forward_time = benchmark_forward_pass(q_fwd, k_fwd, v_fwd, num_warmup=10, num_timed_runs=20)
q_bwd,k_bwd,v_bwd = tuple(torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16, requires_grad=True) for _ in range(3))
with torch.profiler.record_function("backward_pass_benchmark"):
avg_backward_time = benchmark_backward_pass(q_bwd, k_bwd, v_bwd, num_warmup=10, num_timed_runs=20)
print(f"Average forward pass time: {avg_forward_time:.3f} ms")
print(f"Average backward pass time: {avg_backward_time:.3f} ms")
print(prof.key_averages().table(sort_by='cuda_time_total'))
@alexarmbr
Copy link
Author

alexarmbr commented May 14, 2025

some notes

  • self cpu/gpu - time spent in this layer of the call stack only, not including descendants
  • cpu/gpu - time spent at this layer of the call stack and below

the torch.profiler.record_function('label') is an effective way to quickly check how long all kernels launched under this range take to execute. For example, according to the above, the SDPA forward kernel has a latency of ~0.77ms. It is called 30 times under the forward_pass_benchmark label, 30 calls should take ~21ms. The forward_pass_benchmark label has a cuda time of 23ms, which is close, the extra 2ms in there are probably because there are two device syncs under the label.

these labels seem to not be displayed accurately in the trace visualization, at least for this example

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment