|
""" |
|
Minimal reproducer showing that delaying the replicate-dimension all-reduce to |
|
only the final microbatch improves throughput under HSDP with inhomogeneous |
|
microbatch compute costs. Uses variable batch sizes as a stand-in for more realistic |
|
cases, e.g. varying attention costs in packed sequences. |
|
|
|
Usage: |
|
|
|
# HSDP |
|
torchrun --nproc-per-node=4 scripts/bench_delay_all_reduce.py --accum-steps 4 |
|
# HSDP with delayed all-reduce |
|
torchrun --nproc-per-node=4 scripts/bench_delay_all_reduce.py --accum-steps 4 --delay-all-reduce |
|
# FSDP over the world |
|
torchrun --nproc-per-node=4 scripts/bench_delay_all_reduce.py --accum-steps 4 --fsdp-only |
|
""" |
|
|
|
import argparse |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
from torch.distributed._composable.fsdp import fully_shard |
|
from torch.distributed.device_mesh import init_device_mesh |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, hidden_dim: int, intermediate_dim: int): |
|
super().__init__() |
|
self.gate_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False) |
|
self.up_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False) |
|
self.down_proj = nn.Linear(intermediate_dim, hidden_dim, bias=False) |
|
|
|
def forward(self, x): |
|
return self.down_proj(nn.functional.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
class FFNStack(nn.Module): |
|
def __init__(self, num_layers: int, hidden_dim: int, intermediate_dim: int): |
|
super().__init__() |
|
self.layers = nn.ModuleList( |
|
[FeedForward(hidden_dim, intermediate_dim) for _ in range(num_layers)] |
|
) |
|
|
|
def forward(self, x): |
|
for layer in self.layers: |
|
x = x + layer(x) |
|
return x |
|
|
|
|
|
def get_batch_size( |
|
microbatch_idx: int, group_idx: int, *, sync: bool, ratio: int |
|
) -> int: |
|
if sync: |
|
return 1 if microbatch_idx % 2 == 0 else ratio |
|
return 1 if (group_idx + microbatch_idx) % 2 == 0 else ratio |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--accum-steps", type=int, default=2) |
|
parser.add_argument("--batch-size-ratio", type=int, default=2) |
|
parser.add_argument("--delay-all-reduce", action="store_true") |
|
parser.add_argument("--sync-batch-sizes", action="store_true") |
|
parser.add_argument("--fsdp-only", action="store_true") |
|
parser.add_argument("--seqlen", type=int, default=4096) |
|
parser.add_argument("--num-layers", type=int, default=8) |
|
parser.add_argument("--warmup-steps", type=int, default=5) |
|
parser.add_argument("--measured-steps", type=int, default=10) |
|
# Llama 8B defaults: |
|
parser.add_argument("--hidden-dim", type=int, default=4096) |
|
parser.add_argument("--intermediate-dim", type=int, default=14336) |
|
args = parser.parse_args() |
|
|
|
if args.delay_all_reduce and args.fsdp_only: |
|
raise ValueError("--delay-all-reduce is incompatible with --fsdp-only") |
|
|
|
try: |
|
dist.init_process_group("nccl") |
|
rank = dist.get_rank() |
|
torch.cuda.set_device(rank) |
|
|
|
if args.fsdp_only: |
|
mesh = init_device_mesh("cuda", (4,), mesh_dim_names=("shard",)) |
|
group_idx = rank // 2 |
|
else: |
|
mesh = init_device_mesh( |
|
"cuda", (2, 2), mesh_dim_names=("replicate", "shard") |
|
) |
|
group_idx = mesh.get_local_rank("replicate") |
|
|
|
model = FFNStack(args.num_layers, args.hidden_dim, args.intermediate_dim).cuda() |
|
for layer in model.layers: |
|
fully_shard(layer, mesh=mesh) |
|
fully_shard(model, mesh=mesh) |
|
|
|
warmup_steps = args.warmup_steps |
|
measured_steps = args.measured_steps |
|
total_steps = warmup_steps + measured_steps |
|
|
|
start_events = [ |
|
torch.cuda.Event(enable_timing=True) for _ in range(measured_steps) |
|
] |
|
end_events = [ |
|
torch.cuda.Event(enable_timing=True) for _ in range(measured_steps) |
|
] |
|
|
|
total_examples = 0 |
|
|
|
for step in range(total_steps): |
|
measure_idx = step - warmup_steps |
|
if measure_idx >= 0: |
|
start_events[measure_idx].record() |
|
|
|
model.zero_grad() |
|
for i in range(args.accum_steps): |
|
if args.delay_all_reduce: |
|
model.set_requires_all_reduce(i == args.accum_steps - 1) |
|
|
|
bs = get_batch_size( |
|
i, |
|
group_idx, |
|
sync=args.sync_batch_sizes, |
|
ratio=args.batch_size_ratio, |
|
) |
|
if measure_idx >= 0: |
|
total_examples += bs |
|
x = torch.randn(bs, args.seqlen, args.hidden_dim, device="cuda") |
|
model(x).sum().backward() |
|
|
|
if measure_idx >= 0: |
|
end_events[measure_idx].record() |
|
|
|
torch.cuda.synchronize() |
|
|
|
step_times_ms = torch.tensor( |
|
[start_events[i].elapsed_time(end_events[i]) for i in range(measured_steps)] |
|
) |
|
mean_ms = step_times_ms.mean().item() |
|
std_ms = step_times_ms.std().item() |
|
|
|
total_time_sec = step_times_ms.sum().item() / 1000.0 |
|
total_tokens = torch.tensor( |
|
total_examples * args.seqlen, device="cuda", dtype=torch.long |
|
) |
|
dist.all_reduce(total_tokens) |
|
tok_per_sec_per_gpu = ( |
|
total_tokens.item() / total_time_sec / dist.get_world_size() |
|
) |
|
|
|
if rank == 0: |
|
if args.fsdp_only: |
|
mode = "FSDP" |
|
elif args.delay_all_reduce: |
|
mode = "HSDP(delayed AR)" |
|
else: |
|
mode = "HSDP" |
|
print( |
|
f"{mode} | " |
|
f"accum_steps: {args.accum_steps}, batch_size_ratio: {args.batch_size_ratio} | " |
|
f"step: {mean_ms:.1f} ± {std_ms:.1f} ms | " |
|
f"tok/sec/gpu: {tok_per_sec_per_gpu:.0f} | " |
|
) |
|
|
|
finally: |
|
dist.destroy_process_group() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |