Eric P. Nusbaum
UI Fixes
9bad826
raw
history blame
5.61 kB
import os
import numpy as np
import onnx
import onnxruntime
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
# Constants
PROB_THRESHOLD = 0.5 # Minimum probability to show results
MODEL_PATH = os.path.join("onnx", "model.onnx")
LABELS_PATH = os.path.join("onnx", "labels.txt")
# Load labels
with open(LABELS_PATH, "r") as f:
LABELS = f.read().strip().split("\n")
class Model:
def __init__(self, model_filepath):
self.session = onnxruntime.InferenceSession(model_filepath)
assert len(self.session.get_inputs()) == 1
self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W)
self.input_name = self.session.get_inputs()[0].name
self.input_type = {'tensor(float)': np.float32, 'tensor(float16)': np.float16}.get(
self.session.get_inputs()[0].type, np.float32
)
self.output_names = [o.name for o in self.session.get_outputs()]
self.is_bgr = False
self.is_range255 = False
onnx_model = onnx.load(model_filepath)
for metadata in onnx_model.metadata_props:
if metadata.key == 'Image.BitmapPixelFormat' and metadata.value == 'Bgr8':
self.is_bgr = True
elif metadata.key == 'Image.NominalPixelRange' and metadata.value == 'NominalRange_0_255':
self.is_range255 = True
def predict(self, image: Image.Image):
# Preprocess image
image_resized = image.resize(self.input_shape)
input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W)
if self.is_bgr:
input_array = input_array[:, (2, 1, 0), :, :]
if not self.is_range255:
input_array = input_array / 255.0 # Normalize to [0,1]
# Run inference
outputs = self.session.run(self.output_names, {self.input_name: input_array.astype(self.input_type)})
return {name: outputs[i] for i, name in enumerate(self.output_names)}
def draw_boxes(image: Image.Image, outputs: dict):
draw = ImageDraw.Draw(image, "RGBA") # Use RGBA for transparency
# Dynamic font size based on image dimensions
image_width, image_height = image.size
boxes = outputs.get('detected_boxes', [])
classes = outputs.get('detected_classes', [])
scores = outputs.get('detected_scores', [])
for box, cls, score in zip(boxes[0], classes[0], scores[0]):
if score < PROB_THRESHOLD:
continue
label = LABELS[int(cls)]
# Assuming box format: [ymin, xmin, ymax, xmax] normalized [0,1]
ymin, xmin, ymax, xmax = box
left = xmin * image_width
right = xmax * image_width
top = ymin * image_height
bottom = ymax * image_height
# Draw bounding box
draw.rectangle([left, top, right, bottom], outline="red", width=4)
# Prepare label text
text = f"{label}: {score:.2f}"
# Set label box dimensions
text_width = right - left
text_height = (bottom - top) // 5 # 10% of the bounding box height
# Calculate label background position
label_top = max(top - text_height - 10, 0)
label_left = left
# Draw semi-transparent rectangle behind text
draw.rectangle(
[label_left, label_top, label_left + text_width + 10, label_top + text_height + 10],
fill=(255, 0, 0, 160) # Semi-transparent red
)
# Dynamically scale font size
font_size = 10 # Start with a small font size
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" # Common path on Linux
while True:
font = ImageFont.truetype(font_path, size=font_size)
text_bbox = draw.textbbox((0, 0), text, font=font)
text_pixel_height = text_bbox[3] - text_bbox[1]
if text_pixel_height >= text_height or font_size > 200: # Cap font size to prevent infinite loops
break
font_size += 1
# Draw text with the scaled font
draw.text(
(label_left + 5, label_top + 5),
text,
fill="black",
font=font
)
return image
# Initialize model
model = Model(MODEL_PATH)
def detect_objects(image):
outputs = model.predict(image)
annotated_image = draw_boxes(image.copy(), outputs)
# Prepare detection summary
detections = []
boxes = outputs.get('detected_boxes', [])
classes = outputs.get('detected_classes', [])
scores = outputs.get('detected_scores', [])
for box, cls, score in zip(boxes[0], classes[0], scores[0]):
if score < PROB_THRESHOLD:
continue
label = LABELS[int(cls)]
detections.append(f"{label}: {score:.2f}")
detection_summary = "\n".join(detections) if detections else "No objects detected."
return annotated_image, detection_summary
# Gradio Interface
iface = gr.Interface(
fn=detect_objects,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil", label="Detected Objects"),
gr.Textbox(label="Detections")
],
title="JunkWaxHero ⚾ - Baseball Card Set Detection (ONNX Model)",
description="Upload an image to itentify the set of the baseball card (1980-1999).",
examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"],
theme="default", # You can choose other themes if desired
allow_flagging="never" # Disable flagging if not needed
# Removed 'layout' parameter
)
if __name__ == "__main__":
iface.launch()