Last active
January 22, 2022 01:25
-
-
Save nelhage/c5f9b2831014963a4b35d2e0311f93c7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import os | |
import time | |
import torch | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
INTERVAL = 1 | |
COMM_SIZE = (10,) | |
def run(rank, size): | |
torch.cuda.set_device(rank) | |
pg = torch.distributed.new_group(list(range(size)), backend="nccl") | |
if rank == 0: | |
s1 = torch.cuda.Stream() | |
s2 = torch.cuda.Stream() | |
dist.barrier() | |
torch.cuda.synchronize() | |
outputs = torch.zeros((size, *COMM_SIZE), dtype=torch.float, device="cuda") | |
mine = torch.randn(COMM_SIZE, dtype=torch.float, device="cuda") | |
# Do an allgather to warm up comms. Somehow the first | |
# all-gather we do isn't actually async and waits for the comm | |
# to complete. | |
pg._allgather_base(outputs, mine).wait() | |
with torch.cuda.stream(s1): | |
# Allocate a tensor whose backing block comes from stream | |
# `s1`. | |
mine = torch.randn(COMM_SIZE, dtype=torch.float, device="cuda") | |
# When we execute the _allgather_base, | |
# ProcessGroupNCCL::collective calls `recordStream` to | |
# record `mine` as having been used on the NCCL comms | |
# stream. | |
handle = pg._allgather_base(outputs, mine) | |
# Now we free `mine`. This ends up in | |
# DeviceCachingAllocator::free, which notices that the | |
# block has non-empty stream uses, and queues an event on | |
# the NCCL comms stream. | |
# | |
# Note that the bug wouldn't show up with point-to-point | |
# comms, because they hold on to their input or output | |
# tensors in WorkNCCL::outputs_, and so the tensor would | |
# not actually be freed her. | |
mine = None | |
print("[0] Queued the receive.") | |
t = time.time() | |
# Now we do some concurrent work while the comms happen in the | |
# background. | |
while not handle.is_completed(): | |
# We allocate a tensor, and then we `record_stream` to | |
# make the allocator record it as having stream_uses. This | |
# is the simplest demo for a reproducer; in real code this | |
# can happen in autograd, by other comms, or a handful of | |
# other ways. | |
data = torch.randn((1024,), device="cuda") | |
data.record_stream(s2) | |
# Now we free `data`. Since it has `stream_uses`, the | |
# allocator enqueues an event and marks the underlying | |
# buffer for later free. | |
# | |
# However, `process_events` will walk the event list in | |
# order, and stop at the first event which isn't | |
# ready. Since we queued and event on the NCCL comms up | |
# above, it will always stop there, and no memory will be | |
# released until the comms complete. | |
data = None | |
now = time.time() | |
if (now - t) > INTERVAL: | |
# Dump memory stats every second | |
t = now | |
print(torch.cuda.memory_summary(abbreviated=True)) | |
handle.wait() | |
else: | |
dist.barrier() | |
outputs = torch.zeros((size, *COMM_SIZE), dtype=torch.float, device="cuda") | |
mine = torch.randn(COMM_SIZE, dtype=torch.float, device="cuda") | |
pg._allgather_base(outputs, mine) | |
# On rank 1, we just sleep 10s and then do an all-gather, to | |
# achieve the effect of a long-running op on the NCCL stream | |
# in rank 0. | |
print("[1] Sleeping...") | |
time.sleep(10) | |
pg._allgather_base(outputs, mine) | |
print("[1] Sent a tensor") | |
def init_process(rank, size, fn, backend="nccl"): | |
""" Initialize the distributed environment. """ | |
os.environ["MASTER_ADDR"] = "127.0.0.1" | |
os.environ["MASTER_PORT"] = "29500" | |
dist.init_process_group(backend, rank=rank, world_size=size) | |
fn(rank, size) | |
if __name__ == "__main__": | |
size = 2 | |
processes = [] | |
mp.set_start_method("spawn") | |
for rank in range(size): | |
p = mp.Process(target=init_process, args=(rank, size, run)) | |
p.start() | |
processes.append(p) | |
for p in processes: | |
p.join() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import os | |
import time | |
import torch | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
INTERVAL = 1 | |
COMM_SIZE = (10,) | |
def run(rank, size): | |
torch.cuda.set_device(rank) | |
if rank == 0: | |
s1 = torch.cuda.Stream() | |
s2 = torch.cuda.Stream() | |
dist.barrier() | |
torch.cuda.synchronize() | |
buf = torch.empty(COMM_SIZE, dtype=torch.float, device="cuda") | |
# Do a warmup comms; The first comm seems to block until | |
# completion whether or not we do it async. | |
dist.irecv(buf, src=1).wait() | |
with torch.cuda.stream(s1): | |
# Allocate a tensor whose backing block comes from stream | |
# `s1`. | |
buf = torch.empty(COMM_SIZE, dtype=torch.float, device="cuda") | |
# Now use it in a NCCL | |
# comm. ProcessGroupNCCL::pointToPoint will call | |
# `recordStream` to record `buf` as having been used on | |
# the NCCL comms stream. | |
# | |
# This comm will be fast since rank 0 sends promptly. | |
dist.irecv(buf, src=1).wait() | |
# Now we start a long-running comm. Rank 0 will sleep | |
# before sending this tensor, so this results in a | |
# long-running op on the NCCL CUDA stream. | |
handle = dist.irecv( | |
torch.empty(COMM_SIZE, dtype=torch.float, device="cuda"), src=1 | |
) | |
# Now we free `buf`. This eventually ends up in | |
# DeviceCachingAllocator::free; it notices that the block | |
# has non-empty stream uses, and so queues an event on the | |
# NCCL comms stream to make sure the tensor is actually | |
# done being used before it is actually released to CUDA. | |
buf = None | |
print("[0] Queued the receive.") | |
t = time.time() | |
# Now we do some concurrent work while the comms happen in the | |
# background. | |
while not handle.is_completed(): | |
# We allocate a tensor, and then we `record_stream` to | |
# make the allocator record it as having stream_uses. This | |
# is the simplest demo for a reproducer; in real code this | |
# can happen in autograd, by other comms, or a handful of | |
# other ways. | |
data = torch.randn((1024,), device="cuda") | |
data.record_stream(s2) | |
# Now we free `data`. Since it has `stream_uses`, the | |
# allocator enqueues an event and marks the underlying | |
# buffer for later free. | |
# | |
# However, `process_events` will walk the event list in | |
# order, and stop at the first event which isn't | |
# ready. Since we queued an event on the NCCL comms up | |
# above, it will always stop there, and no memory will be | |
# released until the comms complete. | |
data = None | |
now = time.time() | |
if (now - t) > INTERVAL: | |
# Dump memory stats every second | |
t = now | |
print(torch.cuda.memory_summary(abbreviated=True)) | |
handle.wait() | |
else: | |
dist.barrier() | |
buf = torch.randn(COMM_SIZE, dtype=torch.float, device="cuda") | |
# One warm-up send | |
dist.isend(buf, dst=0).wait() | |
# One send for the first (fast) `irecv` in the rank 0 | |
dist.isend(buf, dst=0).wait() | |
# Now we sleep 10 and then do a final isend, to cause the | |
# final `irecv` in rank 0 to be long-running. | |
print("[1] Sleeping...") | |
time.sleep(10) | |
dist.isend(buf, dst=0).wait() | |
print("[1] Sent a tensor") | |
def init_process(rank, size, fn, backend="nccl"): | |
""" Initialize the distributed environment. """ | |
os.environ["MASTER_ADDR"] = "127.0.0.1" | |
os.environ["MASTER_PORT"] = "29500" | |
dist.init_process_group(backend, rank=rank, world_size=size) | |
fn(rank, size) | |
if __name__ == "__main__": | |
size = 2 | |
processes = [] | |
mp.set_start_method("spawn") | |
for rank in range(size): | |
p = mp.Process(target=init_process, args=(rank, size, run)) | |
p.start() | |
processes.append(p) | |
for p in processes: | |
p.join() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment