Last active
September 28, 2018 07:07
-
-
Save oraoto/798b3989d4984eed0daf595d4f9e5360 to your computer and use it in GitHub Desktop.
nnabla graph visualization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import nnabla as nn | |
import nnabla.functions as F | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import graphviz as gv | |
def draw_graph(v, hide_params=True, op_as_edge=False): | |
graph = gv.Digraph() | |
params = nn.get_parameters(grad_only=False) | |
layer_count = {} | |
variables = {} | |
if type(v) is list: | |
v = F.sink(*v) | |
if op_as_edge: | |
hide_params = True | |
def add_variable(n, prefix): | |
if n not in variables.values(): | |
n_name = prefix + str(id(n)) | |
variables[n_name] = n | |
else: | |
n_name = list(variables.keys())[list(variables.values()).index(n)] | |
if hide_params and n in params.values(): | |
return False | |
attrs = { | |
'label': str(n.shape), | |
'style': 'filled', | |
'shape': 'box', | |
'align': 'center' | |
} | |
if n.need_grad: | |
attrs['fillcolor'] = '#f8baff' | |
else: | |
attrs['fillcolor'] = '#cbffba' | |
if n in params.values(): | |
attrs['size'] = '' | |
attrs['fillcolor'] = '#f8baff75' | |
graph.node(n_name, **attrs) | |
return n_name | |
def visit(f): | |
if f.name not in layer_count: | |
layer_count[f.name] = 1 | |
else: | |
layer_count[f.name] += 1 | |
if f.info.type_name == 'Sink': | |
return | |
f_name = f.name + '_' + str(layer_count[f.name]) | |
if not op_as_edge: | |
graph.node(f_name) | |
inputs = [] | |
outputs = [] | |
for inp in f.inputs: | |
n_id = add_variable(inp, f_name+ '_Input') | |
if n_id: | |
inputs.append(n_id) | |
if n_id and not op_as_edge: | |
graph.edge(n_id, f_name) | |
for oup in f.outputs: | |
n_id = add_variable(oup, f_name + '_Output') | |
if n_id: | |
outputs.append(n_id) | |
if not op_as_edge: | |
graph.edge(f_name, n_id) | |
if op_as_edge: | |
for i in inputs: | |
for o in outputs: | |
graph.edge(i, o, label=f_name) | |
v.visit(visit) | |
return graph |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
import nnabla as nn | |
import nnabla.functions as F | |
import nnabla.parametric_functions as PF | |
from graph import draw_graph | |
def rnn(xs, h0, hidden=32): | |
hs = [] | |
with nn.parameter_scope("rnn"): | |
h = h0 | |
for x in xs: | |
with nn.parameter_scope("x2h"): | |
x2h = PF.affine(x, hidden, with_bias=False) | |
with nn.parameter_scope("h2h"): | |
h2h = PF.affine(h, hidden) | |
h = F.tanh(x2h + h2h) | |
hs.append(h) | |
with nn.parameter_scope("classifier"): | |
y = PF.affine(h, 10) | |
return y, hs | |
seq_x = [ | |
nn.Variable([28, 28]), | |
nn.Variable([28, 28]), | |
nn.Variable([28, 28]) | |
] | |
h0 = nn.Variable((28, 32)) | |
y, hs = rnn(seq_x, h0, 32) | |
g = draw_graph(y, hide_params=False) | |
g.view() | |
#%% | |
g = draw_graph(y) | |
g.view() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment