Skip to content

Instantly share code, notes, and snippets.

@alexarmbr
Created May 14, 2025 21:13
Show Gist options
  • Save alexarmbr/4d5f2c4fdd45b98f912d6728a8cdc01d to your computer and use it in GitHub Desktop.
Save alexarmbr/4d5f2c4fdd45b98f912d6728a8cdc01d to your computer and use it in GitHub Desktop.
pytorch profiler
import torch
import torch.nn.functional as F
import torch.profiler
def benchmark_forward_pass(q, k, v, num_warmup=10, num_timed_runs=20):
"""
Benchmarks the forward pass of torch.nn.functional.scaled_dot_product_attention.
"""
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# warm-up steps
for _ in range(num_warmup):
_ = F.scaled_dot_product_attention(q, k, v)
torch.cuda.synchronize() # Ensure warm-up kernels complete
start_event.record()
for _ in range(num_timed_runs):
_ = F.scaled_dot_product_attention(q, k, v)
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_timed_runs
def benchmark_backward_pass(q, k, v, num_warmup=10, num_timed_runs=20):
"""
Benchmarks the backward pass of torch.nn.functional.scaled_dot_product_attention.
"""
grad_output = torch.randn_like(q)
output = F.scaled_dot_product_attention(q, k, v)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# Warm-up steps
for _ in range(num_warmup):
# Clear gradients from previous iterations
q.grad = None
k.grad = None
v.grad = None
# Perform the forward pass to set up the graph for backward
output.backward(grad_output, retain_graph=True) # No need to retain graph if output is fresh
torch.cuda.synchronize() # Ensure all warm-up operations (including backward) are complete
start_event.record()
for _ in range(num_timed_runs):
# Clear gradients for this specific iteration
q.grad = None
k.grad = None
v.grad = None
output.backward(grad_output, retain_graph=True)
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_timed_runs
if __name__ == "__main__":
batch_size = 1
num_heads = 24
seq_len = 4096 + 512
head_dim = 128
device = torch.device("cuda")
q_fwd,k_fwd,v_fwd = tuple(torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16) for _ in range(3))
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
# record_shapes=True,
# profile_memory=True,
# with_stack=True,
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiles/sdpa_benchmark'),
) as prof:
with torch.profiler.record_function("forward_pass_benchmark"):
avg_forward_time = benchmark_forward_pass(q_fwd, k_fwd, v_fwd, num_warmup=10, num_timed_runs=20)
q_bwd,k_bwd,v_bwd = tuple(torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16, requires_grad=True) for _ in range(3))
with torch.profiler.record_function("backward_pass_benchmark"):
avg_backward_time = benchmark_backward_pass(q_bwd, k_bwd, v_bwd, num_warmup=10, num_timed_runs=20)
print(f"Average forward pass time: {avg_forward_time:.3f} ms")
print(f"Average backward pass time: {avg_backward_time:.3f} ms")
print(prof.key_averages().table(sort_by='cuda_time_total'))
@alexarmbr
Copy link
Author

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
autograd::engine::evaluate_function: ScaledDotProduc...         0.03%     191.202us         3.67%      20.488ms     682.921us       0.000us         0.00%      71.281ms       2.376ms            30  
                ScaledDotProductFlashAttentionBackward0         0.05%     292.994us         3.64%      20.296ms     676.548us       0.000us         0.00%      71.281ms       2.376ms            30  
     aten::_scaled_dot_product_flash_attention_backward         0.08%     441.323us         3.58%      20.003ms     666.781us       0.000us         0.00%      71.281ms       2.376ms            30  
                        aten::_flash_attention_backward         0.09%     526.104us         3.45%      19.267ms     642.219us      68.925ms        72.17%      71.281ms       2.376ms            30  
void pytorch_flash::flash_bwd_dq_dk_dv_loop_seqk_par...         0.00%       0.000us         0.00%       0.000us       0.000us      67.183ms        70.35%      67.183ms       2.239ms            30  
                                 forward_pass_benchmark         0.34%       1.884ms         5.45%      30.428ms      30.428ms       0.000us         0.00%      27.944ms      27.944ms             1  
                     aten::scaled_dot_product_attention         0.05%     259.845us         1.22%       6.801ms     219.379us       0.000us         0.00%      26.430ms     852.569us            31  
              aten::_scaled_dot_product_flash_attention         0.06%     319.763us         0.82%       4.600ms     148.393us       0.000us         0.00%      26.430ms     852.569us            31  
                         aten::_flash_attention_forward         0.05%     297.125us         0.73%       4.085ms     131.782us      24.075ms        25.21%      26.430ms     852.569us            31  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us      24.075ms        25.21%      24.075ms     776.621us            31  
                                 forward_pass_benchmark         0.00%       0.000us         0.00%       0.000us       0.000us      23.488ms        24.59%      23.488ms      23.488ms             1  
                                       aten::contiguous         0.02%      90.192us         2.81%      15.672ms     261.206us       0.000us         0.00%       2.356ms      39.261us            60  
                                            aten::clone         0.03%     141.941us         2.79%      15.582ms     259.702us       0.000us         0.00%       2.356ms      39.261us            60  
                                            aten::copy_         0.07%     363.206us         2.56%      14.288ms     238.134us       2.356ms         2.47%       2.356ms      39.261us            60  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       2.356ms         2.47%       2.356ms      39.261us            60  
                                   cudaFuncSetAttribute         0.02%      99.671us         0.48%       2.673ms      43.821us       0.000us         0.00%       2.354ms      38.597us            61  
                                   cudaEventElapsedTime         0.00%       7.320us         0.00%       7.320us       3.660us       2.297ms         2.41%       2.297ms       1.148ms             2  
                       Runtime Triggered Module Loading         3.17%      17.699ms         3.17%      17.699ms       2.950ms       1.570ms         1.64%       1.570ms     261.600us             6  
                                backward_pass_benchmark         0.00%       0.000us         0.00%       0.000us       0.000us       1.055ms         1.10%       1.055ms       1.055ms             1  
