File size: 5,614 Bytes
de40de9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f61c335
 
633e7c4
f61c335
 
 
 
 
 
 
 
4e8cd1a
f61c335
 
 
 
 
 
 
 
 
 
e7ef62f
f61c335
 
 
 
196b516
e7ef62f
9bad826
e7ef62f
f61c335
 
 
 
 
 
 
 
 
 
32eca4a
 
 
 
 
 
 
 
9bad826
32eca4a
 
bf53212
32eca4a
196b516
 
 
 
 
 
633e7c4
7a9ce47
de40de9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bad826
 
de40de9
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import numpy as np
import onnx
import onnxruntime
from PIL import Image, ImageDraw, ImageFont
import gradio as gr

# Constants
PROB_THRESHOLD = 0.5  # Minimum probability to show results
MODEL_PATH = os.path.join("onnx", "model.onnx")
LABELS_PATH = os.path.join("onnx", "labels.txt")

# Load labels
with open(LABELS_PATH, "r") as f:
    LABELS = f.read().strip().split("\n")

class Model:
    def __init__(self, model_filepath):
        self.session = onnxruntime.InferenceSession(model_filepath)
        assert len(self.session.get_inputs()) == 1
        self.input_shape = self.session.get_inputs()[0].shape[2:]  # (H, W)
        self.input_name = self.session.get_inputs()[0].name
        self.input_type = {'tensor(float)': np.float32, 'tensor(float16)': np.float16}.get(
            self.session.get_inputs()[0].type, np.float32
        )
        self.output_names = [o.name for o in self.session.get_outputs()]

        self.is_bgr = False
        self.is_range255 = False
        onnx_model = onnx.load(model_filepath)
        for metadata in onnx_model.metadata_props:
            if metadata.key == 'Image.BitmapPixelFormat' and metadata.value == 'Bgr8':
                self.is_bgr = True
            elif metadata.key == 'Image.NominalPixelRange' and metadata.value == 'NominalRange_0_255':
                self.is_range255 = True

    def predict(self, image: Image.Image):
        # Preprocess image
        image_resized = image.resize(self.input_shape)
        input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
        input_array = input_array.transpose((0, 3, 1, 2))  # (N, C, H, W)
        if self.is_bgr:
            input_array = input_array[:, (2, 1, 0), :, :]
        if not self.is_range255:
            input_array = input_array / 255.0  # Normalize to [0,1]

        # Run inference
        outputs = self.session.run(self.output_names, {self.input_name: input_array.astype(self.input_type)})
        return {name: outputs[i] for i, name in enumerate(self.output_names)}

def draw_boxes(image: Image.Image, outputs: dict):
    draw = ImageDraw.Draw(image, "RGBA")  # Use RGBA for transparency

    # Dynamic font size based on image dimensions
    image_width, image_height = image.size
    boxes = outputs.get('detected_boxes', [])
    classes = outputs.get('detected_classes', [])
    scores = outputs.get('detected_scores', [])

    for box, cls, score in zip(boxes[0], classes[0], scores[0]):
        if score < PROB_THRESHOLD:
            continue
        label = LABELS[int(cls)]

        # Assuming box format: [ymin, xmin, ymax, xmax] normalized [0,1]
        ymin, xmin, ymax, xmax = box
        left = xmin * image_width
        right = xmax * image_width
        top = ymin * image_height
        bottom = ymax * image_height

        # Draw bounding box
        draw.rectangle([left, top, right, bottom], outline="red", width=4)

        # Prepare label text
        text = f"{label}: {score:.2f}"

        # Set label box dimensions
        text_width = right - left
        text_height = (bottom - top) // 5  # 10% of the bounding box height

        # Calculate label background position
        label_top = max(top - text_height - 10, 0)
        label_left = left

        # Draw semi-transparent rectangle behind text
        draw.rectangle(
            [label_left, label_top, label_left + text_width + 10, label_top + text_height + 10],
            fill=(255, 0, 0, 160)  # Semi-transparent red
        )

        # Dynamically scale font size
        font_size = 10  # Start with a small font size
        font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"  # Common path on Linux

        while True:
            font = ImageFont.truetype(font_path, size=font_size)
            text_bbox = draw.textbbox((0, 0), text, font=font)
            text_pixel_height = text_bbox[3] - text_bbox[1]
            if text_pixel_height >= text_height or font_size > 200:  # Cap font size to prevent infinite loops
                break
            font_size += 1

        # Draw text with the scaled font
        draw.text(
            (label_left + 5, label_top + 5),
            text,
            fill="black",
            font=font
        )

    return image

# Initialize model
model = Model(MODEL_PATH)

def detect_objects(image):
    outputs = model.predict(image)
    annotated_image = draw_boxes(image.copy(), outputs)

    # Prepare detection summary
    detections = []
    boxes = outputs.get('detected_boxes', [])
    classes = outputs.get('detected_classes', [])
    scores = outputs.get('detected_scores', [])

    for box, cls, score in zip(boxes[0], classes[0], scores[0]):
        if score < PROB_THRESHOLD:
            continue
        label = LABELS[int(cls)]
        detections.append(f"{label}: {score:.2f}")

    detection_summary = "\n".join(detections) if detections else "No objects detected."

    return annotated_image, detection_summary

# Gradio Interface
iface = gr.Interface(
    fn=detect_objects,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Image(type="pil", label="Detected Objects"),
        gr.Textbox(label="Detections")
    ],
    title="JunkWaxHero ⚾ - Baseball Card Set Detection (ONNX Model)",
    description="Upload an image to itentify the set of the baseball card (1980-1999).",
    examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"],
    theme="default",  # You can choose other themes if desired
    allow_flagging="never"  # Disable flagging if not needed
    # Removed 'layout' parameter
)

if __name__ == "__main__":
    iface.launch()