ethiotech4848 commited on
Commit
c76d131
·
verified ·
1 Parent(s): 7fa7f06

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +334 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ # from gradio.themes.Soft import Soft
8
+ from PIL import Image
9
+ from qwen_vl_utils import process_vision_info
10
+ from transformers import (
11
+ AutoProcessor,
12
+ Gemma3ForConditionalGeneration,
13
+ Qwen2_5_VLForConditionalGeneration,
14
+ )
15
+
16
+ from spaces import GPU
17
+ import supervision as sv
18
+
19
+ # --- Config ---
20
+ # IMPORTANT: Both models are gated. You must be logged in to your Hugging Face account
21
+ # and have been granted access to use them.
22
+ # from huggingface_hub import login
23
+ # login()
24
+
25
+ model_qwen_id = "Qwen/Qwen2.5-VL-3B-Instruct"
26
+ model_gemma_id = "google/gemma-3-4b-it"
27
+
28
+ # Load Qwen Model
29
+ model_qwen = Qwen2_5_VLForConditionalGeneration.from_pretrained(
30
+ model_qwen_id, torch_dtype="auto", device_map="auto"
31
+ )
32
+ min_pixels = 224 * 224
33
+ max_pixels = 1024 * 1024
34
+ processor_qwen = AutoProcessor.from_pretrained(
35
+ model_qwen_id, min_pixels=min_pixels, max_pixels=max_pixels
36
+ )
37
+
38
+ # Load Gemma Model
39
+ model_gemma = Gemma3ForConditionalGeneration.from_pretrained(
40
+ model_gemma_id,
41
+ torch_dtype=torch.bfloat16, # Recommended dtype for Gemma
42
+ device_map="auto"
43
+ )
44
+ processor_gemma = AutoProcessor.from_pretrained(model_gemma_id)
45
+
46
+
47
+ def extract_model_short_name(model_id):
48
+ return model_id.split("/")[-1].replace("-", " ").replace("_", " ")
49
+
50
+
51
+ model_qwen_name = extract_model_short_name(model_qwen_id) # → "Qwen2.5 VL 3B Instruct"
52
+ model_gemma_name = extract_model_short_name(model_gemma_id) # → "gemma 3 4b it"
53
+
54
+
55
+ def create_annotated_image(image, json_data, height, width):
56
+ try:
57
+ # Standardize parsing for outputs wrapped in markdown
58
+ if "```json" in json_data:
59
+ parsed_json_data = json_data.split("```json")[1].split("```")[0]
60
+ else:
61
+ parsed_json_data = json_data
62
+ bbox_data = json.loads(parsed_json_data)
63
+ except Exception:
64
+ # If parsing fails, return the original image
65
+ return image
66
+
67
+ # Ensure bbox_data is a list
68
+ if not isinstance(bbox_data, list):
69
+ bbox_data = [bbox_data]
70
+
71
+
72
+ original_width, original_height = image.size
73
+ x_scale = original_width / width
74
+ y_scale = original_height / height
75
+
76
+ points = []
77
+ point_labels = []
78
+
79
+ annotated_image = np.array(image.convert("RGB"))
80
+ detections_exist = False
81
+
82
+ # Check if there are bounding boxes in the data to create detections
83
+ if any("box_2d" in item for item in bbox_data):
84
+ detections_exist = True
85
+ # Use Qwen parser as a generic VLM parser for bounding boxes
86
+ detections = sv.Detections.from_vlm(vlm = sv.VLM.QWEN_2_5_VL,
87
+ result=json_data,
88
+ # resolution_wh is the size model "sees"
89
+ resolution_wh=(width, height))
90
+ bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
91
+ label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
92
+
93
+ annotated_image = bounding_box_annotator.annotate(
94
+ scene=annotated_image, detections=detections
95
+ )
96
+ annotated_image = label_annotator.annotate(
97
+ scene=annotated_image, detections=detections
98
+ )
99
+
100
+ # Handle points separately
101
+ for item in bbox_data:
102
+ label = item.get("label", "")
103
+ if "point_2d" in item:
104
+ x, y = item["point_2d"]
105
+ scaled_x = int(x * x_scale)
106
+ scaled_y = int(y * y_scale)
107
+ points.append([scaled_x, scaled_y])
108
+ point_labels.append(label)
109
+
110
+ if points:
111
+ points_array = np.array(points).reshape(1, -1, 2)
112
+ key_points = sv.KeyPoints(xy=points_array)
113
+ vertex_annotator = sv.VertexAnnotator(radius=5, color=sv.Color.BLUE)
114
+ annotated_image = vertex_annotator.annotate(
115
+ scene=annotated_image, key_points=key_points
116
+ )
117
+
118
+ return Image.fromarray(annotated_image)
119
+
120
+
121
+ @GPU
122
+ def detect_qwen(image, prompt):
123
+ messages = [
124
+ {
125
+ "role": "user",
126
+ "content": [
127
+ {"type": "image", "image": image},
128
+ {"type": "text", "text": prompt},
129
+ ],
130
+ }
131
+ ]
132
+
133
+ t0 = time.perf_counter()
134
+ text = processor_qwen.apply_chat_template(
135
+ messages, tokenize=False, add_generation_prompt=True
136
+ )
137
+ image_inputs, video_inputs = process_vision_info(messages)
138
+ inputs = processor_qwen(
139
+ text=[text],
140
+ images=image_inputs,
141
+ videos=video_inputs,
142
+ padding=True,
143
+ return_tensors="pt",
144
+ ).to(model_qwen.device)
145
+
146
+ generated_ids = model_qwen.generate(**inputs, max_new_tokens=1024)
147
+ generated_ids_trimmed = [
148
+ out_ids[len(in_ids) :]
149
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
150
+ ]
151
+ output_text = processor_qwen.batch_decode(
152
+ generated_ids_trimmed,
153
+ do_sample=True,
154
+ skip_special_tokens=True,
155
+ clean_up_tokenization_spaces=False,
156
+ )[0]
157
+ elapsed_ms = (time.perf_counter() - t0) * 1_000
158
+
159
+ # These dimensions are specific to how Qwen's processor handles images
160
+ input_height = inputs["image_grid_thw"][0][1] * 14
161
+ input_width = inputs["image_grid_thw"][0][2] * 14
162
+
163
+ annotated_image = create_annotated_image(
164
+ image, output_text, input_height, input_width
165
+ )
166
+
167
+ time_taken = f"**Inference time ({model_qwen_name}):** {elapsed_ms:.0f} ms"
168
+ return annotated_image, output_text, time_taken
169
+
170
+
171
+ @GPU
172
+ def detect_gemma(image, prompt):
173
+ messages = [
174
+ {
175
+ "role": "user",
176
+ "content": [
177
+ {"type": "image", "image": image},
178
+ {"type": "text", "text": prompt},
179
+ ],
180
+ }
181
+ ]
182
+
183
+ t0 = time.perf_counter()
184
+ inputs = processor_gemma.apply_chat_template(
185
+ messages,
186
+ add_generation_prompt=True,
187
+ tokenize=True,
188
+ return_dict=True,
189
+ return_tensors="pt"
190
+ ).to(model_gemma.device)
191
+
192
+ input_len = inputs["input_ids"].shape[-1]
193
+
194
+ with torch.inference_mode():
195
+ generation = model_gemma.generate(**inputs, max_new_tokens=1024, do_sample=False)
196
+
197
+ generation_trimmed = generation[0][input_len:]
198
+ output_text = processor_gemma.decode(generation_trimmed, skip_special_tokens=True)
199
+ elapsed_ms = (time.perf_counter() - t0) * 1_000
200
+
201
+ # Gemma's vision encoder normalizes images to a fixed size (e.g., 896x896)
202
+ input_height = 896
203
+ input_width = 896
204
+
205
+ annotated_image = create_annotated_image(
206
+ image, output_text, input_height, input_width
207
+ )
208
+
209
+ time_taken = f"**Inference time ({model_gemma_name}):** {elapsed_ms:.0f} ms"
210
+ return annotated_image, output_text, time_taken
211
+
212
+
213
+ def detect(image, prompt_model_1, prompt_model_2):
214
+ STANDARD_SIZE = (1024, 1024)
215
+ image.thumbnail(STANDARD_SIZE)
216
+
217
+ annotated_image_model_1, output_text_model_1, timing_1 = detect_qwen(
218
+ image, prompt_model_1
219
+ )
220
+ annotated_image_model_2, output_text_model_2, timing_2 = detect_gemma(
221
+ image, prompt_model_2
222
+ )
223
+
224
+ return (
225
+ annotated_image_model_1,
226
+ output_text_model_1,
227
+ timing_1,
228
+ annotated_image_model_2,
229
+ output_text_model_2,
230
+ timing_2,
231
+ )
232
+
233
+
234
+ css_hide_share = """
235
+ button#gradio-share-link-button-0 {
236
+ display: none !important;
237
+ }
238
+ """
239
+
240
+ # --- Gradio Interface ---
241
+ with gr.Blocks(theme=gr.themes.Soft(), css=css_hide_share) as demo:
242
+ gr.Markdown("# Object Detection & Understanding: Qwen vs. Gemma")
243
+ gr.Markdown(
244
+ "### Compare object detection, visual grounding, and keypoint detection using natural language prompts with two leading VLMs."
245
+ )
246
+ gr.Markdown("""
247
+ *Powered by [Qwen2.5-VL 3B](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) and [Gemma 3 4B IT](https://huggingface.co/google/gemma-3-4b-it). For best results, ask the model to return a JSON list in a markdown block. Inspired by the [HF Team's space](https://huggingface.co/spaces/sergiopaniego/vlm_object_understanding), selecting `detect` for categories with "Object Detection" `point` for the ones with "Keypoint Detection", and reasoning-based querying for all others.*
248
+ """)
249
+
250
+ with gr.Row():
251
+ with gr.Column(scale=2):
252
+ image_input = gr.Image(label="Upload an image", type="pil", height=400)
253
+ prompt_input_model_1 = gr.Textbox(
254
+ label=f"Enter your prompt for {model_qwen_name}",
255
+ placeholder="e.g., Detect all red cars. Return a JSON list with 'box_2d' and 'label'.",
256
+ )
257
+ prompt_input_model_2 = gr.Textbox(
258
+ label=f"Enter your prompt for {model_gemma_name}",
259
+ placeholder="e.g., Detect all red cars. Return a JSON list with 'box_2d' and 'label'.",
260
+ )
261
+ generate_btn = gr.Button(value="Generate")
262
+
263
+ with gr.Column(scale=1):
264
+ output_image_model_1 = gr.Image(
265
+ type="pil", label=f"Annotated image from {model_qwen_name}", height=400
266
+ )
267
+ output_textbox_model_1 = gr.Textbox(
268
+ label=f"Model response from {model_qwen_name}", lines=10
269
+ )
270
+ output_time_model_1 = gr.Markdown()
271
+
272
+ with gr.Column(scale=1):
273
+ output_image_model_2 = gr.Image(
274
+ type="pil",
275
+ label=f"Annotated image from {model_gemma_name}",
276
+ height=400,
277
+ )
278
+ output_textbox_model_2 = gr.Textbox(
279
+ label=f"Model response from {model_gemma_name}", lines=10
280
+ )
281
+ output_time_model_2 = gr.Markdown()
282
+
283
+ gr.Markdown("### Examples")
284
+
285
+ prompt_obj_detect = "Detect all objects in this image. For each object, provide a 'box_2d' and a 'label'. Return the output as a JSON list inside a markdown block."
286
+ prompt_candy_detect = "Detect all individual candies in this image. For each, provide a 'box_2d' and a 'label'. Return the output as a JSON list inside a markdown block."
287
+ prompt_car_count = "Count the number of red cars in the image."
288
+ prompt_candy_count = "Count the number of blue candies in the image."
289
+ prompt_car_keypoint = "Identify the red cars in this image. For each, detect its key points and return their positions as 'point_2d' in a JSON list inside a markdown block."
290
+ prompt_candy_keypoint = "Identify the blue candies in this image. For each, detect its key points and return their positions as 'point_2d' in a JSON list inside a markdown block."
291
+ prompt_car_ground = "Detect the red car that is leading in this image. Return its location with 'box_2d' and 'label' in a JSON list inside a markdown block."
292
+ prompt_candy_ground = "Detect the blue candy at the top of the group. Return its location with 'box_2d' and 'label' in a JSON list inside a markdown block."
293
+
294
+
295
+ example_prompts = [
296
+ ["examples/example_1.jpg", prompt_obj_detect, prompt_obj_detect],
297
+ ["examples/example_2.JPG", prompt_candy_detect, prompt_candy_detect],
298
+ ["examples/example_1.jpg", prompt_car_count, prompt_car_count],
299
+ ["examples/example_2.JPG", prompt_candy_count, prompt_candy_count],
300
+ ["examples/example_1.jpg", prompt_car_keypoint, prompt_car_keypoint],
301
+ ["examples/example_2.JPG", prompt_candy_keypoint, prompt_candy_keypoint],
302
+ ["examples/example_1.jpg", prompt_car_ground, prompt_car_ground],
303
+ ["examples/example_2.JPG", prompt_candy_ground, prompt_candy_ground],
304
+ ]
305
+
306
+ gr.Examples(
307
+ examples=example_prompts,
308
+ inputs=[
309
+ image_input,
310
+ prompt_input_model_1,
311
+ prompt_input_model_2,
312
+ ],
313
+ label="Click an example to populate the input",
314
+ )
315
+
316
+ generate_btn.click(
317
+ fn=detect,
318
+ inputs=[
319
+ image_input,
320
+ prompt_input_model_1,
321
+ prompt_input_model_2,
322
+ ],
323
+ outputs=[
324
+ output_image_model_1,
325
+ output_textbox_model_1,
326
+ output_time_model_1,
327
+ output_image_model_2,
328
+ output_textbox_model_2,
329
+ output_time_model_2,
330
+ ],
331
+ )
332
+
333
+ if __name__ == "__main__":
334
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ datasets
4
+ bitsandbytes
5
+ Pillow
6
+ gradio
7
+ accelerate
8
+ qwen-vl-utils
9
+ torchvision
10
+ matplotlib
11
+ supervision