File size: 5,452 Bytes
ae5135e
7a9ce47
ae5135e
 
7a9ce47
ae5135e
ecce323
ae5135e
 
 
 
2d23a8b
7a9ce47
ae5135e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2839c47
633e7c4
2839c47
633e7c4
2839c47
7a9ce47
2839c47
633e7c4
7a9ce47
2839c47
7a9ce47
b7bd655
ae5135e
 
 
029bb24
7a9ce47
ae5135e
7a9ce47
ae5135e
633e7c4
ae5135e
7a9ce47
633e7c4
 
 
 
 
 
2839c47
633e7c4
 
ae5135e
633e7c4
 
 
 
 
 
2839c47
 
 
 
 
 
 
 
 
 
633e7c4
 
2839c47
 
 
 
 
 
633e7c4
7a9ce47
 
ae5135e
 
029bb24
ae5135e
 
 
029bb24
ae5135e
 
 
 
 
029bb24
ae5135e
 
 
 
 
029bb24
ae5135e
 
 
2d23a8b
ae5135e
34f5d81
ae5135e
987957a
2839c47
 
 
 
ae5135e
 
633e7c4
2839c47
272be86
 
34f5d81
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import numpy as np
import onnx
import onnxruntime
from PIL import Image, ImageDraw, ImageFont
import gradio as gr

# Constants
PROB_THRESHOLD = 0.5  # Minimum probability to show results
MODEL_PATH = os.path.join("onnx", "model.onnx")
LABELS_PATH = os.path.join("onnx", "labels.txt")

# Load labels
with open(LABELS_PATH, "r") as f:
    LABELS = f.read().strip().split("\n")

class Model:
    def __init__(self, model_filepath):
        self.session = onnxruntime.InferenceSession(model_filepath)
        assert len(self.session.get_inputs()) == 1
        self.input_shape = self.session.get_inputs()[0].shape[2:]  # (H, W)
        self.input_name = self.session.get_inputs()[0].name
        self.input_type = {'tensor(float)': np.float32, 'tensor(float16)': np.float16}.get(
            self.session.get_inputs()[0].type, np.float32
        )
        self.output_names = [o.name for o in self.session.get_outputs()]

        self.is_bgr = False
        self.is_range255 = False
        onnx_model = onnx.load(model_filepath)
        for metadata in onnx_model.metadata_props:
            if metadata.key == 'Image.BitmapPixelFormat' and metadata.value == 'Bgr8':
                self.is_bgr = True
            elif metadata.key == 'Image.NominalPixelRange' and metadata.value == 'NominalRange_0_255':
                self.is_range255 = True

    def predict(self, image: Image.Image):
        # Preprocess image
        image_resized = image.resize(self.input_shape)
        input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
        input_array = input_array.transpose((0, 3, 1, 2))  # (N, C, H, W)
        if self.is_bgr:
            input_array = input_array[:, (2, 1, 0), :, :]
        if not self.is_range255:
            input_array = input_array / 255.0  # Normalize to [0,1]

        # Run inference
        outputs = self.session.run(self.output_names, {self.input_name: input_array.astype(self.input_type)})
        return {name: outputs[i] for i, name in enumerate(self.output_names)}

def draw_boxes(image: Image.Image, outputs: dict):
    draw = ImageDraw.Draw(image, "RGBA")  # Use RGBA for transparency

    # Dynamic font size based on image dimensions
    image_width, image_height = image.size
    font_size = max(20, image_width // 50)  # Increased minimum font size
    try:
        # Attempt to load a truetype font; adjust the path if necessary
        font = ImageFont.truetype("arial.ttf", size=font_size)
    except IOError:
        # Fallback to default font if truetype font is not found
        font = ImageFont.load_default()

    boxes = outputs.get('detected_boxes', [])
    classes = outputs.get('detected_classes', [])
    scores = outputs.get('detected_scores', [])

    for box, cls, score in zip(boxes[0], classes[0], scores[0]):
        if score < PROB_THRESHOLD:
            continue
        label = LABELS[int(cls)]

        # Assuming box format: [ymin, xmin, ymax, xmax] normalized [0,1]
        ymin, xmin, ymax, xmax = box
        left = xmin * image_width
        right = xmax * image_width
        top = ymin * image_height
        bottom = ymax * image_height

        # Draw bounding box
        draw.rectangle([left, top, right, bottom], outline="red", width=3)

        # Prepare label text
        text = f"{label}: {score:.2f}"

        # Calculate text size using textbbox
        text_bbox = draw.textbbox((0, 0), text, font=font)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = text_bbox[3] - text_bbox[1]

        # Calculate label background position
        # Ensure the label box does not go above the image
        label_top = max(top - text_height - 10, 0)
        label_left = left

        # Draw semi-transparent rectangle behind text
        draw.rectangle(
            [label_left, label_top, label_left + text_width + 10, label_top + text_height + 10],
            fill=(255, 0, 0, 160)  # Semi-transparent red
        )

        # Draw text
        draw.text(
            (label_left + 5, label_top + 5),
            text,
            fill="white",
            font=font
        )

    return image

# Initialize model
model = Model(MODEL_PATH)

def detect_objects(image):
    outputs = model.predict(image)
    annotated_image = draw_boxes(image.copy(), outputs)

    # Prepare detection summary
    detections = []
    boxes = outputs.get('detected_boxes', [])
    classes = outputs.get('detected_classes', [])
    scores = outputs.get('detected_scores', [])

    for box, cls, score in zip(boxes[0], classes[0], scores[0]):
        if score < PROB_THRESHOLD:
            continue
        label = LABELS[int(cls)]
        detections.append(f"{label}: {score:.2f}")

    detection_summary = "\n".join(detections) if detections else "No objects detected."

    return annotated_image, detection_summary

# Gradio Interface
iface = gr.Interface(
    fn=detect_objects,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Image(type="pil", label="Detected Objects"),
        gr.Textbox(label="Detections")
    ],
    title="Object Detection with ONNX Model",
    description="Upload an image to detect objects using the ONNX model.",
    examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"],
    theme="default",  # You can choose other themes if desired
    allow_flagging="never"  # Disable flagging if not needed
    # Removed 'layout' parameter
)

if __name__ == "__main__":
    iface.launch()