Skip to content

Instantly share code, notes, and snippets.

@garrett361
Last active May 27, 2026 14:35
Show Gist options
  • Select an option

  • Save garrett361/8cf4f612eec730fe66881a21fd9449d2 to your computer and use it in GitHub Desktop.

Select an option

Save garrett361/8cf4f612eec730fe66881a21fd9449d2 to your computer and use it in GitHub Desktop.
HSDP for inhomogeneity
"""
Minimal reproducer showing that delaying the replicate-dimension all-reduce to
only the final microbatch improves throughput under HSDP with inhomogeneous
microbatch compute costs. Uses variable batch sizes as a stand-in for more realistic
cases, e.g. varying attention costs in packed sequences.
Usage:
# HSDP
torchrun --nproc-per-node=4 scripts/bench_delay_all_reduce.py --accum-steps 4
# HSDP with delayed all-reduce
torchrun --nproc-per-node=4 scripts/bench_delay_all_reduce.py --accum-steps 4 --delay-all-reduce
# FSDP over the world
torchrun --nproc-per-node=4 scripts/bench_delay_all_reduce.py --accum-steps 4 --fsdp-only
"""
import argparse
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
class FeedForward(nn.Module):
def __init__(self, hidden_dim: int, intermediate_dim: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False)
self.up_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False)
self.down_proj = nn.Linear(intermediate_dim, hidden_dim, bias=False)
def forward(self, x):
return self.down_proj(nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
class FFNStack(nn.Module):
def __init__(self, num_layers: int, hidden_dim: int, intermediate_dim: int):
super().__init__()
self.layers = nn.ModuleList(
[FeedForward(hidden_dim, intermediate_dim) for _ in range(num_layers)]
)
def forward(self, x):
for layer in self.layers:
x = x + layer(x)
return x
def get_batch_size(
microbatch_idx: int, group_idx: int, *, sync: bool, ratio: int
) -> int:
if sync:
return 1 if microbatch_idx % 2 == 0 else ratio
return 1 if (group_idx + microbatch_idx) % 2 == 0 else ratio
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--accum-steps", type=int, default=2)
parser.add_argument("--batch-size-ratio", type=int, default=2)
parser.add_argument("--delay-all-reduce", action="store_true")
parser.add_argument("--sync-batch-sizes", action="store_true")
parser.add_argument("--fsdp-only", action="store_true")
parser.add_argument("--seqlen", type=int, default=4096)
parser.add_argument("--num-layers", type=int, default=8)
parser.add_argument("--warmup-steps", type=int, default=5)
parser.add_argument("--measured-steps", type=int, default=10)
# Llama 8B defaults:
parser.add_argument("--hidden-dim", type=int, default=4096)
parser.add_argument("--intermediate-dim", type=int, default=14336)
args = parser.parse_args()
if args.delay_all_reduce and args.fsdp_only:
raise ValueError("--delay-all-reduce is incompatible with --fsdp-only")
try:
dist.init_process_group("nccl")
rank = dist.get_rank()
torch.cuda.set_device(rank)
if args.fsdp_only:
mesh = init_device_mesh("cuda", (4,), mesh_dim_names=("shard",))
group_idx = rank // 2
else:
mesh = init_device_mesh(
"cuda", (2, 2), mesh_dim_names=("replicate", "shard")
)
group_idx = mesh.get_local_rank("replicate")
model = FFNStack(args.num_layers, args.hidden_dim, args.intermediate_dim).cuda()
for layer in model.layers:
fully_shard(layer, mesh=mesh)
fully_shard(model, mesh=mesh)
warmup_steps = args.warmup_steps
measured_steps = args.measured_steps
total_steps = warmup_steps + measured_steps
start_events = [
torch.cuda.Event(enable_timing=True) for _ in range(measured_steps)
]
end_events = [
torch.cuda.Event(enable_timing=True) for _ in range(measured_steps)
]
total_examples = 0
for step in range(total_steps):
measure_idx = step - warmup_steps
if measure_idx >= 0:
start_events[measure_idx].record()
model.zero_grad()
for i in range(args.accum_steps):
if args.delay_all_reduce:
model.set_requires_all_reduce(i == args.accum_steps - 1)
bs = get_batch_size(
i,
group_idx,
sync=args.sync_batch_sizes,
ratio=args.batch_size_ratio,
)
if measure_idx >= 0:
total_examples += bs
x = torch.randn(bs, args.seqlen, args.hidden_dim, device="cuda")
model(x).sum().backward()
if measure_idx >= 0:
end_events[measure_idx].record()
torch.cuda.synchronize()
step_times_ms = torch.tensor(
[start_events[i].elapsed_time(end_events[i]) for i in range(measured_steps)]
)
mean_ms = step_times_ms.mean().item()
std_ms = step_times_ms.std().item()
total_time_sec = step_times_ms.sum().item() / 1000.0
total_tokens = torch.tensor(
total_examples * args.seqlen, device="cuda", dtype=torch.long
)
dist.all_reduce(total_tokens)
tok_per_sec_per_gpu = (
total_tokens.item() / total_time_sec / dist.get_world_size()
)
if rank == 0:
if args.fsdp_only:
mode = "FSDP"
elif args.delay_all_reduce:
mode = "HSDP(delayed AR)"
else:
mode = "HSDP"
print(
f"{mode} | "
f"accum_steps: {args.accum_steps}, batch_size_ratio: {args.batch_size_ratio} | "
f"step: {mean_ms:.1f} ± {std_ms:.1f} ms | "
f"tok/sec/gpu: {tok_per_sec_per_gpu:.0f} | "
)
finally:
dist.destroy_process_group()
if __name__ == "__main__":
main()

Measured on an 8xH100 node

❯ torchrun --nproc-per-node=4 scripts/bench_delay_all_reduce.py --accum-steps 2 --batch-size-ratio 2 --delay-all-reduce
HSDP(delayed AR) | accum_steps: 2, batch_size_ratio: 2 | step: 2145.8 ± 2.5 ms | tok/sec/gpu: 5727 |

❯ torchrun --nproc-per-node=4 scripts/bench_delay_all_reduce.py --accum-steps 2 --batch-size-ratio 2
HSDP | accum_steps: 2, batch_size_ratio: 2 | step: 2698.9 ± 1.2 ms | tok/sec/gpu: 4553 |

❯ torchrun --nproc-per-node=4 scripts/bench_delay_all_reduce.py --accum-steps 2 --batch-size-ratio 2 --fsdp-only
FSDP | accum_steps: 2, batch_size_ratio: 2 | step: 2693.5 ± 1.6 ms | tok/sec/gpu: 4562 |
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment