Created
February 2, 2021 16:57
-
-
Save cbalint13/68953f677c995a2b7933c1f0ffef474f to your computer and use it in GitHub Desktop.
tvm-micro-pr7392
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import onnx | |
import numpy as np | |
import tvm | |
from tvm import te | |
import tvm.relay as relay | |
### | |
### | |
### | |
quantize = True | |
### | |
### | |
### | |
def get_model(onnx_file): | |
# input model | |
onnx_model = onnx.load(onnx_file) | |
# input data name and shape | |
data_name = onnx_model.graph.input[0].name | |
data_shape = list(d.dim_value for d in onnx_model.graph.input[0].type.tensor_type.shape.dim) | |
# import as tvm relay | |
mod, params = relay.frontend.from_onnx(onnx_model, {data_name: data_shape}) | |
return mod, params | |
def quantize(mod, params, data_aware): | |
if data_aware: | |
with relay.quantize.qconfig(skip_conv_layers=[], calibrate_mode='kl_divergence', weight_scale='max'): | |
mod = relay.quantize.quantize(mod, params, dataset=calibrate_dataset()) | |
else: | |
with relay.quantize.qconfig(nbit_input=8, nbit_weight=8, skip_conv_layers=[], skip_dense_layer=False, calibrate_mode='global_scale', weight_scale='power2', global_scale=8.0, round_for_shift=False, do_simulation=False): | |
mod = relay.quantize.quantize(mod, params) | |
return mod | |
def main(): | |
# load model | |
# https://drive.google.com/file/d/1hTWILOzKmseA16wWu50cPebzV53M_vD-/view?usp=sharing | |
mod, params = get_model('sine_model.onnx') | |
# print model | |
print(mod.astext(show_meta_data=False)) | |
if quantize == True: | |
mod = quantize(mod, params, data_aware=False) | |
# print quantized model | |
print(mod.astext(show_meta_data=False)) | |
TARGET = tvm.target.target.micro("nrf5340dk") | |
#TARGET = tvm.target.create('c -device=micro_dev') | |
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): | |
build = relay.build(mod, target=TARGET, params=params) | |
# print C sourcecode | |
#print(build.lib.imported_modules[0].get_source()) | |
build.lib.imported_modules[0].save("dump.c") | |
# compile a local object with our NN (if want to use it in as shared lib) | |
build.export_library('dump.o') | |
## | |
## continue with NRF target | |
## | |
import os | |
from tvm import micro | |
compiler = tvm.micro.DefaultCompiler(target=TARGET) | |
opts = tvm.micro.default_options(os.path.join(tvm.micro.CRT_ROOT_DIR, "host")) | |
# compile a standalone binary out of C code for target | |
workspace = tvm.micro.Workspace() | |
micro_binary = tvm.micro.build_static_runtime( | |
workspace, | |
compiler, | |
build, | |
lib_opts=opts["lib_opts"], | |
bin_opts=opts["bin_opts"], | |
extra_libs=[os.path.join(tvm.micro.build.CRT_ROOT_DIR, "memory")], | |
) | |
# you can continue with flasher = compiler.flasher() part of documentation | |
# https://tvm.apache.org/docs/tutorials/micro/micro_tflite.html#defining-the-target | |
# i dont have such board | |
BOARD = "nrf5340dk" | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment