Last active
June 19, 2023 19:28
-
-
Save LukeAI/bbfc3ab749601ab0f2cb06e4b8fc75cb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import os | |
from super_gradients.training import models | |
from super_gradients.common.object_names import Models | |
import onnx | |
import torch | |
import torch.nn as nn | |
# CONFIG | |
NO_CLASSES=80 | |
batch_size = 1 | |
topk_all = 100 | |
input_shape = (3, 640, 640) | |
iou_thres=0.45 | |
score_thres=0.25 | |
end2end=True | |
onnx_path = "yolo_nas_s.onnx" | |
net = models.get(Models.YOLO_NAS_S, pretrained_weights="coco") | |
#net = models.get(Models.YOLO_NAS_L, num_classes=NO_CLASSES, | |
# checkpoint_path="/home/luke/yoloNAS/checkpoints/yolo_nas_l_spss_export/ckpt_latest.pth") | |
class TRT_NMS(torch.autograd.Function): | |
'''TensorRT NMS operation''' | |
@staticmethod | |
def forward( | |
ctx, | |
boxes, | |
scores, | |
background_class=-1, | |
box_coding=0, | |
iou_threshold=0.45, | |
max_output_boxes=100, | |
plugin_version="1", | |
score_activation=0, | |
score_threshold=0.25, | |
class_agnostic=1 | |
): | |
batch_size, num_boxes, num_classes = scores.shape | |
num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32) | |
det_boxes = torch.randn(batch_size, max_output_boxes, 4) | |
det_scores = torch.randn(batch_size, max_output_boxes) | |
det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32) | |
return num_det, det_boxes, det_scores, det_classes | |
@staticmethod | |
def symbolic(g, | |
boxes, | |
scores, | |
background_class=-1, | |
box_coding=0, | |
iou_threshold=0.45, | |
max_output_boxes=100, | |
plugin_version="1", | |
score_activation=0, | |
score_threshold=0.25, | |
class_agnostic=1 | |
): | |
out = g.op("TRT::EfficientNMS_TRT", | |
boxes, | |
scores, | |
background_class_i=background_class, | |
box_coding_i=box_coding, | |
iou_threshold_f=iou_threshold, | |
max_output_boxes_i=max_output_boxes, | |
plugin_version_s=plugin_version, | |
class_agnostic_i=class_agnostic, | |
score_activation_i=score_activation, | |
score_threshold_f=score_threshold, | |
outputs=4) | |
nums, boxes, scores, classes = out | |
return nums, boxes, scores, classes | |
class ONNX_TRT(nn.Module): | |
'''onnx module with TensorRT NMS operation.''' | |
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80): | |
super().__init__() | |
assert max_wh is None | |
self.device = device if device else torch.device('cpu') | |
self.background_class = -1, | |
self.box_coding = 0, | |
self.iou_threshold = iou_thres | |
self.max_obj = max_obj | |
self.plugin_version = '1' | |
self.score_activation = 0 | |
self.score_threshold = score_thres | |
self.n_classes=n_classes | |
def forward(self, x): | |
boxes, confscores = x | |
num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(boxes, confscores, self.background_class, self.box_coding, | |
self.iou_threshold, self.max_obj, | |
self.plugin_version, self.score_activation, | |
self.score_threshold) | |
return num_det, det_boxes, det_scores, det_classes | |
net.eval() | |
net.prep_model_for_conversion() | |
# https://github.com/Deci-AI/super-gradients/blob/master/documentation/source/BenchmarkingYoloNAS.md | |
if (end2end): | |
onnx_path = os.path.splitext(onnx_path)[0] + "_nms" + ".onnx" | |
NMS = ONNX_TRT( | |
max_obj=topk_all, iou_thres=iou_thres, score_thres=score_thres, max_wh=None ,device=None, n_classes=NO_CLASSES | |
) | |
NMS.eval() | |
onnx_export_kwargs = { | |
'input_names' : ['images'], | |
'output_names' : ["num_dets", "det_boxes", "det_scores", "det_classes"] | |
} | |
models.convert_to_onnx(model=net, input_shape=input_shape, post_process=NMS, out_path=onnx_path, | |
torch_onnx_export_kwargs=onnx_export_kwargs) | |
else: | |
models.convert_to_onnx(model=net, input_shape=input_shape, out_path=onnx_path) | |
# set output dimensions | |
# note: this makes no functional difference, just explicitly labels output dims | |
# so can be understood better when onnx inspected with netron etc. | |
shapes = [batch_size, 1, | |
batch_size, topk_all, 4, | |
batch_size, topk_all, | |
batch_size, topk_all] | |
onnx_model = onnx.load(onnx_path) # load onnx model | |
onnx.checker.check_model(onnx_model) # check onnx model | |
for i in onnx_model.graph.output: | |
for j in i.type.tensor_type.shape.dim: | |
j.dim_param = str(shapes.pop(0)) | |
onnx.save(onnx_model, onnx_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
no idea I'm afraid, if you copy and paste the code above and run it as-is, without modification it does that? is your super-gradients and onnxsim/onnxruntime etc. up to date?