Created
May 9, 2025 23:04
-
-
Save alexarmbr/59996bcffb49d9bc0b3ebf02c3858a58 to your computer and use it in GitHub Desktop.
a toy example of one way to use torch.distributed.tensor.experimental.context_parallel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from tqdm import tqdm | |
import torch | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
from torch.distributed.tensor.experimental import context_parallel | |
from torch.distributed.tensor.experimental._attention import _cp_options | |
from torch.nn.attention import SDPBackend, sdpa_kernel | |
from torch.distributed._functional_collectives import all_gather_tensor_autograd | |
torch.manual_seed(0) | |
# this is required for non causal attention to work | |
_cp_options.enable_load_balance = False | |
TRAINING_STEPS = 100 | |
LEARNING_RATE = 0.01 | |
B, H, N, D = 2, 12, 100_000, 64 | |
N = (N // 4) * 4 | |
class SimpleAttentionLayer(torch.nn.Module): | |
def __init__(self, device): | |
super().__init__() | |
self.q_proj = torch.nn.Linear(D, D, dtype=torch.bfloat16, device=device) | |
self.k_proj = torch.nn.Linear(D, D, dtype=torch.bfloat16, device=device) | |
self.v_proj = torch.nn.Linear(D, D, dtype=torch.bfloat16, device=device) | |
self.o_proj = torch.nn.Linear(D, D, dtype=torch.bfloat16, device=device) | |
# initialize all weights as random uniform | |
torch.nn.init.uniform_(self.q_proj.weight) | |
torch.nn.init.uniform_(self.k_proj.weight) | |
torch.nn.init.uniform_(self.v_proj.weight) | |
torch.nn.init.uniform_(self.o_proj.weight) | |
torch.nn.init.uniform_(self.q_proj.bias) | |
torch.nn.init.uniform_(self.k_proj.bias) | |
torch.nn.init.uniform_(self.v_proj.bias) | |
torch.nn.init.uniform_(self.o_proj.bias) | |
def forward(self, dist_x): | |
dist_q = self.q_proj(dist_x) | |
dist_k = self.k_proj(dist_x) | |
dist_v = self.v_proj(dist_x) | |
with sdpa_kernel(SDPBackend.FLASH_ATTENTION): | |
dist_o = F.scaled_dot_product_attention(dist_q, dist_k, dist_v) | |
dist_o = self.o_proj(dist_o) | |
return dist_o | |
def ctx_parallel_worker(): | |
dist.init_process_group(backend="nccl") | |
rank = dist.get_rank() | |
world_size = dist.get_world_size() | |
device = torch.device(f"cuda:{rank}") | |
device_mesh = dist.init_device_mesh("cuda", (world_size,)) | |
assert N % world_size == 0 | |
model = SimpleAttentionLayer(device) | |
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) | |
if rank == 0: | |
input = torch.randn(B, H, N, D, dtype=torch.bfloat16, device=device) | |
target = torch.randn(B, H, N, D, dtype=torch.bfloat16, device=device) | |
else: | |
input = torch.zeros(B, H, N, D, dtype=torch.bfloat16, device=device) | |
target = torch.zeros(B, H, N, D, dtype=torch.bfloat16, device=device) | |
dist.broadcast(input, src=0) | |
dist.broadcast(target, src=0) | |
loss_curve = [] | |
# Wrap the range with tqdm and assign it to a variable | |
local_input = torch.chunk(input, world_size, dim=2)[rank] | |
progress_bar = tqdm(range(TRAINING_STEPS), bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_inv_fmt}{postfix}]') | |
# train a single attention layer to map input to target | |
# the attention layer is replicated across ranks, | |
# the sequence dimension of the input is sharded across ranks | |
# and context_parallel is used to call a parallel implementation of the attention layer | |
for i in progress_bar: | |
# this context manager installs some hooks into the DTensor class to intercept calls to F.scaled_dot_product_attention | |
# and replace them with calls to a ring attention implementation that operates on DTensor objects | |
with context_parallel(device_mesh): | |
optimizer.zero_grad() | |
local_output = model(local_input) | |
# gather the output, now each rank has the full output | |
# this all gather is compatible with autograd, regular torch.distributed collectives are not | |
output = all_gather_tensor_autograd(local_output, gather_dim=2, group=device_mesh) | |
# compute the loss, since target and output match across all ranks, the loss will match across ranks | |
loss = torch.nn.functional.mse_loss(output, target) | |
# compute the gradient of each weight WRT the loss on each rank | |
# the loss is the same across all ranks, but the gradients will be different | |
# since each rank ran forward with a different slice of the input | |
loss.backward() | |
# all reduce the gradients across all ranks, this would probably be inneficient for a larger model | |
for param in model.parameters(): | |
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM) | |
# now all gradients should match, update the weights | |
optimizer.step() | |
progress_bar.set_postfix(loss=loss.item()) | |
loss_curve.append((i, loss.item())) | |
if dist.is_initialized(): | |
dist.destroy_process_group() | |
if rank == 0: | |
with open("loss_curve_ctx_parallel.csv", "w") as f: | |
f.write("step,loss\n") | |
for row in loss_curve: | |
row_str = ",".join([f"{loss:.2f}" for loss in row]) | |
f.write(f"{row_str}\n") | |
def single_gpu_worker(): | |
device = torch.device("cuda") | |
model = SimpleAttentionLayer(device) | |
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) | |
input = torch.randn(B, H, N, D, dtype=torch.bfloat16, device=device) | |
target = torch.randn(B, H, N, D, dtype=torch.bfloat16, device=device) | |
loss_curve = [] | |
progress_bar = tqdm(range(TRAINING_STEPS), bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_inv_fmt}{postfix}]') | |
# train a single attention layer to map input to target | |
for i in progress_bar: | |
optimizer.zero_grad() | |
output = model(input) | |
loss = torch.nn.functional.mse_loss(output, target) | |
loss_curve.append((i, loss.item())) | |
loss.backward() | |
optimizer.step() | |
progress_bar.set_postfix(loss=loss.item()) | |
with open("loss_curve_no_ctx_parallel.csv", "w") as f: | |
f.write("step,loss\n") | |
for i, loss in loss_curve: | |
f.write(f"{i},{loss}\n") | |
if __name__ == "__main__": | |
# loss curves should match exactly between the context parallel and single gpu training | |
# torchrun --nproc-per-node 4 context_parallel_example.py | |
# on 4 x H200 this runs at 0.23s / iter | |
if "RANK" in os.environ: | |
ctx_parallel_worker() | |
# python context_parallel_example.py | |
# on 1 x H200 this runs at 0.77s / iter | |
else: | |
single_gpu_worker() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
to plot loss curves