sergiopaniego's picture
Update app.py
b553066 verified
import random
import requests
import json
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import torch
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from qwen_vl_utils import process_vision_info
from spaces import GPU
from gradio.themes.ocean import Ocean
# --- Config ---
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", torch_dtype="auto", device_map="auto"
)
min_pixels = 224 * 224
max_pixels = 512 * 512
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
label2color = {}
def get_color(label, explicit_color=None):
if explicit_color:
return explicit_color
if label not in label2color:
label2color[label] = "#" + ''.join(random.choices('0123456789ABCDEF', k=6))
return label2color[label]
def create_annotated_image(image, json_data, height, width):
try:
json_data = json_data.split('```json')[1].split('```')[0]
bbox_data = json.loads(json_data)
except Exception:
return image
original_width, original_height = image.size
x_scale = original_width / width
y_scale = original_height / height
scale_factor = max(original_width, original_height) / 512
draw_image = image.copy()
draw = ImageDraw.Draw(draw_image)
try:
print(1)
print('int(12 * scale_factor)', int(12 * scale_factor))
font = ImageFont.truetype("arial.ttf", int(12 * scale_factor))
except:
print(2)
font = ImageFont.load_default()
for item in bbox_data:
label = item.get("label", "")
color = get_color(label, item.get("color", None))
if "bbox_2d" in item:
bbox = item["bbox_2d"]
scaled_bbox = [
int(bbox[0] * x_scale),
int(bbox[1] * y_scale),
int(bbox[2] * x_scale),
int(bbox[3] * y_scale)
]
draw.rectangle(scaled_bbox, outline=color, width=int(2 * scale_factor))
draw.text(
(scaled_bbox[0], max(0, scaled_bbox[1] - int(15 * scale_factor))),
label,
fill=color,
font=font
)
if "point_2d" in item:
x, y = item["point_2d"]
scaled_x = int(x * x_scale)
scaled_y = int(y * y_scale)
r = int(5 * scale_factor)
draw.ellipse((scaled_x - r, scaled_y - r, scaled_x + r, scaled_y + r), fill=color, outline=color)
draw.text((scaled_x + int(6 * scale_factor), scaled_y), label, fill=color, font=font)
return draw_image
@GPU
def detect(image, prompt):
STANDARD_SIZE = (512, 512)
image.thumbnail(STANDARD_SIZE)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=1024)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, do_sample=True, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
input_height = inputs['image_grid_thw'][0][1] * 14
input_width = inputs['image_grid_thw'][0][2] * 14
annotated_image = create_annotated_image(image, output_text, input_height, input_width)
return annotated_image, output_text
css_hide_share = """
button#gradio-share-link-button-0 {
display: none !important;
}
"""
# --- Gradio Interface ---
with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
gr.Markdown("# Object Understanding with Vision Language Models")
gr.Markdown("### Explore object detection, visual grounding, keypoint detection, and/or object counting through natural language prompts.")
gr.Markdown("""
*Powered by Qwen2.5-VL*
*Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.*
""")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(label="Upload an image", type="pil", height=400)
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="e.g., Detect all red cars in the image")
category_input = gr.Textbox(label="Category", interactive=False)
generate_btn = gr.Button(value="Generate")
with gr.Column(scale=1):
output_image = gr.Image(type="pil", label="Annotated image", height=400)
output_textbox = gr.Textbox(label="Model response", lines=10)
gr.Markdown("### Examples")
example_prompts = [
["examples/example_1.jpg", "Detect all objects in the image and return their locations and labels.", "Object Detection"],
["examples/example_2.JPG", "Detect all the individual candies in the image and return their locations and labels.", "Object Detection"],
["examples/example_1.jpg", "Count the number of red cars in the image.", "Object Counting"],
["examples/example_2.JPG", "Count the number of blue candies in the image.", "Object Counting"],
["examples/example_1.jpg", "Identify the red cars in this image, detect their key points and return their positions in the form of points.", "Visual Grounding + Keypoint Detection"],
["examples/example_2.JPG", "Identify the blue candies in this image, detect their key points and return their positions in the form of points.", "Visual Grounding + Keypoint Detection"],
["examples/example_1.jpg", "Detect the red car that is leading in this image and return its location and label.", "Visual Grounding + Object Detection"],
["examples/example_2.JPG", "Detect the blue candy located at the top of the group in this image and return its location and label.", "Visual Grounding + Object Detection"],
]
gr.Examples(
examples=example_prompts,
inputs=[image_input, prompt_input, category_input],
label="Click an example to populate the input"
)
generate_btn.click(fn=detect, inputs=[image_input, prompt_input], outputs=[output_image, output_textbox])
if __name__ == "__main__":
demo.launch()