Last active
July 31, 2025 09:57
-
-
Save Bowser1704/5bc947ff0461f504bac8a211c10cb374 to your computer and use it in GitHub Desktop.
nccl_test
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torch.distributed as dist | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
def benchmark_all_reduce(): | |
# Initialize distributed environment | |
rank = int(os.environ["RANK"]) | |
local_rank = int(os.environ["LOCAL_RANK"]) | |
world_size = int(os.environ["WORLD_SIZE"]) | |
# Initialize NCCL backend (automatically enables RDMA when available) | |
dist.init_process_group(backend="nccl") | |
torch.cuda.set_device(local_rank) | |
# Benchmark parameters | |
data_sizes = [2**i for i in range(20, 31)] # 1MB to 1GB (elements) | |
num_iters = 100 | |
warmup_iters = 10 | |
if local_rank == 0: | |
print(f"{'Size (MB)':<12} {'Bandwidth (GB/s)':<15}", flush=True) | |
# Pre-allocate tensor with max size | |
max_size = max(data_sizes) | |
data = torch.rand(max_size, dtype=torch.float32, device="cuda") | |
element_size = data.element_size() | |
for size in data_sizes: | |
data_view = data[:size] | |
# Warmup phase | |
for _ in range(warmup_iters): | |
dist.all_reduce(data_view, op=dist.ReduceOp.SUM) | |
torch.cuda.synchronize() | |
# Accurate GPU timing | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
torch.cuda.synchronize() | |
start_event.record() | |
for _ in range(num_iters): | |
dist.all_reduce(data_view, op=dist.ReduceOp.SUM) | |
end_event.record() | |
torch.cuda.synchronize() | |
elapsed_time_ms = start_event.elapsed_time(end_event) | |
elapsed_time = elapsed_time_ms / 1000.0 # seconds | |
total_bytes = size * element_size | |
total_data_transferred = ( | |
2 * (world_size - 1) / world_size * total_bytes * num_iters | |
) | |
bandwidth = total_data_transferred / elapsed_time / (1024**3) # GB/s | |
if local_rank == 0: | |
size_mb = total_bytes / (1024**2) | |
print(f"{size_mb:<12.2f} {bandwidth:<15.2f}", flush=True) | |
dist.destroy_process_group() | |
if __name__ == "__main__": | |
benchmark_all_reduce() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment