Skip to content

Instantly share code, notes, and snippets.

@Bowser1704
Last active July 31, 2025 09:57
Show Gist options
  • Save Bowser1704/5bc947ff0461f504bac8a211c10cb374 to your computer and use it in GitHub Desktop.
Save Bowser1704/5bc947ff0461f504bac8a211c10cb374 to your computer and use it in GitHub Desktop.
nccl_test
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def benchmark_all_reduce():
# Initialize distributed environment
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
# Initialize NCCL backend (automatically enables RDMA when available)
dist.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)
# Benchmark parameters
data_sizes = [2**i for i in range(20, 31)] # 1MB to 1GB (elements)
num_iters = 100
warmup_iters = 10
if local_rank == 0:
print(f"{'Size (MB)':<12} {'Bandwidth (GB/s)':<15}", flush=True)
# Pre-allocate tensor with max size
max_size = max(data_sizes)
data = torch.rand(max_size, dtype=torch.float32, device="cuda")
element_size = data.element_size()
for size in data_sizes:
data_view = data[:size]
# Warmup phase
for _ in range(warmup_iters):
dist.all_reduce(data_view, op=dist.ReduceOp.SUM)
torch.cuda.synchronize()
# Accurate GPU timing
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(num_iters):
dist.all_reduce(data_view, op=dist.ReduceOp.SUM)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
elapsed_time = elapsed_time_ms / 1000.0 # seconds
total_bytes = size * element_size
total_data_transferred = (
2 * (world_size - 1) / world_size * total_bytes * num_iters
)
bandwidth = total_data_transferred / elapsed_time / (1024**3) # GB/s
if local_rank == 0:
size_mb = total_bytes / (1024**2)
print(f"{size_mb:<12.2f} {bandwidth:<15.2f}", flush=True)
dist.destroy_process_group()
if __name__ == "__main__":
benchmark_all_reduce()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment