Created
June 15, 2023 19:03
-
-
Save gilankpam/777713f2899fd385c1792a526e0a1be0 to your computer and use it in GitHub Desktop.
yolov8 rk3588
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from rknn.api import RKNN | |
ONNX_MODEL = 'yolov8m.onnx' | |
RKNN_MODEL = 'yolov8m.rknn' | |
DATASET = './dataset.txt' | |
QUANTIZE_ON = True | |
if __name__ == '__main__': | |
# Create RKNN object | |
rknn = RKNN(verbose=True) | |
# pre-process config | |
print('--> Config model') | |
rknn.config(mean_values=[[0, 0, 0]], std_values=[[255, 255, 255]], target_platform='RK3588') | |
print('done') | |
# Load ONNX model | |
print('--> Loading model') | |
ret = rknn.load_onnx(model=ONNX_MODEL) | |
if ret != 0: | |
print('Load model failed!') | |
exit(ret) | |
print('done') | |
# Build model | |
print('--> Building model') | |
ret = rknn.build(do_quantization=QUANTIZE_ON, dataset=DATASET) | |
if ret != 0: | |
print('Build model failed!') | |
exit(ret) | |
print('done') | |
# Export RKNN model | |
print('--> Export rknn model') | |
ret = rknn.export_rknn(RKNN_MODEL) | |
if ret != 0: | |
print('Export rknn model failed!') | |
exit(ret) | |
print('done') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Usage example | |
python yolov8.py --model ./yolov8m.rknn --img bus.jpg | |
""" | |
import cv2 | |
import numpy as np | |
from rknnlite.api import RKNNLite | |
import time | |
import argparse | |
RKNN_MODEL = 'yolov8m_RK3588_i8.rknn' | |
IMGSZ = (416, 416) | |
CLASSES = ("person", "bicycle", "car", "motorbike ", "aeroplane ", "bus ", "train", "truck ", "boat", "traffic light", | |
"fire hydrant", "stop sign ", "parking meter", "bench", "bird", "cat", "dog ", "horse ", "sheep", "cow", "elephant", | |
"bear", "zebra ", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", | |
"baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife ", | |
"spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza ", "donut", "cake", "chair", "sofa", | |
"pottedplant", "bed", "diningtable", "toilet ", "tvmonitor", "laptop ", "mouse ", "remote ", "keyboard ", "cell phone", "microwave ", | |
"oven ", "toaster", "sink", "refrigerator ", "book", "clock", "vase", "scissors ", "teddy bear ", "hair drier", "toothbrush ") | |
def preprocess(img_path): | |
img = cv2.imread(img_path) | |
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
img = cv2.resize(img, IMGSZ) | |
return img | |
def postprocess(output, confidence_thres=0.5, iou_thres=0.5): | |
outputs = np.transpose(np.squeeze(output[0])) | |
# Get the number of rows in the outputs array | |
rows = outputs.shape[0] | |
# Lists to store the bounding boxes, scores, and class IDs of the detections | |
boxes = [] | |
scores = [] | |
class_ids = [] | |
# Calculate the scaling factors for the bounding box coordinates | |
x_factor = 1 | |
y_factor = 1 | |
# Iterate over each row in the outputs array | |
for i in range(rows): | |
# Extract the class scores from the current row | |
classes_scores = outputs[i][4:] | |
# Find the maximum score among the class scores | |
max_score = np.amax(classes_scores) | |
# If the maximum score is above the confidence threshold | |
if max_score >= confidence_thres: | |
# Get the class ID with the highest score | |
class_id = np.argmax(classes_scores) | |
# Extract the bounding box coordinates from the current row | |
x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3] | |
# Calculate the scaled coordinates of the bounding box | |
x1 = int((x - w / 2) * x_factor) | |
y1 = int((y - h / 2) * y_factor) | |
x2 = x1 + int(w * x_factor) | |
y2 = y1 + int(h * y_factor) | |
# Add the class ID, score, and box coordinates to the respective lists | |
class_ids.append(class_id) | |
scores.append(max_score) | |
boxes.append([x1, y1, x2, y2]) | |
# Apply non-maximum suppression to filter out overlapping bounding boxes | |
indices = cv2.dnn.NMSBoxes(boxes, scores, confidence_thres, iou_thres) | |
detections = [] | |
# Iterate over the selected indices after non-maximum suppression | |
for i in indices: | |
detections.append([ | |
boxes[i], | |
scores[i], | |
class_ids[i] | |
]) | |
# Return the modified input image | |
return detections | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--img', type=str, default='bus.jpg') | |
parser.add_argument('--model', type=str, default='yolov8m_RK3588_i8.rknn') | |
opt = parser.parse_args() | |
args = vars(opt) | |
rknn_lite = RKNNLite() | |
ret = rknn_lite.load_rknn(args['model']) | |
if ret != 0: | |
print('Load RKNN model failed') | |
exit(ret) | |
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0_1_2) | |
if ret != 0: | |
print('Init runtime environment failed') | |
exit(ret) | |
start = time.time() | |
img_data = preprocess(args['img']) | |
outputs = rknn_lite.inference(inputs=[img_data]) | |
print(f"inference time: {(time.time() - start) * 1000} ms") | |
detections = postprocess(outputs[0]) | |
print(f"detection time: {(time.time() - start) * 1000} ms") | |
img_orig = cv2.imread(args['img']) | |
img_orig = cv2.resize(img_orig, IMGSZ) | |
for d in detections: | |
score, class_id = d[1], d[2] | |
x1, y1, x2, y2 = d[0][0], d[0][1], d[0][2], d[0][3] | |
cv2.rectangle(img_orig, (x1, y1), (x2, y2), 2) | |
label = f'{CLASSES[class_id]}: {score:.2f}' | |
label_height = 10 | |
label_x = x1 | |
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10 | |
cv2.putText(img_orig, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) | |
cv2.imwrite('yolov8_result.jpg', img_orig) | |
print(f"{(time.time() - start) * 1000} ms") |
Traceback (most recent call last):
File "/home/orangepi/test_yolov8.py", line 115, in <module>
detections = postprocess(outputs[0])
File "/home/orangepi/test_yolov8.py", line 66, in postprocess
x1 = int((x - w / 2) * x_factor)
TypeError: only length-1 arrays can be converted to Python scalars
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello can you answer on questions, how looks like DATASET = './dataset.txt'? I guess it's just labels or i am mistake?