Skip to content

Instantly share code, notes, and snippets.

@alexarmbr
Created May 9, 2025 23:04
Show Gist options
  • Save alexarmbr/59996bcffb49d9bc0b3ebf02c3858a58 to your computer and use it in GitHub Desktop.
Save alexarmbr/59996bcffb49d9bc0b3ebf02c3858a58 to your computer and use it in GitHub Desktop.
a toy example of one way to use torch.distributed.tensor.experimental.context_parallel
import os
from tqdm import tqdm
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import _cp_options
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.distributed._functional_collectives import all_gather_tensor_autograd
torch.manual_seed(0)
# this is required for non causal attention to work
_cp_options.enable_load_balance = False
TRAINING_STEPS = 100
LEARNING_RATE = 0.01
B, H, N, D = 2, 12, 100_000, 64
N = (N // 4) * 4
class SimpleAttentionLayer(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.q_proj = torch.nn.Linear(D, D, dtype=torch.bfloat16, device=device)
self.k_proj = torch.nn.Linear(D, D, dtype=torch.bfloat16, device=device)
self.v_proj = torch.nn.Linear(D, D, dtype=torch.bfloat16, device=device)
self.o_proj = torch.nn.Linear(D, D, dtype=torch.bfloat16, device=device)
# initialize all weights as random uniform
torch.nn.init.uniform_(self.q_proj.weight)
torch.nn.init.uniform_(self.k_proj.weight)
torch.nn.init.uniform_(self.v_proj.weight)
torch.nn.init.uniform_(self.o_proj.weight)
torch.nn.init.uniform_(self.q_proj.bias)
torch.nn.init.uniform_(self.k_proj.bias)
torch.nn.init.uniform_(self.v_proj.bias)
torch.nn.init.uniform_(self.o_proj.bias)
def forward(self, dist_x):
dist_q = self.q_proj(dist_x)
dist_k = self.k_proj(dist_x)
dist_v = self.v_proj(dist_x)
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
dist_o = F.scaled_dot_product_attention(dist_q, dist_k, dist_v)
dist_o = self.o_proj(dist_o)
return dist_o
def ctx_parallel_worker():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{rank}")
device_mesh = dist.init_device_mesh("cuda", (world_size,))
assert N % world_size == 0
model = SimpleAttentionLayer(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
if rank == 0:
input = torch.randn(B, H, N, D, dtype=torch.bfloat16, device=device)
target = torch.randn(B, H, N, D, dtype=torch.bfloat16, device=device)
else:
input = torch.zeros(B, H, N, D, dtype=torch.bfloat16, device=device)
target = torch.zeros(B, H, N, D, dtype=torch.bfloat16, device=device)
dist.broadcast(input, src=0)
dist.broadcast(target, src=0)
loss_curve = []
# Wrap the range with tqdm and assign it to a variable
local_input = torch.chunk(input, world_size, dim=2)[rank]
progress_bar = tqdm(range(TRAINING_STEPS), bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_inv_fmt}{postfix}]')
# train a single attention layer to map input to target
# the attention layer is replicated across ranks,
# the sequence dimension of the input is sharded across ranks
# and context_parallel is used to call a parallel implementation of the attention layer
for i in progress_bar:
# this context manager installs some hooks into the DTensor class to intercept calls to F.scaled_dot_product_attention
# and replace them with calls to a ring attention implementation that operates on DTensor objects
with context_parallel(device_mesh):
optimizer.zero_grad()
local_output = model(local_input)
# gather the output, now each rank has the full output
# this all gather is compatible with autograd, regular torch.distributed collectives are not
output = all_gather_tensor_autograd(local_output, gather_dim=2, group=device_mesh)
# compute the loss, since target and output match across all ranks, the loss will match across ranks
loss = torch.nn.functional.mse_loss(output, target)
# compute the gradient of each weight WRT the loss on each rank
# the loss is the same across all ranks, but the gradients will be different
# since each rank ran forward with a different slice of the input
loss.backward()
# all reduce the gradients across all ranks, this would probably be inneficient for a larger model
for param in model.parameters():
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
# now all gradients should match, update the weights
optimizer.step()
progress_bar.set_postfix(loss=loss.item())
loss_curve.append((i, loss.item()))
if dist.is_initialized():
dist.destroy_process_group()
if rank == 0:
with open("loss_curve_ctx_parallel.csv", "w") as f:
f.write("step,loss\n")
for row in loss_curve:
row_str = ",".join([f"{loss:.2f}" for loss in row])
f.write(f"{row_str}\n")
def single_gpu_worker():
device = torch.device("cuda")
model = SimpleAttentionLayer(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
input = torch.randn(B, H, N, D, dtype=torch.bfloat16, device=device)
target = torch.randn(B, H, N, D, dtype=torch.bfloat16, device=device)
loss_curve = []
progress_bar = tqdm(range(TRAINING_STEPS), bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_inv_fmt}{postfix}]')
# train a single attention layer to map input to target
for i in progress_bar:
optimizer.zero_grad()
output = model(input)
loss = torch.nn.functional.mse_loss(output, target)
loss_curve.append((i, loss.item()))
loss.backward()
optimizer.step()
progress_bar.set_postfix(loss=loss.item())
with open("loss_curve_no_ctx_parallel.csv", "w") as f:
f.write("step,loss\n")
for i, loss in loss_curve:
f.write(f"{i},{loss}\n")
if __name__ == "__main__":
# loss curves should match exactly between the context parallel and single gpu training
# torchrun --nproc-per-node 4 context_parallel_example.py
# on 4 x H200 this runs at 0.23s / iter
if "RANK" in os.environ:
ctx_parallel_worker()
# python context_parallel_example.py
# on 1 x H200 this runs at 0.77s / iter
else:
single_gpu_worker()
@alexarmbr
Copy link
Author

to plot loss curves

import pandas as pd
import matplotlib.pyplot as plt

df_ctx = pd.read_csv('loss_curve_ctx_parallel.csv')
df_no_ctx = pd.read_csv('loss_curve_no_ctx_parallel.csv')

plt.figure(figsize=(10, 6))
plt.plot(df_ctx['step'], df_ctx['loss'], label='Loss (Context Parallel)')
plt.plot(df_no_ctx['step'], df_no_ctx['loss'], label='Loss (No Context Parallel)')
plt.title('Comparison of Loss Curves')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.savefig('loss_comparison.png')

@alexarmbr
Copy link
Author

loss_comparison

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment