Last active
September 23, 2019 21:13
-
-
Save yaroslavvb/e53c83c40c8385cd90cdc15c7c61fa63 to your computer and use it in GitHub Desktop.
Example of Python multi-threading giving a mix of .grad from different backward calls
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import threading | |
import torch | |
import torch.nn as nn | |
def simple_model(d, n): | |
"""Creates linear neural network initialized to identity""" | |
layers = [] | |
for i in range(n): | |
layer = nn.Linear(d, d, bias=False) | |
layer.weight.data.copy_(torch.eye(d)) | |
layers.append(layer) | |
return torch.nn.Sequential(*layers) | |
def propagate(output, gradient, sleep_before1, sleep_before2, label): | |
def f(): | |
time.sleep(sleep_before1) | |
output.backward(gradient, retain_graph=True) | |
grad1 = model[0].weight.grad.detach().clone() | |
time.sleep(sleep_before2) | |
grad2 = model[1].weight.grad | |
print(f"{label} observed gradients ", grad1[0, 0].item(), grad2[0, 0].item()) | |
return threading.Thread(target=f, args=()) | |
# Create simple model with two scenarios, all gradients=1, or all gradients=0 | |
model = simple_model(2, 2) | |
x = torch.ones(1, 2) | |
y = model(x) | |
propagate2 = propagate(y, x, sleep_before1=0.5, sleep_before2=0, label="thread2") # observes gradients 1, 1 | |
propagate1 = propagate(y, x - x, sleep_before1=0, sleep_before2=1, label="thread1") # should get gradients 0, 0, but instead gets 0, 1 because of another thread2 | |
propagate1.start() | |
propagate2.start() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment