Created
May 14, 2025 22:09
-
-
Save alexarmbr/72bb80ed0e8abbd617cc757ebd2d8ef2 to your computer and use it in GitHub Desktop.
context parallel attention benchmarking
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# torchrun --nproc-per-node 1 benchmark_sdpa.py | |
# torchrun --nproc-per-node 2 benchmark_sdpa.py | |
# torchrun --nproc-per-node 4 benchmark_sdpa.py | |
# torchrun --nproc-per-node 8 benchmark_sdpa.py | |
import torch | |
import torch.nn.functional as F | |
from torch.distributed.tensor.experimental import context_parallel | |
from torch.distributed.tensor.experimental._attention import _cp_options | |
import torch.distributed as dist | |
_cp_options.enable_load_balance = False | |
def benchmark_forward_pass(q, k, v, num_warmup=10, num_timed_runs=20): | |
""" | |
Benchmarks the forward pass of torch.nn.functional.scaled_dot_product_attention. | |
""" | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
# warm-up steps | |
for _ in range(num_warmup): | |
_ = F.scaled_dot_product_attention(q, k, v) | |
# torch.cuda.synchronize() # Ensure warm-up kernels complete | |
dist.barrier() | |
start_event.record() | |
for _ in range(num_timed_runs): | |
_ = F.scaled_dot_product_attention(q, k, v) | |
dist.barrier() | |
end_event.record() | |
torch.cuda.synchronize() | |
return start_event.elapsed_time(end_event) / num_timed_runs | |
def benchmark_backward_pass(q, k, v, num_warmup=10, num_timed_runs=20): | |
""" | |
Benchmarks the backward pass of torch.nn.functional.scaled_dot_product_attention. | |
""" | |
grad_output = torch.randn_like(q) | |
output = F.scaled_dot_product_attention(q, k, v) | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
# Warm-up steps | |
for _ in range(num_warmup): | |
# Clear gradients from previous iterations | |
q.grad = None | |
k.grad = None | |
v.grad = None | |
# Perform the forward pass to set up the graph for backward | |
output.backward(grad_output, retain_graph=True) # No need to retain graph if output is fresh | |
# torch.cuda.synchronize() # Ensure all warm-up operations (including backward) are complete | |
dist.barrier() | |
start_event.record() | |
for _ in range(num_timed_runs): | |
# Clear gradients for this specific iteration | |
q.grad = None | |
k.grad = None | |
v.grad = None | |
output.backward(grad_output, retain_graph=True) | |
dist.barrier() | |
end_event.record() | |
torch.cuda.synchronize() | |
return start_event.elapsed_time(end_event) / num_timed_runs | |
def benchmark_qkv_gemm(q, k, v, num_warmup=10, num_timed_runs=20): | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
dim = q.shape[-1] | |
weight = torch.randn(dim, 3 * dim, device=q.device, dtype=torch.bfloat16) | |
for _ in range(num_warmup): | |
_ = torch.matmul(q, weight) | |
dist.barrier() | |
start_event.record() | |
for _ in range(num_timed_runs): | |
_ = torch.matmul(q, weight) | |
dist.barrier() | |
end_event.record() | |
torch.cuda.synchronize() | |
return start_event.elapsed_time(end_event) / num_timed_runs | |
if __name__ == "__main__": | |
batch_size = 1 | |
num_heads = 24 | |
seq_len = 4096 + 512 | |
# seq_len = 100_000 | |
head_dim = 128 | |
dist.init_process_group(backend="nccl") | |
rank = dist.get_rank() | |
world_size = dist.get_world_size() | |
device = torch.device(f"cuda:{rank}") | |
device_mesh = dist.init_device_mesh("cuda", (world_size,)) | |
q_fwd,k_fwd,v_fwd = tuple(torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=torch.bfloat16) for _ in range(3)) | |
with context_parallel(mesh=device_mesh) as cp: | |
q_fwd_chunk = torch.chunk(q_fwd, world_size, dim=2) | |
k_fwd_chunk = torch.chunk(k_fwd, world_size, dim=2) | |
v_fwd_chunk = torch.chunk(v_fwd, world_size, dim=2) | |
avg_forward_time = benchmark_forward_pass(q_fwd_chunk[rank], k_fwd_chunk[rank], v_fwd_chunk[rank], num_warmup=10, num_timed_runs=20) | |
q_bwd,k_bwd,v_bwd = tuple(torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=torch.bfloat16, requires_grad=True) for _ in range(3)) | |
with context_parallel(mesh=device_mesh) as cp: | |
q_bwd_chunk = torch.chunk(q_bwd, world_size, dim=2) | |
k_bwd_chunk = torch.chunk(k_bwd, world_size, dim=2) | |
v_bwd_chunk = torch.chunk(v_bwd, world_size, dim=2) | |
avg_backward_time = benchmark_backward_pass(q_bwd_chunk[rank], k_bwd_chunk[rank], v_bwd_chunk[rank], num_warmup=10, num_timed_runs=20) | |
q_fwd_chunk = torch.chunk(q_fwd, world_size, dim=2) | |
k_fwd_chunk = torch.chunk(k_fwd, world_size, dim=2) | |
v_fwd_chunk = torch.chunk(v_fwd, world_size, dim=2) | |
qkv_gemm_time = benchmark_qkv_gemm(q_fwd_chunk[rank], k_fwd_chunk[rank], v_fwd_chunk[rank], num_warmup=10, num_timed_runs=20) | |
for r in range(world_size): | |
if rank == r: | |
print(f"--- Rank {r} ---") | |
print(f"Average forward pass time: {avg_forward_time:.3f} ms") | |
print(f"Average backward pass time: {avg_backward_time:.3f} ms") | |
print(f"Average qkv gemm time: {qkv_gemm_time:.3f} ms") | |
print("-----------------") | |
dist.barrier() | |
dist.destroy_process_group() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This script benchmarks
at a sequence length of 4096 which is what we have when training flux, according to this benchmark, context parallelism provides no speedup at all for any of these operations individually. I think this because the tensors are small enough that all the kernels are limited by memory latency. In this case adding more GPUs will not help, because every GPU experiences the same memory latency.
When you increase the sequence length to 100,00 you get near linear speedups from context parallelism (a bit less than linear for SDPA backward). I think this is because tensors are large enough that the kernels are compute bound, so parallelizing helps, because in this regime the more tensor cores you throw at the problem the faster it gets done.
In training we see that context parallelism at a sequence length of 4096 provides a 1.66x speedup when going from 1 to 2 devices, but marginal speedups after that. Given these profiling results, it is an open question why going from 1 to 2 GPUs gives us any speedup at all.