Last active
September 18, 2021 04:33
-
-
Save usstq/81ed82be3b53bd0e6767b77c8cad40b0 to your computer and use it in GitHub Desktop.
improved net_drawer.py for onnx model visualize
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# SPDX-License-Identifier: Apache-2.0 | |
# A library and utility for drawing ONNX nets. Most of this implementation has | |
# been borrowed from the caffe2 implementation | |
# https://github.com/caffe2/caffe2/blob/master/caffe2/python/net_drawer.py | |
# | |
# The script takes two required arguments: | |
# -input: a path to a serialized ModelProto .pb file. | |
# -output: a path to write a dot file representation of the graph | |
# | |
# Given this dot file representation, you can-for example-export this to svg | |
# with the graphviz `dot` utility, like so: | |
# | |
# $ dot -Tsvg my_output.dot -o my_output.svg | |
# | |
# Improved by [email protected]: | |
# - directly exported to svg | |
# - remove value node | |
# - simplified node label | |
# - use tooltip feature | |
# - show shape infer result | |
# | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
from __future__ import unicode_literals | |
import argparse | |
from collections import defaultdict | |
import json | |
from os import path | |
import numpy | |
import onnx.numpy_helper | |
from onnx import shape_inference | |
from onnx import ModelProto, GraphProto, NodeProto, TensorProto | |
import pydot # type: ignore | |
from typing import Text, Any, Callable, Optional, Dict | |
OP_STYLE = { | |
'shape': 'box', | |
'color': 'lightblue', | |
'style': 'filled', | |
'margin':"0.1,0.1", | |
'height': "0.3" | |
} | |
CONST_STYLE = { | |
'style' : 'dashed', | |
'shape': 'Mrecord', | |
'margin':"0.03,0.03", | |
'height': "0.1" | |
} | |
INPUT_STYLE = { | |
'shape': 'box', | |
'color': 'gray1', | |
'style' : 'dotted', | |
} | |
INIT_STYLE = { | |
'shape': 'box', | |
} | |
BLOB_STYLE = {'shape': 'octagon'} | |
_NodeProducer = Callable[[NodeProto, int], pydot.Node] | |
def _escape_label(name): # type: (Text) -> Text | |
# json.dumps is poor man's escaping | |
return json.dumps(name) | |
def _form_and_sanitize_docstring(s): # type: (Text) -> Text | |
url = 'javascript:alert(' | |
url += _escape_label(s).replace('"', '\'').replace('<', '').replace('>', '') | |
url += ')' | |
return url | |
value_map = {} | |
elem_type_map = {} | |
for attr_name in dir(TensorProto): | |
attr = getattr(TensorProto, attr_name) | |
if isinstance(attr, int): | |
elem_type_map[attr] = attr_name | |
def get_shape(value_name): | |
str_shape = "" | |
if (value_name in value_map): | |
value = value_map[value_name] | |
tensor_type = value.type.tensor_type | |
if (tensor_type.HasField("elem_type")): | |
if tensor_type.elem_type in elem_type_map: | |
str_shape += elem_type_map[tensor_type.elem_type] + "\\n" | |
str_shape += "(" | |
if (tensor_type.HasField("shape")): | |
# iterate through dimensions of the shape: | |
for d in tensor_type.shape.dim: | |
if (str_shape[-1] != '('): | |
str_shape += "," | |
# the dimension may have a definite (integer) value or a symbolic identifier or neither: | |
if (d.HasField("dim_value")): | |
str_shape += str(d.dim_value) | |
elif (d.HasField("dim_param")): | |
str_shape += str(d.dim_param) # unknown dimension with symbolic name | |
else: | |
str_shape += "?" # unknown dimension with no name | |
else: | |
str_shape += "?" | |
str_shape += ")" | |
else: | |
str_shape = "?" | |
return str_shape | |
# | |
# https://github.com/sassoftware/python-dlpy/blob/master/dlpy/model_conversion/onnx_graph.py | |
# | |
def _convert_onnx_attribute_proto(attr_proto): | |
''' | |
Convert ONNX AttributeProto into Python object | |
''' | |
if attr_proto.HasField('f'): | |
return attr_proto.f | |
elif attr_proto.HasField('i'): | |
return attr_proto.i | |
elif attr_proto.HasField('s'): | |
return str(attr_proto.s, 'utf-8') | |
elif attr_proto.HasField('t'): | |
return attr_proto.t # this is a proto! | |
elif attr_proto.floats: | |
return list(attr_proto.floats) | |
elif attr_proto.ints: | |
return list(attr_proto.ints) | |
elif attr_proto.strings: | |
str_list = list(attr_proto.strings) | |
str_list = list(map(lambda x: str(x, 'utf-8'), str_list)) | |
return str_list | |
else: | |
raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto)) | |
def GetOpNodeProducer(embed_docstring=False, **kwargs): # type: (bool, **Any) -> _NodeProducer | |
def ReallyGetOpNode(op, op_id): # type: (NodeProto, int) -> pydot.Node | |
node_name = '%s_%d' % (op.op_type, op_id) | |
tooltip = node_name | |
for i, input in enumerate(op.input): | |
tooltip += '\n input' + str(i) + ' ' + input | |
for i, output in enumerate(op.output): | |
tooltip += '\n output' + str(i) + ' ' + output | |
for i, attr in enumerate(op.attribute): | |
tooltip += f"\n {attr.name}:{_convert_onnx_attribute_proto(attr)}" | |
label = op.op_type | |
if op.name: | |
label = op.name | |
if not (op.op_type in op.name): | |
label += "/" + op.op_type | |
op_style = kwargs | |
if(op.op_type == "Constant"): | |
op_style = CONST_STYLE | |
for v in op.attribute: | |
if v.name == 'value': | |
value = onnx.numpy_helper.to_array(v.t) | |
tooltip = str(value).replace("9223372036854775807", "max") | |
if (len(tooltip)): | |
if (len(tooltip) < 10): | |
label = tooltip | |
else: | |
label = tooltip[:10] + "..." | |
node = pydot.Node(node_name, label=label, tooltip=tooltip, **op_style) | |
if embed_docstring: | |
url = _form_and_sanitize_docstring(op.doc_string) | |
node.set_URL(url) | |
return node | |
return ReallyGetOpNode | |
def GetPydotGraph( | |
graph, # type: GraphProto | |
name=None, # type: Optional[Text] | |
rankdir='LR', # type: Text | |
node_producer=None, # type: Optional[_NodeProducer] | |
embed_docstring=False, # type: bool | |
title = "" | |
): # type: (...) -> pydot.Dot | |
if node_producer is None: | |
node_producer = GetOpNodeProducer(embed_docstring=embed_docstring, **OP_STYLE) | |
pydot_graph = pydot.Dot(name, rankdir=rankdir) | |
pydot_graph.set("labelloc","t") | |
pydot_graph.set("labelfontsize",30) | |
pydot_graph.set("label", title) | |
pydot_nodes = {} # type: Dict[Text, pydot.Node] | |
pydot_node_counts = defaultdict(int) # type: Dict[Text, int] | |
op2node = {} | |
initializers = {} | |
for t in graph.initializer: | |
initializers[t.name] = t | |
oploc_rank = [0 for x in range(len(graph.node))] | |
value2op={} | |
for op_id, op in enumerate(graph.node): | |
for index, name in enumerate(op.output): | |
value2op[name] = op_id | |
for op_id, op in enumerate(graph.node): | |
for index,name in enumerate(op.input): | |
if name in value2op: | |
op_id = value2op[name] | |
oploc_rank[op_id] += index | |
# add op node by rank | |
sort_index = numpy.argsort(oploc_rank) | |
for op_id in sort_index: | |
op = graph.node[op_id] | |
op_node = node_producer(op, op_id) | |
pydot_graph.add_node(op_node) | |
op2node[op_id] = op_node | |
for index, name in enumerate(op.output): | |
if name not in pydot_nodes: | |
pydot_nodes[name] = { | |
"name":name, | |
"to":index, | |
"from":index, | |
"op_node":op_node, | |
"shape": get_shape(name), | |
"consumer_cnt" : 0 | |
} | |
for v in graph.input: | |
in_node = pydot.Node(v.name, label=f'"input:{v.name}"', tooltip="input value", **INPUT_STYLE) | |
pydot_graph.add_node(in_node) | |
pydot_nodes[v.name] = { | |
"name":v.name, | |
"to":0, | |
"from":0, | |
"op_node":in_node, | |
"shape": get_shape(name), | |
"consumer_cnt" : 0 | |
} | |
for v in graph.output: | |
in_node = pydot.Node(v.name, label=f'"output:{v.name}"', tooltip="output value", **INPUT_STYLE) | |
pydot_graph.add_node(in_node) | |
pydot_nodes[v.name] = { | |
"name":v.name, | |
"to":0, | |
"from":0, | |
"op_node":in_node, | |
"shape": get_shape(name), | |
"consumer_cnt" : 0 | |
} | |
output_names = [v.name for v in graph.output] | |
for op_id, op in enumerate(graph.node): | |
for index,name in enumerate(op.input): | |
if (name not in pydot_nodes): | |
pydot_nodes[name] = { | |
"name":name, | |
"to":0, | |
"from":0, | |
"op_node":in_node, | |
"consumer_cnt" : 0 | |
} | |
if name in initializers: | |
t = initializers[name] | |
tensor = onnx.numpy_helper.to_array(t) | |
node_label = str(tensor) | |
tooltip = name | |
if (len(node_label) > 10): | |
tooltip = node_label | |
node_label = name | |
if (len(tooltip) > 128): | |
tooltip = tooltip[:128] + "..." | |
in_node = pydot.Node(name, label=f'"{node_label}"', tooltip=tooltip, **INIT_STYLE) | |
pydot_graph.add_node(in_node) | |
pydot_nodes[name]["shape"] = f'{tensor.shape}' | |
pydot_nodes[name]["op_node"] = in_node | |
else: | |
node_name = name | |
if (node_name == ""): | |
node_name = "?" | |
in_node = pydot.Node(node_name, label=f'"?:{name}"', tooltip="unknown", **INIT_STYLE) | |
pydot_graph.add_node(in_node) | |
pydot_nodes[name]["shape"] = get_shape(name) | |
pydot_nodes[name]["op_node"] = in_node | |
pydot_n = pydot_nodes[name] | |
pydot_n["to"] = index | |
pydot_n["consumer_cnt"] += 1 | |
if (pydot_n["op_node"]): | |
pydot_graph.add_edge(pydot.Edge(pydot_n["op_node"].get_name(), op2node[op_id].get_name() + ":" + str(pydot_n["to"]), | |
#taillabel=str(pydot_n["from"]), | |
#headlabel=str(pydot_n["to"]), | |
label = pydot_n["shape"] + "\n" + str(pydot_n["from"]) + "->" + str(pydot_n["to"]), | |
tooltip = "\"" + name + "\"" | |
)) | |
# to output | |
for index, name in enumerate(op.output): | |
if name in output_names: | |
pydot_n = pydot_nodes[name] | |
pydot_graph.add_edge(pydot.Edge(op2node[op_id], pydot_n["op_node"], | |
label = get_shape(name) + "\n" + str(pydot_n["from"]) + "->" + str(pydot_n["to"]))) | |
return pydot_graph | |
def main(): # type: () -> None | |
parser = argparse.ArgumentParser(description="ONNX net drawer") | |
parser.add_argument( | |
"--input", | |
type=Text, required=True, | |
help="The input protobuf file.", | |
) | |
parser.add_argument( | |
"--output", | |
type=Text, required=True, | |
help="The output protobuf file.", | |
) | |
parser.add_argument( | |
"--rankdir", type=Text, default='TD', | |
help="The rank direction of the pydot graph.", | |
) | |
parser.add_argument( | |
"--embed_docstring", action="store_true", | |
help="Embed docstring as javascript alert. Useful for SVG format.", | |
) | |
args = parser.parse_args() | |
model = ModelProto() | |
with open(args.input, 'rb') as fid: | |
content = fid.read() | |
model.ParseFromString(content) | |
inferred_model = shape_inference.infer_shapes(model) | |
for v in inferred_model.graph.value_info: | |
value_map[v.name] = v | |
for v in model.graph.input: | |
value_map[v.name] = v | |
for v in model.graph.output: | |
value_map[v.name] = v | |
opset_import = [ str(x).strip('\n') for x in model.opset_import] | |
title = f"ir_version:{model.ir_version} model_version:{model.model_version} opset_import:{opset_import} producer:{model.producer_name} {model.producer_version}" | |
pydot_graph = GetPydotGraph( | |
model.graph, | |
name=model.graph.name, | |
rankdir=args.rankdir, | |
node_producer=GetOpNodeProducer( | |
embed_docstring=args.embed_docstring, | |
**OP_STYLE | |
), | |
title = title | |
) | |
pydot_graph.write(args.output, format = args.output.split(".")[-1]) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment