Spaces:
No application file
No application file
import numpy as np | |
import onnxruntime as ort | |
import torch | |
import torchvision | |
from utils import yolo_insert_nms | |
class YOLO11(torch.nn.Module): | |
def __init__(self, name) -> None: | |
super().__init__() | |
from ultralytics import YOLO | |
# Load a model | |
# build a new model from scratch | |
# model = YOLO(f'{name}.yaml') | |
# load a pretrained model (recommended for training) | |
model = YOLO("yolo11n.pt") | |
self.model = model.model | |
def forward(self, x): | |
"""https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/tasks.py#L216""" | |
pred: torch.Tensor = self.model(x)[0] # n 84 8400, | |
pred = pred.permute(0, 2, 1) | |
boxes, scores = pred.split([4, 80], dim=-1) | |
boxes = torchvision.ops.box_convert(boxes, in_fmt="cxcywh", out_fmt="xyxy") | |
return boxes, scores | |
def export_onnx(name="yolov8n"): | |
"""export onnx""" | |
m = YOLO11(name) | |
x = torch.rand(1, 3, 640, 640) | |
dynamic_axes = {"image": {0: "-1"}} | |
torch.onnx.export( | |
m, | |
x, | |
f"{name}.onnx", | |
input_names=["image"], | |
output_names=["boxes", "scores"], | |
opset_version=13, | |
dynamic_axes=dynamic_axes, | |
) | |
data = np.random.rand(1, 3, 640, 640).astype(np.float32) | |
sess = ort.InferenceSession(f"{name}.onnx") | |
_ = sess.run(output_names=None, input_feed={"image": data}) | |
import onnx | |
import onnxslim | |
model_onnx = onnx.load(f"{name}.onnx") | |
model_onnx = onnxslim.slim(model_onnx) | |
onnx.save(model_onnx, f"{name}.onnx") | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--name", type=str, default="yolo11n_tuned") | |
parser.add_argument("--score_threshold", type=float, default=0.01) | |
parser.add_argument("--iou_threshold", type=float, default=0.6) | |
parser.add_argument("--max_output_boxes", type=int, default=300) | |
args = parser.parse_args() | |
export_onnx(name=args.name) | |
yolo_insert_nms( | |
path=f"{args.name}.onnx", | |
score_threshold=args.score_threshold, | |
iou_threshold=args.iou_threshold, | |
max_output_boxes=args.max_output_boxes, | |
) | |