Spaces:
Running
on
Zero
Running
on
Zero
import random | |
import requests | |
import json | |
import matplotlib.pyplot as plt | |
from PIL import Image, ImageDraw, ImageFont | |
import gradio as gr | |
import torch | |
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration | |
from qwen_vl_utils import process_vision_info | |
from spaces import GPU | |
from gradio.themes.ocean import Ocean | |
# --- Config --- | |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
"Qwen/Qwen2.5-VL-3B-Instruct", torch_dtype="auto", device_map="auto" | |
) | |
min_pixels = 224 * 224 | |
max_pixels = 512 * 512 | |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels) | |
label2color = {} | |
def get_color(label, explicit_color=None): | |
if explicit_color: | |
return explicit_color | |
if label not in label2color: | |
label2color[label] = "#" + ''.join(random.choices('0123456789ABCDEF', k=6)) | |
return label2color[label] | |
def create_annotated_image(image, json_data, height, width): | |
try: | |
json_data = json_data.split('```json')[1].split('```')[0] | |
bbox_data = json.loads(json_data) | |
except Exception: | |
return image | |
original_width, original_height = image.size | |
x_scale = original_width / width | |
y_scale = original_height / height | |
scale_factor = max(original_width, original_height) / 512 | |
draw_image = image.copy() | |
draw = ImageDraw.Draw(draw_image) | |
try: | |
print(1) | |
print('int(12 * scale_factor)', int(12 * scale_factor)) | |
font = ImageFont.truetype("arial.ttf", int(12 * scale_factor)) | |
except: | |
print(2) | |
font = ImageFont.load_default() | |
for item in bbox_data: | |
label = item.get("label", "") | |
color = get_color(label, item.get("color", None)) | |
if "bbox_2d" in item: | |
bbox = item["bbox_2d"] | |
scaled_bbox = [ | |
int(bbox[0] * x_scale), | |
int(bbox[1] * y_scale), | |
int(bbox[2] * x_scale), | |
int(bbox[3] * y_scale) | |
] | |
draw.rectangle(scaled_bbox, outline=color, width=int(2 * scale_factor)) | |
draw.text( | |
(scaled_bbox[0], max(0, scaled_bbox[1] - int(15 * scale_factor))), | |
label, | |
fill=color, | |
font=font | |
) | |
if "point_2d" in item: | |
x, y = item["point_2d"] | |
scaled_x = int(x * x_scale) | |
scaled_y = int(y * y_scale) | |
r = int(5 * scale_factor) | |
draw.ellipse((scaled_x - r, scaled_y - r, scaled_x + r, scaled_y + r), fill=color, outline=color) | |
draw.text((scaled_x + int(6 * scale_factor), scaled_y), label, fill=color, font=font) | |
return draw_image | |
def detect(image, prompt): | |
STANDARD_SIZE = (512, 512) | |
image.thumbnail(STANDARD_SIZE) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": prompt}, | |
], | |
} | |
] | |
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
image_inputs, video_inputs = process_vision_info(messages) | |
inputs = processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt", | |
).to(model.device) | |
generated_ids = model.generate(**inputs, max_new_tokens=1024) | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
] | |
output_text = processor.batch_decode( | |
generated_ids_trimmed, do_sample=True, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
)[0] | |
input_height = inputs['image_grid_thw'][0][1] * 14 | |
input_width = inputs['image_grid_thw'][0][2] * 14 | |
annotated_image = create_annotated_image(image, output_text, input_height, input_width) | |
return annotated_image, output_text | |
css_hide_share = """ | |
button#gradio-share-link-button-0 { | |
display: none !important; | |
} | |
""" | |
# --- Gradio Interface --- | |
with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo: | |
gr.Markdown("# Object Understanding with Vision Language Models") | |
gr.Markdown("### Explore object detection, visual grounding, keypoint detection, and/or object counting through natural language prompts.") | |
gr.Markdown(""" | |
*Powered by Qwen2.5-VL* | |
*Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.* | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image(label="Upload an image", type="pil", height=400) | |
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="e.g., Detect all red cars in the image") | |
category_input = gr.Textbox(label="Category", interactive=False) | |
generate_btn = gr.Button(value="Generate") | |
with gr.Column(scale=1): | |
output_image = gr.Image(type="pil", label="Annotated image", height=400) | |
output_textbox = gr.Textbox(label="Model response", lines=10) | |
gr.Markdown("### Examples") | |
example_prompts = [ | |
["examples/example_1.jpg", "Detect all objects in the image and return their locations and labels.", "Object Detection"], | |
["examples/example_2.JPG", "Detect all the individual candies in the image and return their locations and labels.", "Object Detection"], | |
["examples/example_1.jpg", "Count the number of red cars in the image.", "Object Counting"], | |
["examples/example_2.JPG", "Count the number of blue candies in the image.", "Object Counting"], | |
["examples/example_1.jpg", "Identify the red cars in this image, detect their key points and return their positions in the form of points.", "Visual Grounding + Keypoint Detection"], | |
["examples/example_2.JPG", "Identify the blue candies in this image, detect their key points and return their positions in the form of points.", "Visual Grounding + Keypoint Detection"], | |
["examples/example_1.jpg", "Detect the red car that is leading in this image and return its location and label.", "Visual Grounding + Object Detection"], | |
["examples/example_2.JPG", "Detect the blue candy located at the top of the group in this image and return its location and label.", "Visual Grounding + Object Detection"], | |
] | |
gr.Examples( | |
examples=example_prompts, | |
inputs=[image_input, prompt_input, category_input], | |
label="Click an example to populate the input" | |
) | |
generate_btn.click(fn=detect, inputs=[image_input, prompt_input], outputs=[output_image, output_textbox]) | |
if __name__ == "__main__": | |
demo.launch() | |