File size: 6,792 Bytes
1b98b3b
 
 
 
 
 
 
 
 
 
 
 
8f86518
1b98b3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f86518
1b98b3b
0e82716
1b98b3b
 
 
 
 
 
b553066
 
 
 
 
 
 
 
 
1b98b3b
 
b553066
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b98b3b
 
 
 
b553066
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import random
import requests
import json

import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont

import gradio as gr
import torch
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from qwen_vl_utils import process_vision_info
from spaces import GPU
from gradio.themes.ocean import Ocean

# --- Config ---
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct", torch_dtype="auto", device_map="auto"
)

min_pixels = 224 * 224
max_pixels = 512 * 512
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)

label2color = {}

def get_color(label, explicit_color=None):
    if explicit_color:
        return explicit_color
    if label not in label2color:
        label2color[label] = "#" + ''.join(random.choices('0123456789ABCDEF', k=6))
    return label2color[label]

def create_annotated_image(image, json_data, height, width):
    try:
        json_data = json_data.split('```json')[1].split('```')[0]
        bbox_data = json.loads(json_data)
    except Exception:
        return image

    original_width, original_height = image.size
    x_scale = original_width / width
    y_scale = original_height / height

    scale_factor = max(original_width, original_height) / 512

    draw_image = image.copy()
    draw = ImageDraw.Draw(draw_image)

    try:
        print(1)
        print('int(12 * scale_factor)', int(12 * scale_factor))
        font = ImageFont.truetype("arial.ttf", int(12 * scale_factor))
    except:
        print(2)
        font = ImageFont.load_default()

    for item in bbox_data:
        label = item.get("label", "")
        color = get_color(label, item.get("color", None))

        if "bbox_2d" in item:
            bbox = item["bbox_2d"]
            scaled_bbox = [
                int(bbox[0] * x_scale),
                int(bbox[1] * y_scale),
                int(bbox[2] * x_scale),
                int(bbox[3] * y_scale)
            ]
            draw.rectangle(scaled_bbox, outline=color, width=int(2 * scale_factor))
            draw.text(
                (scaled_bbox[0], max(0, scaled_bbox[1] - int(15 * scale_factor))),
                label,
                fill=color,
                font=font
            )

        if "point_2d" in item:
            x, y = item["point_2d"]
            scaled_x = int(x * x_scale)
            scaled_y = int(y * y_scale)
            r = int(5 * scale_factor)
            draw.ellipse((scaled_x - r, scaled_y - r, scaled_x + r, scaled_y + r), fill=color, outline=color)
            draw.text((scaled_x + int(6 * scale_factor), scaled_y), label, fill=color, font=font)

    return draw_image

@GPU
def detect(image, prompt):
    STANDARD_SIZE = (512, 512)
    image.thumbnail(STANDARD_SIZE)
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt},
            ],
        }
    ]

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to(model.device)

    generated_ids = model.generate(**inputs, max_new_tokens=1024)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, do_sample=True, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]

    input_height = inputs['image_grid_thw'][0][1] * 14
    input_width = inputs['image_grid_thw'][0][2] * 14

    annotated_image = create_annotated_image(image, output_text, input_height, input_width)

    return annotated_image, output_text

css_hide_share = """
button#gradio-share-link-button-0 {
    display: none !important;
}
"""

# --- Gradio Interface ---
with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:

    gr.Markdown("# Object Understanding with Vision Language Models")
    gr.Markdown("### Explore object detection, visual grounding, keypoint detection, and/or object counting through natural language prompts.")
    gr.Markdown("""
    *Powered by Qwen2.5-VL*  
    *Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.*
    """)

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(label="Upload an image", type="pil", height=400)
            prompt_input = gr.Textbox(label="Enter your prompt", placeholder="e.g., Detect all red cars in the image")
            category_input = gr.Textbox(label="Category", interactive=False)
            generate_btn = gr.Button(value="Generate")

        with gr.Column(scale=1):
            output_image = gr.Image(type="pil", label="Annotated image", height=400)
            output_textbox = gr.Textbox(label="Model response", lines=10)

    gr.Markdown("### Examples")
    example_prompts = [
        ["examples/example_1.jpg", "Detect all objects in the image and return their locations and labels.", "Object Detection"],
        ["examples/example_2.JPG", "Detect all the individual candies in the image and return their locations and labels.", "Object Detection"],
        ["examples/example_1.jpg", "Count the number of red cars in the image.", "Object Counting"],
        ["examples/example_2.JPG", "Count the number of blue candies in the image.", "Object Counting"],
        ["examples/example_1.jpg", "Identify the red cars in this image, detect their key points and return their positions in the form of points.", "Visual Grounding + Keypoint Detection"],
        ["examples/example_2.JPG", "Identify the blue candies in this image, detect their key points and return their positions in the form of points.", "Visual Grounding + Keypoint Detection"],
        ["examples/example_1.jpg", "Detect the red car that is leading in this image and return its location and label.", "Visual Grounding + Object Detection"],
        ["examples/example_2.JPG", "Detect the blue candy located at the top of the group in this image and return its location and label.", "Visual Grounding + Object Detection"],
    ]

    gr.Examples(
        examples=example_prompts,
        inputs=[image_input, prompt_input, category_input],
        label="Click an example to populate the input"
    )

    generate_btn.click(fn=detect, inputs=[image_input, prompt_input], outputs=[output_image, output_textbox])

if __name__ == "__main__":
    demo.launch()