Created
July 8, 2024 17:44
-
-
Save DarthSim/216551dfd58e5628290e90c1d358704b to your computer and use it in GitHub Desktop.
This patch allows exporting YOLOv10 to ONNX file compatible with OpenCV
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/export_opencv.py b/export_opencv.py | |
new file mode 100644 | |
index 00000000..15bfef90 | |
--- /dev/null | |
+++ b/export_opencv.py | |
@@ -0,0 +1,23 @@ | |
+from ultralytics import YOLOv10 | |
+import argparse | |
+ | |
+if __name__ == "__main__": | |
+ | |
+ parser = argparse.ArgumentParser() | |
+ parser.add_argument("--weights", type=str, | |
+ default="yolov10n.pt", | |
+ help="model.pt path") | |
+ parser.add_argument("--imgsz", type=int, nargs=2, | |
+ default=(640, 640), | |
+ help="Image size for the model") | |
+ parser.add_argument("--half", | |
+ action="store_true", | |
+ help="FP16 half-precision export") | |
+ args = parser.parse_args() | |
+ | |
+ model = YOLOv10(args.weights) | |
+ | |
+ model.export(format='onnx', | |
+ imgsz=args.imgsz, | |
+ simplify=True, | |
+ half=args.half) | |
diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py | |
index a9c5d9ee..544ab8b7 100644 | |
--- a/ultralytics/nn/modules/head.py | |
+++ b/ultralytics/nn/modules/head.py | |
@@ -79,7 +79,7 @@ class Detect(nn.Module): | |
def forward(self, x): | |
"""Concatenates and returns predicted bounding boxes and class probabilities.""" | |
y = self.forward_feat(x, self.cv2, self.cv3) | |
- | |
+ | |
if self.training: | |
return y | |
@@ -507,7 +507,7 @@ class v10Detect(Detect): | |
self.one2one_cv2 = copy.deepcopy(self.cv2) | |
self.one2one_cv3 = copy.deepcopy(self.cv3) | |
- | |
+ | |
def forward(self, x): | |
one2one = self.forward_feat([xi.detach() for xi in x], self.one2one_cv2, self.one2one_cv3) | |
if not self.export: | |
@@ -519,8 +519,7 @@ class v10Detect(Detect): | |
return {"one2many": one2many, "one2one": one2one} | |
else: | |
assert(self.max_det != -1) | |
- boxes, scores, labels = ops.v10postprocess(one2one.permute(0, 2, 1), self.max_det, self.nc) | |
- return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1) | |
+ return one2one | |
else: | |
return {"one2many": one2many, "one2one": one2one} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment