sergiopaniego HF Staff commited on
Commit
1b98b3b
·
1 Parent(s): f6057ac

Upload files

Browse files
Files changed (4) hide show
  1. app.py +179 -0
  2. examples/example_1.jpg +3 -0
  3. examples/example_2.JPG +3 -0
  4. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import requests
3
+ import json
4
+
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image, ImageDraw, ImageFont
7
+
8
+ import gradio as gr
9
+ import torch
10
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
11
+ from qwen_vl_utils import process_vision_info
12
+ from spaces import GPU
13
+ from gradio.themes.citrus import Citrus
14
+
15
+ # --- Config ---
16
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
17
+ "Qwen/Qwen2.5-VL-3B-Instruct", torch_dtype="auto", device_map="auto"
18
+ )
19
+
20
+ min_pixels = 224 * 224
21
+ max_pixels = 512 * 512
22
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
23
+
24
+ label2color = {}
25
+
26
+ def get_color(label, explicit_color=None):
27
+ if explicit_color:
28
+ return explicit_color
29
+ if label not in label2color:
30
+ label2color[label] = "#" + ''.join(random.choices('0123456789ABCDEF', k=6))
31
+ return label2color[label]
32
+
33
+ def create_annotated_image(image, json_data, height, width):
34
+ try:
35
+ json_data = json_data.split('```json')[1].split('```')[0]
36
+ bbox_data = json.loads(json_data)
37
+ except Exception:
38
+ return image
39
+
40
+ original_width, original_height = image.size
41
+ x_scale = original_width / width
42
+ y_scale = original_height / height
43
+
44
+ scale_factor = max(original_width, original_height) / 512
45
+
46
+ draw_image = image.copy()
47
+ draw = ImageDraw.Draw(draw_image)
48
+
49
+ try:
50
+ print(1)
51
+ print('int(12 * scale_factor)', int(12 * scale_factor))
52
+ font = ImageFont.truetype("arial.ttf", int(12 * scale_factor))
53
+ except:
54
+ print(2)
55
+ font = ImageFont.load_default()
56
+
57
+ for item in bbox_data:
58
+ label = item.get("label", "")
59
+ color = get_color(label, item.get("color", None))
60
+
61
+ if "bbox_2d" in item:
62
+ bbox = item["bbox_2d"]
63
+ scaled_bbox = [
64
+ int(bbox[0] * x_scale),
65
+ int(bbox[1] * y_scale),
66
+ int(bbox[2] * x_scale),
67
+ int(bbox[3] * y_scale)
68
+ ]
69
+ draw.rectangle(scaled_bbox, outline=color, width=int(2 * scale_factor))
70
+ draw.text(
71
+ (scaled_bbox[0], max(0, scaled_bbox[1] - int(15 * scale_factor))),
72
+ label,
73
+ fill=color,
74
+ font=font
75
+ )
76
+
77
+ if "point_2d" in item:
78
+ x, y = item["point_2d"]
79
+ scaled_x = int(x * x_scale)
80
+ scaled_y = int(y * y_scale)
81
+ r = int(5 * scale_factor)
82
+ draw.ellipse((scaled_x - r, scaled_y - r, scaled_x + r, scaled_y + r), fill=color, outline=color)
83
+ draw.text((scaled_x + int(6 * scale_factor), scaled_y), label, fill=color, font=font)
84
+
85
+ return draw_image
86
+
87
+ @GPU
88
+ def detect(image, prompt):
89
+ STANDARD_SIZE = (512, 512)
90
+ image.thumbnail(STANDARD_SIZE)
91
+ messages = [
92
+ {
93
+ "role": "user",
94
+ "content": [
95
+ {"type": "image", "image": image},
96
+ {"type": "text", "text": prompt},
97
+ ],
98
+ }
99
+ ]
100
+
101
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
102
+ image_inputs, video_inputs = process_vision_info(messages)
103
+ inputs = processor(
104
+ text=[text],
105
+ images=image_inputs,
106
+ videos=video_inputs,
107
+ padding=True,
108
+ return_tensors="pt",
109
+ ).to(model.device)
110
+
111
+ generated_ids = model.generate(**inputs, max_new_tokens=1024)
112
+ generated_ids_trimmed = [
113
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
114
+ ]
115
+ output_text = processor.batch_decode(
116
+ generated_ids_trimmed, do_sample=True, skip_special_tokens=True, clean_up_tokenization_spaces=False
117
+ )[0]
118
+
119
+ input_height = inputs['image_grid_thw'][0][1] * 14
120
+ input_width = inputs['image_grid_thw'][0][2] * 14
121
+
122
+ annotated_image = create_annotated_image(image, output_text, input_height, input_width)
123
+
124
+ return annotated_image, output_text
125
+
126
+ css_hide_share = """
127
+ button#gradio-share-link-button-0 {
128
+ display: none !important;
129
+ }
130
+ """
131
+
132
+ # --- Gradio Interface ---
133
+ with gr.Blocks(theme=Citrus(), css=css_hide_share) as demo:
134
+
135
+ gr.Markdown("# Object Understanding with Vision-Language Models")
136
+ gr.Markdown("### Explore object detection, visual grounding, keypoint detection, and/or object counting through natural language prompts.")
137
+ gr.Markdown("""
138
+ *Powered by Qwen2.5-VL*
139
+ *Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.*
140
+ """)
141
+
142
+
143
+ with gr.Column():
144
+ with gr.Row():
145
+ image_input = gr.Image(label="Upload an image", type="pil", height=500)
146
+
147
+ with gr.Column():
148
+ prompt_input = gr.Textbox(label="Enter your prompt", placeholder="e.g., Detect all red cars in the image")
149
+ category_input = gr.Textbox(label="Category", interactive=False)
150
+
151
+ generate_btn = gr.Button(value="Generate")
152
+
153
+ with gr.Row():
154
+ output_image = gr.Image(type="pil", label="Annotated image", height=500)
155
+ output_textbox = gr.Textbox(label="Model response", lines=10)
156
+
157
+ example_prompts = [
158
+ ["examples/example_1.jpg", "Detect all objects in the image and return their locations and labels.", "Object Detection"],
159
+ ["examples/example_2.JPG", "Detect all the individual candies in the image and return their locations and labels.", "Object Detection"],
160
+ ["examples/example_1.jpg", "Count the number of red cars in the image.", "Object Counting"],
161
+ ["examples/example_2.JPG", "Count the number of blue candies in the image.", "Object Counting"],
162
+ ["examples/example_1.jpg", "Identify the red cars in this image, detect their key points and return their positions in the form of points.", "Visual Grounding + Keypoint Detection"],
163
+ ["examples/example_2.JPG", "Identify the blue candies in this image, detect their key points and return their positions in the form of points.", "Visual Grounding + Keypoint Detection"],
164
+ ["examples/example_1.jpg", "Detect the red car that is leading in this image and return its location and label.", "Visual Grounding + Object Detection"],
165
+ ["examples/example_2.JPG", "Detect the blue candy located at the top of the group in this image and return its location and label.", "Visual Grounding + Object Detection"],
166
+ ]
167
+
168
+
169
+ gr.Markdown("### Examples")
170
+ gr.Examples(
171
+ examples=example_prompts,
172
+ inputs=[image_input, prompt_input, category_input],
173
+ label="Click an example to populate the input"
174
+ )
175
+
176
+ generate_btn.click(fn=detect, inputs=[image_input, prompt_input], outputs=[output_image, output_textbox])
177
+
178
+ if __name__ == "__main__":
179
+ demo.launch()
examples/example_1.jpg ADDED

Git LFS Details

  • SHA256: 459c6c274619b0e423f5d571a4cac75eadbea00997347bebab6df5609e6fdfe9
  • Pointer size: 130 Bytes
  • Size of remote file: 90.4 kB
examples/example_2.JPG ADDED

Git LFS Details

  • SHA256: fc417c899e94f8df465b7541c5a70f0eebb85c414d06345f0b290c061eccc84c
  • Pointer size: 132 Bytes
  • Size of remote file: 2.29 MB
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ datasets
4
+ Pillow
5
+ gradio
6
+ accelerate
7
+ qwen-vl-utils
8
+ torchvision
9
+ matplotlib