import numpy as np import onnxruntime as ort import torch import torchvision from utils import yolo_insert_nms class YOLO11(torch.nn.Module): def __init__(self, name) -> None: super().__init__() from ultralytics import YOLO # Load a model # build a new model from scratch # model = YOLO(f'{name}.yaml') # load a pretrained model (recommended for training) model = YOLO("yolo11n.pt") self.model = model.model def forward(self, x): """https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/tasks.py#L216""" pred: torch.Tensor = self.model(x)[0] # n 84 8400, pred = pred.permute(0, 2, 1) boxes, scores = pred.split([4, 80], dim=-1) boxes = torchvision.ops.box_convert(boxes, in_fmt="cxcywh", out_fmt="xyxy") return boxes, scores def export_onnx(name="yolov8n"): """export onnx""" m = YOLO11(name) x = torch.rand(1, 3, 640, 640) dynamic_axes = {"image": {0: "-1"}} torch.onnx.export( m, x, f"{name}.onnx", input_names=["image"], output_names=["boxes", "scores"], opset_version=13, dynamic_axes=dynamic_axes, ) data = np.random.rand(1, 3, 640, 640).astype(np.float32) sess = ort.InferenceSession(f"{name}.onnx") _ = sess.run(output_names=None, input_feed={"image": data}) import onnx import onnxslim model_onnx = onnx.load(f"{name}.onnx") model_onnx = onnxslim.slim(model_onnx) onnx.save(model_onnx, f"{name}.onnx") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--name", type=str, default="yolo11n_tuned") parser.add_argument("--score_threshold", type=float, default=0.01) parser.add_argument("--iou_threshold", type=float, default=0.6) parser.add_argument("--max_output_boxes", type=int, default=300) args = parser.parse_args() export_onnx(name=args.name) yolo_insert_nms( path=f"{args.name}.onnx", score_threshold=args.score_threshold, iou_threshold=args.iou_threshold, max_output_boxes=args.max_output_boxes, )