Eric P. Nusbaum
Benchmarking
cb58d33
raw
history blame
6.13 kB
import os
import numpy as np
import onnx
import onnxruntime
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import time # Import time for benchmarking
# Constants
PROB_THRESHOLD = 0.5 # Minimum probability to show results
MODEL_PATH = os.path.join("onnx", "model.onnx")
LABELS_PATH = os.path.join("onnx", "labels.txt")
# Load labels
with open(LABELS_PATH, "r") as f:
LABELS = f.read().strip().split("\n")
class Model:
def __init__(self, model_filepath):
self.session = onnxruntime.InferenceSession(model_filepath)
assert len(self.session.get_inputs()) == 1
self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W)
self.input_name = self.session.get_inputs()[0].name
self.input_type = {'tensor(float)': np.float32, 'tensor(float16)': np.float16}.get(
self.session.get_inputs()[0].type, np.float32
)
self.output_names = [o.name for o in self.session.get_outputs()]
self.is_bgr = False
self.is_range255 = False
onnx_model = onnx.load(model_filepath)
for metadata in onnx_model.metadata_props:
if metadata.key == 'Image.BitmapPixelFormat' and metadata.value == 'Bgr8':
self.is_bgr = True
elif metadata.key == 'Image.NominalPixelRange' and metadata.value == 'NominalRange_0_255':
self.is_range255 = True
def predict(self, image: Image.Image):
# Preprocess image
image_resized = image.resize(self.input_shape)
input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W)
if self.is_bgr:
input_array = input_array[:, (2, 1, 0), :, :]
if not self.is_range255:
input_array = input_array / 255.0 # Normalize to [0,1]
# Run inference with benchmarking
start_time = time.time() # Start timing
outputs = self.session.run(self.output_names, {self.input_name: input_array.astype(self.input_type)})
end_time = time.time() # End timing
execution_time = (end_time - start_time) * 1000 # Convert to milliseconds
print(f"Inference time: {execution_time:.2f} ms")
return {name: outputs[i] for i, name in enumerate(self.output_names)}, execution_time
def draw_boxes(image: Image.Image, outputs: dict):
draw = ImageDraw.Draw(image, "RGBA") # Use RGBA for transparency
# Dynamic font size based on image dimensions
image_width, image_height = image.size
boxes = outputs.get('detected_boxes', [])
classes = outputs.get('detected_classes', [])
scores = outputs.get('detected_scores', [])
for box, cls, score in zip(boxes[0], classes[0], scores[0]):
if score < PROB_THRESHOLD:
continue
label = LABELS[int(cls)]
# Assuming box format: [ymin, xmin, ymax, xmax] normalized [0,1]
ymin, xmin, ymax, xmax = box
left = xmin * image_width
right = xmax * image_width
top = ymin * image_height
bottom = ymax * image_height
# Draw bounding box
draw.rectangle([left, top, right, bottom], outline="red", width=4)
# Prepare label text
text = f"{label}: {score:.2f}"
# Set label box dimensions
text_width = right - left
text_height = (bottom - top) // 20 # 5% of the box heights
# Calculate label background position
label_top = max(top - text_height - 10, 0)
label_left = left
# Draw semi-transparent rectangle behind text
draw.rectangle(
[label_left, label_top, label_left + text_width + 10, label_top + text_height + 10],
fill=(255, 0, 0, 160) # Semi-transparent red
)
# Dynamically scale font size
font_size = 10 # Start with a small font size
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" # Common path on Linux
while True:
font = ImageFont.truetype(font_path, size=font_size)
text_bbox = draw.textbbox((0, 0), text, font=font)
text_pixel_height = text_bbox[3] - text_bbox[1]
if text_pixel_height >= text_height or font_size > 200: # Cap font size to prevent infinite loops
break
font_size += 1
# Draw text with the scaled font
draw.text(
(label_left + 5, label_top + 5),
text,
fill="black",
font=font
)
return image
# Initialize model
model = Model(MODEL_PATH)
def detect_objects(image):
outputs, execution_time = model.predict(image)
annotated_image = draw_boxes(image.copy(), outputs)
# Prepare detection summary
detections = []
boxes = outputs.get('detected_boxes', [])
classes = outputs.get('detected_classes', [])
scores = outputs.get('detected_scores', [])
for box, cls, score in zip(boxes[0], classes[0], scores[0]):
if score < PROB_THRESHOLD:
continue
label = LABELS[int(cls)]
detections.append(f"{label}: {score:.2f}")
detection_summary = "\n".join(detections) if detections else "No objects detected."
detection_summary += f"\n\nInference Time: {execution_time:.2f} ms"
return annotated_image, detection_summary
# Enhanced Gradio Interface with Links to Model Card and Repository
iface = gr.Interface(
fn=detect_objects,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil", label="Detected Objects"),
gr.Textbox(label="Detections")
],
title="JunkWaxHero โšพ - Baseball Card Set Detection (ONNX Model)",
description=(
"Upload an image to identify the set of the baseball card (1980-1999).\n\n"
"[๐Ÿ”— Model Card & Repository](https://huggingface.co/enusbaum/JunkWaxHero)"
),
examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"],
theme="default", # You can choose other themes if desired
allow_flagging="never" # Disable flagging if not needed
)
if __name__ == "__main__":
iface.launch()