void pytorch_flash::flash_bwd_convert_dq_kernel<Flas...         0.00%       0.000us         0.00%       0.000us       0.000us     876.766us         0.92%     876.766us      29.226us            30  
void pytorch_flash::flash_bwd_dot_do_o_kernel<true, ...         0.00%       0.000us         0.00%       0.000us       0.000us     865.254us         0.91%     865.254us      28.842us            30  
                                backward_pass_benchmark        76.09%     424.763ms        90.70%     506.357ms     506.357ms       0.000us         0.00%     820.640us     820.640us             1  
                                  Lazy Function Loading         0.02%     135.503us         0.02%     135.503us      27.101us     784.801us         0.82%     784.801us     156.960us             5  
                                          aten::normal_         0.01%      53.182us         0.02%     108.612us      27.153us     148.735us         0.16%     148.735us      37.184us             4  
void at::native::(anonymous namespace)::distribution...         0.00%       0.000us         0.00%       0.000us       0.000us     148.735us         0.16%     148.735us      37.184us             4  
                                            aten::randn         0.02%      87.702us         0.10%     585.340us     195.113us       0.000us         0.00%     111.103us      37.034us             3  
                                       aten::randn_like         0.00%       4.370us         0.03%     193.723us     193.723us       0.000us         0.00%      37.632us      37.632us             1  
                                           Unrecognized         0.35%       1.941ms         0.35%       1.941ms       1.941ms       0.000us         0.00%       0.000us       0.000us             1  
                                        aten::transpose         0.06%     324.438us         0.09%     490.722us       1.348us       0.000us         0.00%       0.000us       0.000us           364  
                                       aten::as_strided         0.03%     166.284us         0.03%     166.284us       0.457us       0.000us         0.00%       0.000us       0.000us           364  
                                       aten::empty_like         0.03%     167.011us         0.48%       2.679ms      14.721us       0.000us         0.00%       0.000us       0.000us           182  
                                    aten::empty_strided         0.07%     412.196us         0.25%       1.414ms      11.590us       0.000us         0.00%       0.000us       0.000us           122  
                                  cudaStreamIsCapturing         0.01%      36.280us         0.01%      36.280us       1.814us       0.000us         0.00%       0.000us       0.000us            20  
                                             cudaMalloc         0.43%       2.428ms         0.43%       2.428ms     202.344us       0.000us         0.00%       0.000us       0.000us            12  
                                            aten::empty         0.13%     727.249us         0.39%       2.185ms       8.847us       0.000us         0.00%       0.000us       0.000us           247  
                                 cudaDeviceGetAttribute         0.00%      22.820us         0.00%      22.820us       0.248us       0.000us         0.00%       0.000us       0.000us            92  
                                       cudaLaunchKernel         2.88%      16.100ms         2.88%      16.100ms      87.029us       0.000us         0.00%       0.000us       0.000us           185  
                                  cudaDeviceSynchronize        15.74%      87.849ms        15.74%      87.849ms      17.570ms       0.000us         0.00%       0.000us       0.000us             5  
                                        cudaEventRecord         0.01%      31.311us         0.01%      31.311us       7.828us       0.000us         0.00%       0.000us       0.000us             4  
autograd::engine::evaluate_function: torch::autograd...         0.03%     178.343us         0.07%     401.077us       4.456us       0.000us         0.00%       0.000us       0.000us            90  
                        torch::autograd::AccumulateGrad         0.02%     116.403us         0.04%     222.734us       2.475us       0.000us         0.00%       0.000us       0.000us            90  
                                           aten::detach         0.01%      46.261us         0.02%     106.331us       1.181us       0.000us         0.00%       0.000us       0.000us            90  
                                                 detach         0.01%      60.070us         0.01%      60.070us       0.667us       0.000us         0.00%       0.000us       0.000us            90  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------ 

@alexarmbr
Copy link
Author

alexarmbr commented May 14, 2025

some notes

  • self cpu/gpu - time spent in this layer of the call stack only, not including descendants
  • cpu/gpu - time spent at this layer of the call stack and below

the torch.profiler.record_function('label') is an effective way to quickly check how long all kernels launched under this range take to execute. For example, according to the above, the SDPA forward kernel has a latency of ~0.77ms. It is called 30 times under the forward_pass_benchmark label, 30 calls should take ~21ms. The forward_pass_benchmark label has a cuda time of 23ms, which is close, the extra 2ms in there are probably because there are two device syncs under the label.

these labels seem to not be displayed accurately in the trace visualization, at least for this example

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment