import contextlib import time from collections import OrderedDict import numpy as np import onnx import onnx_graphsurgeon import torch from PIL import Image def to_binary_data(path, size=(640, 640), output_name="input_tensor.bin"): """--loadInputs='image:input_tensor.bin'""" im = Image.open(path).resize(size) data = np.asarray(im, dtype=np.float32).transpose(2, 0, 1)[None] / 255.0 data.tofile(output_name) def yolo_insert_nms( path, score_threshold=0.01, iou_threshold=0.7, max_output_boxes=300, simplify=False ): """ http://www.xavierdupre.fr/app/onnxcustom/helpsphinx/api/onnxops/onnx__EfficientNMS_TRT.html https://huggingface.co/spaces/muttalib1326/Punjabi_Character_Detection/blob/3dd1e17054c64e5f6b2254278f96cfa2bf418cd4/utils/add_nms.py """ onnx_model = onnx.load(path) if simplify: from onnxsim import simplify onnx_model, _ = simplify(onnx_model, overwrite_input_shapes={"image": [1, 3, 640, 640]}) graph = onnx_graphsurgeon.import_onnx(onnx_model) graph.toposort() graph.fold_constants() graph.cleanup() topk = max_output_boxes attrs = OrderedDict( plugin_version="1", background_class=-1, max_output_boxes=topk, score_threshold=score_threshold, iou_threshold=iou_threshold, score_activation=False, box_coding=0, ) outputs = [ onnx_graphsurgeon.Variable("num_dets", np.int32, [-1, 1]), onnx_graphsurgeon.Variable("det_boxes", np.float32, [-1, topk, 4]), onnx_graphsurgeon.Variable("det_scores", np.float32, [-1, topk]), onnx_graphsurgeon.Variable("det_classes", np.int32, [-1, topk]), ] graph.layer( op="EfficientNMS_TRT", name="batched_nms", inputs=[graph.outputs[0], graph.outputs[1]], outputs=outputs, attrs=attrs, ) graph.outputs = outputs graph.cleanup().toposort() onnx.save(onnx_graphsurgeon.export_onnx(graph), "yolo_w_nms.onnx") class TimeProfiler(contextlib.ContextDecorator): def __init__( self, ): self.total = 0 def __enter__( self, ): self.start = self.time() return self def __exit__(self, type, value, traceback): self.total += self.time() - self.start def reset( self, ): self.total = 0 def time( self, ): if torch.cuda.is_available(): torch.cuda.synchronize() return time.time()