Spaces:
No application file
No application file
import contextlib | |
import time | |
from collections import OrderedDict | |
import numpy as np | |
import onnx | |
import onnx_graphsurgeon | |
import torch | |
from PIL import Image | |
def to_binary_data(path, size=(640, 640), output_name="input_tensor.bin"): | |
"""--loadInputs='image:input_tensor.bin'""" | |
im = Image.open(path).resize(size) | |
data = np.asarray(im, dtype=np.float32).transpose(2, 0, 1)[None] / 255.0 | |
data.tofile(output_name) | |
def yolo_insert_nms( | |
path, score_threshold=0.01, iou_threshold=0.7, max_output_boxes=300, simplify=False | |
): | |
""" | |
http://www.xavierdupre.fr/app/onnxcustom/helpsphinx/api/onnxops/onnx__EfficientNMS_TRT.html | |
https://huggingface.co/spaces/muttalib1326/Punjabi_Character_Detection/blob/3dd1e17054c64e5f6b2254278f96cfa2bf418cd4/utils/add_nms.py | |
""" | |
onnx_model = onnx.load(path) | |
if simplify: | |
from onnxsim import simplify | |
onnx_model, _ = simplify(onnx_model, overwrite_input_shapes={"image": [1, 3, 640, 640]}) | |
graph = onnx_graphsurgeon.import_onnx(onnx_model) | |
graph.toposort() | |
graph.fold_constants() | |
graph.cleanup() | |
topk = max_output_boxes | |
attrs = OrderedDict( | |
plugin_version="1", | |
background_class=-1, | |
max_output_boxes=topk, | |
score_threshold=score_threshold, | |
iou_threshold=iou_threshold, | |
score_activation=False, | |
box_coding=0, | |
) | |
outputs = [ | |
onnx_graphsurgeon.Variable("num_dets", np.int32, [-1, 1]), | |
onnx_graphsurgeon.Variable("det_boxes", np.float32, [-1, topk, 4]), | |
onnx_graphsurgeon.Variable("det_scores", np.float32, [-1, topk]), | |
onnx_graphsurgeon.Variable("det_classes", np.int32, [-1, topk]), | |
] | |
graph.layer( | |
op="EfficientNMS_TRT", | |
name="batched_nms", | |
inputs=[graph.outputs[0], graph.outputs[1]], | |
outputs=outputs, | |
attrs=attrs, | |
) | |
graph.outputs = outputs | |
graph.cleanup().toposort() | |
onnx.save(onnx_graphsurgeon.export_onnx(graph), "yolo_w_nms.onnx") | |
class TimeProfiler(contextlib.ContextDecorator): | |
def __init__( | |
self, | |
): | |
self.total = 0 | |
def __enter__( | |
self, | |
): | |
self.start = self.time() | |
return self | |
def __exit__(self, type, value, traceback): | |
self.total += self.time() - self.start | |
def reset( | |
self, | |
): | |
self.total = 0 | |
def time( | |
self, | |
): | |
if torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
return time.time() | |