sergiopaniego HF Staff onuralpszr commited on
Commit
b83d741
·
verified ·
1 Parent(s): 99a53c4

feat: ✨ supervision from_vlm support added (#4)

Browse files

- refactor: Clean up imports 🧹, improve code readability 📘, and add from_vlm feature from Supervision 🕵️‍♂️ for simplified bounding boxes and annotations 🖼️ (c03a662fce0311080d6ecaef3341d84b914e44af)
- fix: 🐞 re-add

@GPU
decorator to detection functions (de6ff1c222c2340bca1af52c9dcd8931ad211e2b)


Co-authored-by: Onuralp SEZER <[email protected]>

Files changed (2) hide show
  1. app.py +174 -106
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,20 +1,19 @@
1
- import random
2
- import requests
3
  import json
4
- import ast
5
  import time
6
 
7
- import matplotlib.pyplot as plt
8
- import numpy as np
9
- import supervision as sv
10
- from PIL import Image, ImageDraw, ImageFont
11
-
12
  import gradio as gr
13
- import torch
14
- from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, AutoModelForCausalLM
 
15
  from qwen_vl_utils import process_vision_info
 
 
 
 
 
 
16
  from spaces import GPU
17
- from gradio.themes.ocean import Ocean
18
 
19
  # --- Config ---
20
  model_qwen_id = "Qwen/Qwen2.5-VL-3B-Instruct"
@@ -27,24 +26,29 @@ model_moondream = AutoModelForCausalLM.from_pretrained(
27
  model_moondream_id,
28
  revision="2025-06-21",
29
  trust_remote_code=True,
30
- device_map={"": "cuda"}
31
  )
32
 
 
33
  def extract_model_short_name(model_id):
34
  return model_id.split("/")[-1].replace("-", " ").replace("_", " ")
35
 
 
36
  model_qwen_name = extract_model_short_name(model_qwen_id) # → "Qwen2.5 VL 3B Instruct"
37
  model_moondream_name = extract_model_short_name(model_moondream_id) # → "moondream2"
38
 
39
 
40
  min_pixels = 224 * 224
41
  max_pixels = 1024 * 1024
42
- processor_qwen = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
 
 
 
43
 
44
  def create_annotated_image(image, json_data, height, width):
45
  try:
46
- json_data = json_data.split("```json")[1].split("```")[0]
47
- bbox_data = json.loads(json_data)
48
  except Exception:
49
  return image
50
 
@@ -52,24 +56,11 @@ def create_annotated_image(image, json_data, height, width):
52
  x_scale = original_width / width
53
  y_scale = original_height / height
54
 
55
- boxes = []
56
- box_labels = []
57
  points = []
58
  point_labels = []
59
 
60
  for item in bbox_data:
61
  label = item.get("label", "")
62
- if "bbox_2d" in item:
63
- bbox = item["bbox_2d"]
64
- scaled_bbox = [
65
- int(bbox[0] * x_scale),
66
- int(bbox[1] * y_scale),
67
- int(bbox[2] * x_scale),
68
- int(bbox[3] * y_scale)
69
- ]
70
- boxes.append(scaled_bbox)
71
- box_labels.append(label)
72
-
73
  if "point_2d" in item:
74
  x, y = item["point_2d"]
75
  scaled_x = int(x * x_scale)
@@ -77,34 +68,34 @@ def create_annotated_image(image, json_data, height, width):
77
  points.append([scaled_x, scaled_y])
78
  point_labels.append(label)
79
 
80
- annotated_image = np.array(image.convert("RGB"))
81
-
82
- if boxes:
83
- detections = sv.Detections(xyxy=np.array(boxes))
 
 
 
 
84
  bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
85
  label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
86
 
87
  annotated_image = bounding_box_annotator.annotate(
88
- scene=annotated_image,
89
- detections=detections
90
  )
91
  annotated_image = label_annotator.annotate(
92
- scene=annotated_image,
93
- detections=detections,
94
- labels=box_labels
95
  )
96
 
97
  if points:
98
  points_array = np.array(points).reshape(1, -1, 2)
99
  key_points = sv.KeyPoints(xy=points_array)
100
  vertex_annotator = sv.VertexAnnotator(radius=5, color=sv.Color.BLUE)
101
- #vertex_label_annotator = sv.VertexLabelAnnotator(text_scale=0.5, border_radius=2)
102
 
103
  annotated_image = vertex_annotator.annotate(
104
- scene=annotated_image,
105
- key_points=key_points
106
  )
107
-
108
  # annotated_image = vertex_label_annotator.annotate(
109
  # scene=annotated_image,
110
  # key_points=key_points,
@@ -113,6 +104,7 @@ def create_annotated_image(image, json_data, height, width):
113
 
114
  return Image.fromarray(annotated_image)
115
 
 
116
  def create_annotated_image_normalized(image, json_data, label="object"):
117
  if not isinstance(json_data, dict):
118
  return image
@@ -127,54 +119,43 @@ def create_annotated_image_normalized(image, json_data, label="object"):
127
  x = int(point["x"] * original_width)
128
  y = int(point["y"] * original_height)
129
  points.append([x, y])
130
-
131
  if "reasoning" in json_data:
132
  for grounding in json_data["reasoning"].get("grounding", []):
133
  for x_norm, y_norm in grounding.get("points", []):
134
  x = int(x_norm * original_width)
135
  y = int(y_norm * original_height)
136
- points.append([x,y])
137
 
138
  if points:
139
  points_array = np.array(points).reshape(1, -1, 2)
140
  key_points = sv.KeyPoints(xy=points_array)
141
  vertex_annotator = sv.VertexAnnotator(radius=5, color=sv.Color.RED)
142
- annotated_image = vertex_annotator.annotate(scene=annotated_image, key_points=key_points)
 
 
143
 
144
- # Handle boxes for object detection
145
- boxes = []
146
  if "objects" in json_data:
147
- for item in json_data.get("objects", []):
148
- x_min = int(item["x_min"] * original_width)
149
- y_min = int(item["y_min"] * original_height)
150
- x_max = int(item["x_max"] * original_width)
151
- y_max = int(item["y_max"] * original_height)
152
- boxes.append([x_min, y_min, x_max, y_max])
153
-
154
- if boxes:
155
- detections = sv.Detections(xyxy=np.array(boxes))
156
  bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
157
  label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
158
-
159
  labels = [label for _ in detections.xyxy]
160
 
161
  annotated_image = bounding_box_annotator.annotate(
162
- scene=annotated_image,
163
- detections=detections
164
  )
165
  annotated_image = label_annotator.annotate(
166
- scene=annotated_image,
167
- detections=detections,
168
- labels=labels
169
  )
170
 
171
  return Image.fromarray(annotated_image)
172
 
173
-
174
-
175
  @GPU
176
  def detect_qwen(image, prompt):
177
-
178
  messages = [
179
  {
180
  "role": "user",
@@ -186,7 +167,9 @@ def detect_qwen(image, prompt):
186
  ]
187
 
188
  t0 = time.perf_counter()
189
- text = processor_qwen.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
190
  image_inputs, video_inputs = process_vision_info(messages)
191
  inputs = processor_qwen(
192
  text=[text],
@@ -198,17 +181,23 @@ def detect_qwen(image, prompt):
198
 
199
  generated_ids = model_qwen.generate(**inputs, max_new_tokens=1024)
200
  generated_ids_trimmed = [
201
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
 
202
  ]
203
  output_text = processor_qwen.batch_decode(
204
- generated_ids_trimmed, do_sample=True, skip_special_tokens=True, clean_up_tokenization_spaces=False
 
 
 
205
  )[0]
206
  elapsed_ms = (time.perf_counter() - t0) * 1_000
207
 
208
- input_height = inputs['image_grid_thw'][0][1] * 14
209
- input_width = inputs['image_grid_thw'][0][2] * 14
210
 
211
- annotated_image = create_annotated_image(image, output_text, input_height, input_width)
 
 
212
 
213
  time_taken = f"**Inference time ({model_qwen_name}):** {elapsed_ms:.0f} ms"
214
  return annotated_image, output_text, time_taken
@@ -222,22 +211,39 @@ def detect_moondream(image, prompt, category_input):
222
  elif category_input == "Visual Grounding + Keypoint Detection":
223
  output_text = model_moondream.point(image=image, object=prompt)
224
  else:
225
- output_text = model_moondream.query(image=image, question=prompt, reasoning=True)
 
 
226
  elapsed_ms = (time.perf_counter() - t0) * 1_000
227
 
228
- annotated_image = create_annotated_image_normalized(image=image, json_data=output_text, label="object")
 
 
229
 
230
  time_taken = f"**Inference time ({model_moondream_name}):** {elapsed_ms:.0f} ms"
231
  return annotated_image, output_text, time_taken
232
 
 
233
  def detect(image, prompt_model_1, prompt_model_2, category_input):
234
  STANDARD_SIZE = (1024, 1024)
235
  image.thumbnail(STANDARD_SIZE)
236
-
237
- annotated_image_model_1, output_text_model_1, timing_1 = detect_qwen(image, prompt_model_1)
238
- annotated_image_model_2, output_text_model_2, timing_2 = detect_moondream(image, prompt_model_2, category_input)
239
 
240
- return annotated_image_model_1, output_text_model_1, timing_1, annotated_image_model_2, output_text_model_2, timing_2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  css_hide_share = """
243
  button#gradio-share-link-button-0 {
@@ -247,11 +253,12 @@ button#gradio-share-link-button-0 {
247
 
248
  # --- Gradio Interface ---
249
  with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
250
-
251
  gr.Markdown("# 👓 Object Understanding with Vision Language Models")
252
- gr.Markdown("### Explore object detection, visual grounding, keypoint detection, and/or object counting through natural language prompts.")
 
 
253
  gr.Markdown("""
254
- *Powered by [Qwen2.5-VL 3B](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) and [Moondream 2B (revision="2025-06-21")](https://huggingface.co/vikhyatk/moondream2). Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.*
255
  *Moondream 2B uses the [moondream.py API](https://huggingface.co/vikhyatk/moondream2/blob/main/moondream.py), selecting `detect` for categories with "Object Detection" `point` for the ones with "Keypoint Detection", and reasoning-based querying for all others.*
256
  """)
257
 
@@ -260,66 +267,127 @@ with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
260
  image_input = gr.Image(label="Upload an image", type="pil", height=400)
261
  prompt_input_model_1 = gr.Textbox(
262
  label=f"Enter your prompt for {model_qwen_name}",
263
- placeholder="e.g., Detect all red cars in the image"
264
  )
265
 
266
  prompt_input_model_2 = gr.Textbox(
267
  label=f"Enter your prompt for {model_moondream_name}",
268
- placeholder="e.g., Detect all blue cars in the image"
269
  )
270
 
271
-
272
  categories = [
273
  "Object Detection",
274
  "Object Counting",
275
  "Visual Grounding + Keypoint Detection",
276
  "Visual Grounding + Object Detection",
277
- "General query"
278
  ]
279
 
280
  category_input = gr.Dropdown(
281
- choices=categories,
282
- label="Category",
283
- interactive=True
284
  )
285
  generate_btn = gr.Button(value="Generate")
286
 
287
  with gr.Column(scale=1):
288
- output_image_model_1 = gr.Image(type="pil", label=f"Annotated image for {model_qwen_name}", height=400)
289
- output_textbox_model_1 = gr.Textbox(label=f"Model response for {model_qwen_name}", lines=10)
 
 
 
 
290
  output_time_model_1 = gr.Markdown()
291
-
292
  with gr.Column(scale=1):
293
- output_image_model_2 = gr.Image(type="pil", label=f"Annotated image for {model_moondream_name}", height=400)
294
- output_textbox_model_2 = gr.Textbox(label=f"Model response for {model_moondream_name}", lines=10)
 
 
 
 
 
 
295
  output_time_model_2 = gr.Markdown()
296
 
297
  gr.Markdown("### Examples")
298
  example_prompts = [
299
- ["examples/example_1.jpg", "Detect all objects in the image and return their locations and labels.", "objects", "Object Detection"],
300
- ["examples/example_2.JPG", "Detect all the individual candies in the image and return their locations and labels.", "candies", "Object Detection"],
301
- ["examples/example_1.jpg", "Count the number of red cars in the image.", "Count the number of red cars in the image.", "Object Counting"],
302
- ["examples/example_2.JPG", "Count the number of blue candies in the image.", "Count the number of blue candies in the image.", "Object Counting"],
303
- ["examples/example_1.jpg", "Identify the red cars in this image, detect their key points and return their positions in the form of points.", "red cars", "Visual Grounding + Keypoint Detection"],
304
- ["examples/example_2.JPG", "Identify the blue candies in this image, detect their key points and return their positions in the form of points.", "blue candies", "Visual Grounding + Keypoint Detection"],
305
- ["examples/example_1.jpg", "Detect the red car that is leading in this image and return its location and label.", "leading red car", "Visual Grounding + Object Detection"],
306
- ["examples/example_2.JPG", "Detect the blue candy located at the top of the group in this image and return its location and label.", "blue candy located at the top of the group", "Visual Grounding + Object Detection"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  ]
308
 
309
  gr.Examples(
310
  examples=example_prompts,
311
- inputs=[image_input, prompt_input_model_1, prompt_input_model_2, category_input],
312
- label="Click an example to populate the input"
 
 
 
 
 
313
  )
314
 
315
  generate_btn.click(
316
  fn=detect,
317
- inputs=[image_input, prompt_input_model_1, prompt_input_model_2, category_input],
 
 
 
 
 
318
  outputs=[
319
- output_image_model_1, output_textbox_model_1, output_time_model_1,
320
- output_image_model_2, output_textbox_model_2, output_time_model_2
321
- ]
 
 
 
 
322
  )
323
-
324
  if __name__ == "__main__":
325
  demo.launch()
 
 
 
1
  import json
 
2
  import time
3
 
 
 
 
 
 
4
  import gradio as gr
5
+ import numpy as np
6
+ from gradio.themes.ocean import Ocean
7
+ from PIL import Image
8
  from qwen_vl_utils import process_vision_info
9
+ from transformers import (
10
+ AutoModelForCausalLM,
11
+ AutoProcessor,
12
+ Qwen2_5_VLForConditionalGeneration,
13
+ )
14
+
15
  from spaces import GPU
16
+ import supervision as sv
17
 
18
  # --- Config ---
19
  model_qwen_id = "Qwen/Qwen2.5-VL-3B-Instruct"
 
26
  model_moondream_id,
27
  revision="2025-06-21",
28
  trust_remote_code=True,
29
+ device_map={"": "cuda"},
30
  )
31
 
32
+
33
  def extract_model_short_name(model_id):
34
  return model_id.split("/")[-1].replace("-", " ").replace("_", " ")
35
 
36
+
37
  model_qwen_name = extract_model_short_name(model_qwen_id) # → "Qwen2.5 VL 3B Instruct"
38
  model_moondream_name = extract_model_short_name(model_moondream_id) # → "moondream2"
39
 
40
 
41
  min_pixels = 224 * 224
42
  max_pixels = 1024 * 1024
43
+ processor_qwen = AutoProcessor.from_pretrained(
44
+ "Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
45
+ )
46
+
47
 
48
  def create_annotated_image(image, json_data, height, width):
49
  try:
50
+ parsed_json_data = json_data.split("```json")[1].split("```")[0]
51
+ bbox_data = json.loads(parsed_json_data)
52
  except Exception:
53
  return image
54
 
 
56
  x_scale = original_width / width
57
  y_scale = original_height / height
58
 
 
 
59
  points = []
60
  point_labels = []
61
 
62
  for item in bbox_data:
63
  label = item.get("label", "")
 
 
 
 
 
 
 
 
 
 
 
64
  if "point_2d" in item:
65
  x, y = item["point_2d"]
66
  scaled_x = int(x * x_scale)
 
68
  points.append([scaled_x, scaled_y])
69
  point_labels.append(label)
70
 
71
+ annotated_image = np.array(image.convert("RGB"))
72
+
73
+ detections = sv.Detections.from_vlm(vlm = sv.VLM.QWEN_2_5_VL,
74
+ result=json_data,
75
+ input_wh=(original_width,
76
+ original_height),
77
+ resolution_wh=(original_width,
78
+ original_height))
79
  bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
80
  label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
81
 
82
  annotated_image = bounding_box_annotator.annotate(
83
+ scene=annotated_image, detections=detections
 
84
  )
85
  annotated_image = label_annotator.annotate(
86
+ scene=annotated_image, detections=detections
 
 
87
  )
88
 
89
  if points:
90
  points_array = np.array(points).reshape(1, -1, 2)
91
  key_points = sv.KeyPoints(xy=points_array)
92
  vertex_annotator = sv.VertexAnnotator(radius=5, color=sv.Color.BLUE)
93
+ # vertex_label_annotator = sv.VertexLabelAnnotator(text_scale=0.5, border_radius=2)
94
 
95
  annotated_image = vertex_annotator.annotate(
96
+ scene=annotated_image, key_points=key_points
 
97
  )
98
+
99
  # annotated_image = vertex_label_annotator.annotate(
100
  # scene=annotated_image,
101
  # key_points=key_points,
 
104
 
105
  return Image.fromarray(annotated_image)
106
 
107
+
108
  def create_annotated_image_normalized(image, json_data, label="object"):
109
  if not isinstance(json_data, dict):
110
  return image
 
119
  x = int(point["x"] * original_width)
120
  y = int(point["y"] * original_height)
121
  points.append([x, y])
122
+
123
  if "reasoning" in json_data:
124
  for grounding in json_data["reasoning"].get("grounding", []):
125
  for x_norm, y_norm in grounding.get("points", []):
126
  x = int(x_norm * original_width)
127
  y = int(y_norm * original_height)
128
+ points.append([x, y])
129
 
130
  if points:
131
  points_array = np.array(points).reshape(1, -1, 2)
132
  key_points = sv.KeyPoints(xy=points_array)
133
  vertex_annotator = sv.VertexAnnotator(radius=5, color=sv.Color.RED)
134
+ annotated_image = vertex_annotator.annotate(
135
+ scene=annotated_image, key_points=key_points
136
+ )
137
 
 
 
138
  if "objects" in json_data:
139
+ detections = sv.Detections.from_vlm(sv.VLM.MOONDREAM,json_data,
140
+ resolution_wh=(original_width,
141
+ original_height))
142
+
 
 
 
 
 
143
  bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
144
  label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
145
+
146
  labels = [label for _ in detections.xyxy]
147
 
148
  annotated_image = bounding_box_annotator.annotate(
149
+ scene=annotated_image, detections=detections
 
150
  )
151
  annotated_image = label_annotator.annotate(
152
+ scene=annotated_image, detections=detections, labels=labels
 
 
153
  )
154
 
155
  return Image.fromarray(annotated_image)
156
 
 
 
157
  @GPU
158
  def detect_qwen(image, prompt):
 
159
  messages = [
160
  {
161
  "role": "user",
 
167
  ]
168
 
169
  t0 = time.perf_counter()
170
+ text = processor_qwen.apply_chat_template(
171
+ messages, tokenize=False, add_generation_prompt=True
172
+ )
173
  image_inputs, video_inputs = process_vision_info(messages)
174
  inputs = processor_qwen(
175
  text=[text],
 
181
 
182
  generated_ids = model_qwen.generate(**inputs, max_new_tokens=1024)
183
  generated_ids_trimmed = [
184
+ out_ids[len(in_ids) :]
185
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
186
  ]
187
  output_text = processor_qwen.batch_decode(
188
+ generated_ids_trimmed,
189
+ do_sample=True,
190
+ skip_special_tokens=True,
191
+ clean_up_tokenization_spaces=False,
192
  )[0]
193
  elapsed_ms = (time.perf_counter() - t0) * 1_000
194
 
195
+ input_height = inputs["image_grid_thw"][0][1] * 14
196
+ input_width = inputs["image_grid_thw"][0][2] * 14
197
 
198
+ annotated_image = create_annotated_image(
199
+ image, output_text, input_height, input_width
200
+ )
201
 
202
  time_taken = f"**Inference time ({model_qwen_name}):** {elapsed_ms:.0f} ms"
203
  return annotated_image, output_text, time_taken
 
211
  elif category_input == "Visual Grounding + Keypoint Detection":
212
  output_text = model_moondream.point(image=image, object=prompt)
213
  else:
214
+ output_text = model_moondream.query(
215
+ image=image, question=prompt, reasoning=True
216
+ )
217
  elapsed_ms = (time.perf_counter() - t0) * 1_000
218
 
219
+ annotated_image = create_annotated_image_normalized(
220
+ image=image, json_data=output_text, label="object"
221
+ )
222
 
223
  time_taken = f"**Inference time ({model_moondream_name}):** {elapsed_ms:.0f} ms"
224
  return annotated_image, output_text, time_taken
225
 
226
+
227
  def detect(image, prompt_model_1, prompt_model_2, category_input):
228
  STANDARD_SIZE = (1024, 1024)
229
  image.thumbnail(STANDARD_SIZE)
 
 
 
230
 
231
+ annotated_image_model_1, output_text_model_1, timing_1 = detect_qwen(
232
+ image, prompt_model_1
233
+ )
234
+ annotated_image_model_2, output_text_model_2, timing_2 = detect_moondream(
235
+ image, prompt_model_2, category_input
236
+ )
237
+
238
+ return (
239
+ annotated_image_model_1,
240
+ output_text_model_1,
241
+ timing_1,
242
+ annotated_image_model_2,
243
+ output_text_model_2,
244
+ timing_2,
245
+ )
246
+
247
 
248
  css_hide_share = """
249
  button#gradio-share-link-button-0 {
 
253
 
254
  # --- Gradio Interface ---
255
  with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
 
256
  gr.Markdown("# 👓 Object Understanding with Vision Language Models")
257
+ gr.Markdown(
258
+ "### Explore object detection, visual grounding, keypoint detection, and/or object counting through natural language prompts."
259
+ )
260
  gr.Markdown("""
261
+ *Powered by [Qwen2.5-VL 3B](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) and [Moondream 2B (revision="2025-06-21")](https://huggingface.co/vikhyatk/moondream2). Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.*
262
  *Moondream 2B uses the [moondream.py API](https://huggingface.co/vikhyatk/moondream2/blob/main/moondream.py), selecting `detect` for categories with "Object Detection" `point` for the ones with "Keypoint Detection", and reasoning-based querying for all others.*
263
  """)
264
 
 
267
  image_input = gr.Image(label="Upload an image", type="pil", height=400)
268
  prompt_input_model_1 = gr.Textbox(
269
  label=f"Enter your prompt for {model_qwen_name}",
270
+ placeholder="e.g., Detect all red cars in the image",
271
  )
272
 
273
  prompt_input_model_2 = gr.Textbox(
274
  label=f"Enter your prompt for {model_moondream_name}",
275
+ placeholder="e.g., Detect all blue cars in the image",
276
  )
277
 
 
278
  categories = [
279
  "Object Detection",
280
  "Object Counting",
281
  "Visual Grounding + Keypoint Detection",
282
  "Visual Grounding + Object Detection",
283
+ "General query",
284
  ]
285
 
286
  category_input = gr.Dropdown(
287
+ choices=categories, label="Category", interactive=True
 
 
288
  )
289
  generate_btn = gr.Button(value="Generate")
290
 
291
  with gr.Column(scale=1):
292
+ output_image_model_1 = gr.Image(
293
+ type="pil", label=f"Annotated image for {model_qwen_name}", height=400
294
+ )
295
+ output_textbox_model_1 = gr.Textbox(
296
+ label=f"Model response for {model_qwen_name}", lines=10
297
+ )
298
  output_time_model_1 = gr.Markdown()
299
+
300
  with gr.Column(scale=1):
301
+ output_image_model_2 = gr.Image(
302
+ type="pil",
303
+ label=f"Annotated image for {model_moondream_name}",
304
+ height=400,
305
+ )
306
+ output_textbox_model_2 = gr.Textbox(
307
+ label=f"Model response for {model_moondream_name}", lines=10
308
+ )
309
  output_time_model_2 = gr.Markdown()
310
 
311
  gr.Markdown("### Examples")
312
  example_prompts = [
313
+ [
314
+ "examples/example_1.jpg",
315
+ "Detect all objects in the image and return their locations and labels.",
316
+ "objects",
317
+ "Object Detection",
318
+ ],
319
+ [
320
+ "examples/example_2.JPG",
321
+ "Detect all the individual candies in the image and return their locations and labels.",
322
+ "candies",
323
+ "Object Detection",
324
+ ],
325
+ [
326
+ "examples/example_1.jpg",
327
+ "Count the number of red cars in the image.",
328
+ "Count the number of red cars in the image.",
329
+ "Object Counting",
330
+ ],
331
+ [
332
+ "examples/example_2.JPG",
333
+ "Count the number of blue candies in the image.",
334
+ "Count the number of blue candies in the image.",
335
+ "Object Counting",
336
+ ],
337
+ [
338
+ "examples/example_1.jpg",
339
+ "Identify the red cars in this image, detect their key points and return their positions in the form of points.",
340
+ "red cars",
341
+ "Visual Grounding + Keypoint Detection",
342
+ ],
343
+ [
344
+ "examples/example_2.JPG",
345
+ "Identify the blue candies in this image, detect their key points and return their positions in the form of points.",
346
+ "blue candies",
347
+ "Visual Grounding + Keypoint Detection",
348
+ ],
349
+ [
350
+ "examples/example_1.jpg",
351
+ "Detect the red car that is leading in this image and return its location and label.",
352
+ "leading red car",
353
+ "Visual Grounding + Object Detection",
354
+ ],
355
+ [
356
+ "examples/example_2.JPG",
357
+ "Detect the blue candy located at the top of the group in this image and return its location and label.",
358
+ "blue candy located at the top of the group",
359
+ "Visual Grounding + Object Detection",
360
+ ],
361
  ]
362
 
363
  gr.Examples(
364
  examples=example_prompts,
365
+ inputs=[
366
+ image_input,
367
+ prompt_input_model_1,
368
+ prompt_input_model_2,
369
+ category_input,
370
+ ],
371
+ label="Click an example to populate the input",
372
  )
373
 
374
  generate_btn.click(
375
  fn=detect,
376
+ inputs=[
377
+ image_input,
378
+ prompt_input_model_1,
379
+ prompt_input_model_2,
380
+ category_input,
381
+ ],
382
  outputs=[
383
+ output_image_model_1,
384
+ output_textbox_model_1,
385
+ output_time_model_1,
386
+ output_image_model_2,
387
+ output_textbox_model_2,
388
+ output_time_model_2,
389
+ ],
390
  )
391
+
392
  if __name__ == "__main__":
393
  demo.launch()
requirements.txt CHANGED
@@ -7,4 +7,4 @@ accelerate
7
  qwen-vl-utils
8
  torchvision
9
  matplotlib
10
- supervision>=0.26.0rc7
 
7
  qwen-vl-utils
8
  torchvision
9
  matplotlib
10
+ supervision