Last active
December 28, 2020 05:14
-
-
Save vishwanath79/0db14419c2841c843d09bf5420a1bdd0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"import numpy as np\n", | |
"from tensorflow import keras\n", | |
"import onnxruntime" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"outputs": [], | |
"source": [ | |
"model = tf.keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])\n", | |
"# 1 layer 1 value" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model: \"sequential_2\"\n", | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"dense_2 (Dense) (None, 1) 2 \n", | |
"=================================================================\n", | |
"Total params: 2\n", | |
"Trainable params: 2\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"model.compile(optimizer='sgd',loss='mean_squared_error')\n", | |
"model.summary()\n" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"outputs": [], | |
"source": [ | |
"# Neural network to predict y = 3x + 1\n", | |
"xs = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=float)\n", | |
"ys = np.array([4.0, 7.0, 10.0, 13.0, 16.0, 19.0], dtype=float)" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "<tensorflow.python.keras.callbacks.History at 0x7f96ba8d43d0>" | |
}, | |
"execution_count": 27, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"#Train model\n", | |
"model.fit(xs,ys,epochs=500,verbose=False)" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "array([[31.02069]], dtype=float32)" | |
}, | |
"execution_count": 28, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"original = model.predict([10.0])\n", | |
"\n", | |
"original" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Assets written to: tf/saved/assets\n" | |
] | |
} | |
], | |
"source": [ | |
"# save model architecture, weights, and training configuration in a single file/folder\n", | |
"model.save('tf/saved')" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2020-12-27 12:44:43,732 - WARNING - '--tag' not specified for saved_model. Using --tag serve\r\n", | |
"2020-12-27 12:44:43,800 - INFO - Signatures found in model: [serving_default].\r\n", | |
"2020-12-27 12:44:43,800 - WARNING - '--signature_def' not specified, using first signature: serving_default\r\n", | |
"WARNING:tensorflow:From /Users/vishwanath/opt/miniconda3/envs/onnx_virt/lib/python3.8/site-packages/tf2onnx/tf_loader.py:416: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\r\n", | |
"Instructions for updating:\r\n", | |
"Use `tf.compat.v1.graph_util.extract_sub_graph`\r\n", | |
"2020-12-27 12:44:43,816 - WARNING - From /Users/vishwanath/opt/miniconda3/envs/onnx_virt/lib/python3.8/site-packages/tf2onnx/tf_loader.py:416: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\r\n", | |
"Instructions for updating:\r\n", | |
"Use `tf.compat.v1.graph_util.extract_sub_graph`\r\n", | |
"2020-12-27 12:44:43,823 - INFO - Using tensorflow=2.4.0, onnx=1.8.0, tf2onnx=1.7.2/995bd6\r\n", | |
"2020-12-27 12:44:43,823 - INFO - Using opset <onnx, 10>\r\n", | |
"2020-12-27 12:44:43,823 - INFO - Computed 0 values for constant folding\r\n", | |
"2020-12-27 12:44:43,829 - INFO - Optimizing ONNX model\r\n", | |
"2020-12-27 12:44:43,836 - INFO - After optimization: Identity -5 (5->0)\r\n", | |
"2020-12-27 12:44:43,836 - INFO - \r\n", | |
"2020-12-27 12:44:43,837 - INFO - Successfully converted TensorFlow model tf/saved to ONNX\r\n", | |
"2020-12-27 12:44:43,837 - INFO - ONNX model is saved at tf/output/simplemodel.onnx\r\n" | |
] | |
} | |
], | |
"source": [ | |
"#Convert to ONNX\n", | |
"\n", | |
"# opset 8 to generate the graph. By specifying --opset the user can override the default to generate a graph with the desired opset. For example --opset 5 would create a onnx graph that uses only ops available in opset 5. Because older opsets have in most cases fewer ops, some models might not convert on a older opset.\n", | |
"!python -m tf2onnx.convert --opset 10 --saved-model tf/saved --output tf/output/simplemodel.onnx\n" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"input name dense_2_input:0\n", | |
"input shape ['unk__6', 1]\n" | |
] | |
} | |
], | |
"source": [ | |
"path = 'tf/output/simplemodel.onnx'\n", | |
"sess = onnxruntime.InferenceSession(path)\n", | |
"# get name and shape\n", | |
"input_name = sess.get_inputs()[0].name\n", | |
"print(\"input name\", input_name)\n", | |
"input_shape = sess.get_inputs()[0].shape\n", | |
"print(\"input shape\", input_shape)\n", | |
"\n" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"Input values to ONNX model = \n", | |
"[[10.]]\n", | |
"\n", | |
"Output value from ONNX model = \n", | |
"[array([[31.02069]], dtype=float32)]\n" | |
] | |
} | |
], | |
"source": [ | |
"# predict\n", | |
"x = np.array([[10.0]], dtype=np.float32)\n", | |
"print(\"\\nInput values to ONNX model = \")\n", | |
"print(x)\n", | |
"res = sess.run(None, {input_name: x})\n", | |
"print(\"\\nOutput value from ONNX model = \")\n", | |
"print(res)\n" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"outputs": [], | |
"source": [ | |
"# Check if equal\n", | |
"assert res == original" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"outputs": [], | |
"source": [], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment