Last active
February 19, 2025 15:09
-
Star
(111)
You must be signed in to star a gist -
Fork
(15)
You must be signed in to fork a gist
-
-
Save apaszke/f93a377244be9bfcb96d3547b9bc424d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from graphviz import Digraph | |
import torch | |
from torch.autograd import Variable, Function | |
def iter_graph(root, callback): | |
queue = [root] | |
seen = set() | |
while queue: | |
fn = queue.pop() | |
if fn in seen: | |
continue | |
seen.add(fn) | |
for next_fn, _ in fn.next_functions: | |
if next_fn is not None: | |
queue.append(next_fn) | |
callback(fn) | |
def register_hooks(var): | |
fn_dict = {} | |
def hook_cb(fn): | |
def register_grad(grad_input, grad_output): | |
fn_dict[fn] = grad_input | |
fn.register_hook(register_grad) | |
iter_graph(var.grad_fn, hook_cb) | |
def is_bad_grad(grad_output): | |
grad_output = grad_output.data | |
return grad_output.ne(grad_output).any() or grad_output.gt(1e6).any() | |
def make_dot(): | |
node_attr = dict(style='filled', | |
shape='box', | |
align='left', | |
fontsize='12', | |
ranksep='0.1', | |
height='0.2') | |
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) | |
def size_to_str(size): | |
return '('+(', ').join(map(str, size))+')' | |
def build_graph(fn): | |
if hasattr(fn, 'variable'): # if GradAccumulator | |
u = fn.variable | |
node_name = 'Variable\n ' + size_to_str(u.size()) | |
dot.node(str(id(u)), node_name, fillcolor='lightblue') | |
else: | |
assert fn in fn_dict, fn | |
fillcolor = 'white' | |
if any(is_bad_grad(gi) for gi in fn_dict[fn]): | |
fillcolor = 'red' | |
dot.node(str(id(fn)), str(type(fn).__name__), fillcolor=fillcolor) | |
for next_fn, _ in fn.next_functions: | |
if next_fn is not None: | |
next_id = id(getattr(next_fn, 'variable', next_fn)) | |
dot.edge(str(next_id), str(id(fn))) | |
iter_graph(var.grad_fn, build_graph) | |
return dot | |
return make_dot | |
if __name__ == '__main__': | |
x = Variable(torch.randn(10, 10), requires_grad=True) | |
y = Variable(torch.randn(10, 10), requires_grad=True) | |
z = x / (y * 0) | |
z = z.sum() * 2 | |
get_dot = register_hooks(z) | |
z.backward() | |
dot = get_dot() | |
dot.save('tmp.dot') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I run out of ram using this code, and if I try only running it once after n iterations, it crashes.
However, I got some interesting graph before the ram ran out. Does anyone have any suggestion on how to approach the problem? The graph is very very large, and there are red nodes everywhere but this is the end of it:
My model is not very complicated (apart from the transformer itself):