Skip to content

Instantly share code, notes, and snippets.

@alexarmbr
Created May 14, 2025 22:09
Show Gist options
  • Save alexarmbr/72bb80ed0e8abbd617cc757ebd2d8ef2 to your computer and use it in GitHub Desktop.
Save alexarmbr/72bb80ed0e8abbd617cc757ebd2d8ef2 to your computer and use it in GitHub Desktop.
context parallel attention benchmarking
# torchrun --nproc-per-node 1 benchmark_sdpa.py
# torchrun --nproc-per-node 2 benchmark_sdpa.py
# torchrun --nproc-per-node 4 benchmark_sdpa.py
# torchrun --nproc-per-node 8 benchmark_sdpa.py
import torch
import torch.nn.functional as F
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import _cp_options
import torch.distributed as dist
_cp_options.enable_load_balance = False
def benchmark_forward_pass(q, k, v, num_warmup=10, num_timed_runs=20):
"""
Benchmarks the forward pass of torch.nn.functional.scaled_dot_product_attention.
"""
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# warm-up steps
for _ in range(num_warmup):
_ = F.scaled_dot_product_attention(q, k, v)
# torch.cuda.synchronize() # Ensure warm-up kernels complete
dist.barrier()
start_event.record()
for _ in range(num_timed_runs):
_ = F.scaled_dot_product_attention(q, k, v)
dist.barrier()
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_timed_runs
def benchmark_backward_pass(q, k, v, num_warmup=10, num_timed_runs=20):
"""
Benchmarks the backward pass of torch.nn.functional.scaled_dot_product_attention.
"""
grad_output = torch.randn_like(q)
output = F.scaled_dot_product_attention(q, k, v)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# Warm-up steps
for _ in range(num_warmup):
# Clear gradients from previous iterations
q.grad = None
k.grad = None
v.grad = None
# Perform the forward pass to set up the graph for backward
output.backward(grad_output, retain_graph=True) # No need to retain graph if output is fresh
# torch.cuda.synchronize() # Ensure all warm-up operations (including backward) are complete
dist.barrier()
start_event.record()
for _ in range(num_timed_runs):
# Clear gradients for this specific iteration
q.grad = None
k.grad = None
v.grad = None
output.backward(grad_output, retain_graph=True)
dist.barrier()
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_timed_runs
def benchmark_qkv_gemm(q, k, v, num_warmup=10, num_timed_runs=20):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
dim = q.shape[-1]
weight = torch.randn(dim, 3 * dim, device=q.device, dtype=torch.bfloat16)
for _ in range(num_warmup):
_ = torch.matmul(q, weight)
dist.barrier()
start_event.record()
for _ in range(num_timed_runs):
_ = torch.matmul(q, weight)
dist.barrier()
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_timed_runs
if __name__ == "__main__":
batch_size = 1
num_heads = 24
seq_len = 4096 + 512
# seq_len = 100_000
head_dim = 128
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{rank}")
device_mesh = dist.init_device_mesh("cuda", (world_size,))
q_fwd,k_fwd,v_fwd = tuple(torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=torch.bfloat16) for _ in range(3))
with context_parallel(mesh=device_mesh) as cp:
q_fwd_chunk = torch.chunk(q_fwd, world_size, dim=2)
k_fwd_chunk = torch.chunk(k_fwd, world_size, dim=2)
v_fwd_chunk = torch.chunk(v_fwd, world_size, dim=2)
avg_forward_time = benchmark_forward_pass(q_fwd_chunk[rank], k_fwd_chunk[rank], v_fwd_chunk[rank], num_warmup=10, num_timed_runs=20)
q_bwd,k_bwd,v_bwd = tuple(torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=torch.bfloat16, requires_grad=True) for _ in range(3))
with context_parallel(mesh=device_mesh) as cp:
q_bwd_chunk = torch.chunk(q_bwd, world_size, dim=2)
k_bwd_chunk = torch.chunk(k_bwd, world_size, dim=2)
v_bwd_chunk = torch.chunk(v_bwd, world_size, dim=2)
avg_backward_time = benchmark_backward_pass(q_bwd_chunk[rank], k_bwd_chunk[rank], v_bwd_chunk[rank], num_warmup=10, num_timed_runs=20)
q_fwd_chunk = torch.chunk(q_fwd, world_size, dim=2)
k_fwd_chunk = torch.chunk(k_fwd, world_size, dim=2)
v_fwd_chunk = torch.chunk(v_fwd, world_size, dim=2)
qkv_gemm_time = benchmark_qkv_gemm(q_fwd_chunk[rank], k_fwd_chunk[rank], v_fwd_chunk[rank], num_warmup=10, num_timed_runs=20)
for r in range(world_size):
if rank == r:
print(f"--- Rank {r} ---")
print(f"Average forward pass time: {avg_forward_time:.3f} ms")
print(f"Average backward pass time: {avg_backward_time:.3f} ms")
print(f"Average qkv gemm time: {qkv_gemm_time:.3f} ms")
print("-----------------")
dist.barrier()
dist.destroy_process_group()
@alexarmbr
Copy link
Author

This script benchmarks

  • SDPA forward
  • SDPA backward
  • the qkv projection that happens before attention (a gemm with M,N,K = sequence length, model dim * 3, model dim)

at a sequence length of 4096 which is what we have when training flux, according to this benchmark, context parallelism provides no speedup at all for any of these operations individually. I think this because the tensors are small enough that all the kernels are limited by memory latency. In this case adding more GPUs will not help, because every GPU experiences the same memory latency.

When you increase the sequence length to 100,00 you get near linear speedups from context parallelism (a bit less than linear for SDPA backward). I think this is because tensors are large enough that the kernels are compute bound, so parallelizing helps, because in this regime the more tensor cores you throw at the problem the faster it gets done.

In training we see that context parallelism at a sequence length of 4096 provides a 1.66x speedup when going from 1 to 2 devices, but marginal speedups after that. Given these profiling results, it is an open question why going from 1 to 2 GPUs gives us any speedup at all.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment