Last active
May 5, 2025 12:09
-
-
Save MRo47/ca690ba6d2e9ce7637de903b6eae4b03 to your computer and use it in GitHub Desktop.
Export YOLO-NAS model to ONNX from super-gradients, used for C++ inference engine in MRo47/yolo_nas_cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyNFJp1W7CsI5i7TqYMxCROc"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# Export YOLO-NAS ONNX model and metadata for inference\n","\n","Follow and execute the notebook instructions below to fetch and export yolo_nas models, change variables where mentioned to export a different model.\n","\n","> __NOTE__: GPU session is not required to export the model.\n","\n","## 1. Fetch and install super-gradients"],"metadata":{"id":"11H6LoJoLqZD"}},{"cell_type":"code","source":["!git clone https://github.com/Deci-AI/super-gradients.git"],"metadata":{"id":"Rn5XtEBd_178"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["We switch to PR #2061 to get the fixes for model URLs.\n","\n","> __IMPORTANT__: After this installation step the session will ask for restart. Click on restart and **follow from step 2**. No need to run any installation steps below."],"metadata":{"id":"u_gH3lLvMQ2V"}},{"cell_type":"code","source":["%cd /content/super-gradients/\n","!git fetch origin pull/2061/head:url_fixes\n","!git switch url_fixes\n","!pip3 install -r requirements.txt && python3 -m pip install -e ."],"metadata":{"id":"0qlAcPOWC6ZA"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## 2. Import model and fetch model weights"],"metadata":{"id":"8LervBWANtJG"}},{"cell_type":"code","source":["from super_gradients.training import models"],"metadata":{"id":"V9rSMMM8B9bB"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Change **model_type** to the required model type *(yolo_nas_s, yolo_nas_m, yolo_nas_l)*"],"metadata":{"id":"mY34hau0N9Ae"}},{"cell_type":"code","source":["model_type = \"yolo_nas_s\"\n","\n","model = models.get(model_type, pretrained_weights=\"coco\")"],"metadata":{"id":"-AFkqLi6q4Hi"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## 3. Export model to ONNX\n","\n","Here the ONNX opset version can be set, currently set to version 11. This is to match the opset used by ONNX runtime compiled with OpenCV 4.6.0.\n","> __NOTE__: This is not the same as ONNX runtime version\n","\n","After the step completes the model will be available in the current directory as `yolo_nas_<version>.onnx`."],"metadata":{"id":"ZHCj7cYdOaeP"}},{"cell_type":"code","source":["model = models.get(model_type, pretrained_weights=\"coco\")\n","model.eval()\n","model.prep_model_for_conversion(input_size=(1, 3, 640, 640))\n","# models.convert_to_onnx(model=model, prep_model_for_conversion_kwargs={\"input_size\":(1, 3, 640, 640)}, out_path=\"yolo_nas_s.onnx\",\n","# torch_onnx_export_kwargs={\"opset_version\":14})\n","model.export(f\"{model_type}.onnx\", postprocessing=None, preprocessing=None,\n"," onnx_export_kwargs={\"opset_version\":11})"],"metadata":{"id":"ejfT_oXlDyJs"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## 4. Export metadata\n","\n","The pre and post processing steps are serialised here into json, this will be used in the inference engine to create the pre and post processing pipelines.\n","\n","The metadata will be available in current directory as `yolo_nas_<version>-metadata.json`."],"metadata":{"id":"0bbjPBCePlYE"}},{"cell_type":"code","source":["import super_gradients.training.processing as processing\n","\n","def get_preprocessing_steps(preprocessing):\n"," if isinstance(preprocessing, processing.StandardizeImage):\n"," return {\"StandardizeImage\": {\"max_value\": preprocessing.max_value}}\n"," elif isinstance(preprocessing, processing.DetectionRescale):\n"," return {\n"," \"DetectionRescale\": {\n"," \"output_shape\": preprocessing.output_shape\n"," }\n"," }\n"," elif isinstance(preprocessing, processing.DetectionLongestMaxSizeRescale):\n"," return {\n"," \"DetectionLongestMaxSizeRescale\": {\n"," \"output_shape\": preprocessing.output_shape\n"," }\n"," }\n"," elif isinstance(preprocessing, processing.DetectionBottomRightPadding):\n"," return {\n"," \"DetectionBottomRightPadding\": {\n"," \"pad_value\": preprocessing.pad_value,\n"," \"output_shape\": preprocessing.output_shape\n"," }\n"," }\n"," elif isinstance(preprocessing, processing.DetectionCenterPadding):\n"," return {\n"," \"DetectionCenterPadding\": {\n"," \"pad_value\": preprocessing.pad_value,\n"," \"output_shape\": preprocessing.output_shape\n"," }\n"," }\n"," elif isinstance(preprocessing, processing.NormalizeImage):\n"," return {\n"," \"NormalizeImage\": {\"mean\": preprocessing.mean.tolist(), \"std\": preprocessing.std.tolist()}\n"," }\n"," elif isinstance(preprocessing, processing.ImagePermute):\n"," return {\n"," \"ImagePermute\": {\n"," \"order\": preprocessing.permutation\n"," }\n"," }\n"," elif isinstance(preprocessing, processing.ReverseImageChannels):\n"," return {\n"," \"ReverseImageChannels\": True\n"," }\n"," else:\n"," raise NotImplemented(\"Model have processing steps that haven't been implemented\")\n","\n","def get_postprocessing_steps(postprocessing):\n"," if isinstance(postprocessing, processing.DetectionLongestMaxSizeRescale):\n"," return {\n"," \"DetectionLongestMaxSizeRescale\": {\n"," \"output_shape\": postprocessing.output_shape\n"," }\n"," }\n","\n","# serialise preprocessing\n","preprocessing_steps = [\n"," get_preprocessing_steps(st) for st in model._image_processor.processings\n","]\n","\n","import numpy as np\n","dummy = np.random.randint(0, 255, (1000, 800, 3), dtype=np.uint8)\n","\n","input_shape = np.expand_dims(model._image_processor.preprocess_image(dummy)[0], 0).shape\n","\n","postprocessing_steps = {\n"," \"NMS\": {\n"," \"iou\": model._default_nms_iou,\n"," \"conf\": model._default_nms_conf\n"," }\n","}\n","\n","# get coco labels\n","labels = model.get_class_names()\n","\n","# create metadata object\n","metadata = {\n"," \"type\": model_type,\n"," \"input_shape\": input_shape,\n"," \"post_processing\": postprocessing_steps,\n"," \"pre_processing\": preprocessing_steps,\n"," \"labels\": labels,\n"," }\n","\n","# export metadata\n","import json\n","\n","filename = f\"{model_type}-metadata.json\"\n","with open(filename, \"w\") as f:\n"," f.write(json.dumps(metadata))\n","\n","metadata"],"metadata":{"id":"C-I0AlalX9xK"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# 5. Download model and metadata\n","- `/content/yolo_nas_<version>.onnx`\n","- `/content/yolo_nas_<version>-metadata.json`\n","\n","Click the folder icon on the left > right click on the file > click download."],"metadata":{"id":"uzNTroE7R_uG"}}]} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment