Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,792 Bytes
1b98b3b 8f86518 1b98b3b 8f86518 1b98b3b 0e82716 1b98b3b b553066 1b98b3b b553066 1b98b3b b553066 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import random
import requests
import json
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import torch
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from qwen_vl_utils import process_vision_info
from spaces import GPU
from gradio.themes.ocean import Ocean
# --- Config ---
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", torch_dtype="auto", device_map="auto"
)
min_pixels = 224 * 224
max_pixels = 512 * 512
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
label2color = {}
def get_color(label, explicit_color=None):
if explicit_color:
return explicit_color
if label not in label2color:
label2color[label] = "#" + ''.join(random.choices('0123456789ABCDEF', k=6))
return label2color[label]
def create_annotated_image(image, json_data, height, width):
try:
json_data = json_data.split('```json')[1].split('```')[0]
bbox_data = json.loads(json_data)
except Exception:
return image
original_width, original_height = image.size
x_scale = original_width / width
y_scale = original_height / height
scale_factor = max(original_width, original_height) / 512
draw_image = image.copy()
draw = ImageDraw.Draw(draw_image)
try:
print(1)
print('int(12 * scale_factor)', int(12 * scale_factor))
font = ImageFont.truetype("arial.ttf", int(12 * scale_factor))
except:
print(2)
font = ImageFont.load_default()
for item in bbox_data:
label = item.get("label", "")
color = get_color(label, item.get("color", None))
if "bbox_2d" in item:
bbox = item["bbox_2d"]
scaled_bbox = [
int(bbox[0] * x_scale),
int(bbox[1] * y_scale),
int(bbox[2] * x_scale),
int(bbox[3] * y_scale)
]
draw.rectangle(scaled_bbox, outline=color, width=int(2 * scale_factor))
draw.text(
(scaled_bbox[0], max(0, scaled_bbox[1] - int(15 * scale_factor))),
label,
fill=color,
font=font
)
if "point_2d" in item:
x, y = item["point_2d"]
scaled_x = int(x * x_scale)
scaled_y = int(y * y_scale)
r = int(5 * scale_factor)
draw.ellipse((scaled_x - r, scaled_y - r, scaled_x + r, scaled_y + r), fill=color, outline=color)
draw.text((scaled_x + int(6 * scale_factor), scaled_y), label, fill=color, font=font)
return draw_image
@GPU
def detect(image, prompt):
STANDARD_SIZE = (512, 512)
image.thumbnail(STANDARD_SIZE)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=1024)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, do_sample=True, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
input_height = inputs['image_grid_thw'][0][1] * 14
input_width = inputs['image_grid_thw'][0][2] * 14
annotated_image = create_annotated_image(image, output_text, input_height, input_width)
return annotated_image, output_text
css_hide_share = """
button#gradio-share-link-button-0 {
display: none !important;
}
"""
# --- Gradio Interface ---
with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
gr.Markdown("# Object Understanding with Vision Language Models")
gr.Markdown("### Explore object detection, visual grounding, keypoint detection, and/or object counting through natural language prompts.")
gr.Markdown("""
*Powered by Qwen2.5-VL*
*Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.*
""")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(label="Upload an image", type="pil", height=400)
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="e.g., Detect all red cars in the image")
category_input = gr.Textbox(label="Category", interactive=False)
generate_btn = gr.Button(value="Generate")
with gr.Column(scale=1):
output_image = gr.Image(type="pil", label="Annotated image", height=400)
output_textbox = gr.Textbox(label="Model response", lines=10)
gr.Markdown("### Examples")
example_prompts = [
["examples/example_1.jpg", "Detect all objects in the image and return their locations and labels.", "Object Detection"],
["examples/example_2.JPG", "Detect all the individual candies in the image and return their locations and labels.", "Object Detection"],
["examples/example_1.jpg", "Count the number of red cars in the image.", "Object Counting"],
["examples/example_2.JPG", "Count the number of blue candies in the image.", "Object Counting"],
["examples/example_1.jpg", "Identify the red cars in this image, detect their key points and return their positions in the form of points.", "Visual Grounding + Keypoint Detection"],
["examples/example_2.JPG", "Identify the blue candies in this image, detect their key points and return their positions in the form of points.", "Visual Grounding + Keypoint Detection"],
["examples/example_1.jpg", "Detect the red car that is leading in this image and return its location and label.", "Visual Grounding + Object Detection"],
["examples/example_2.JPG", "Detect the blue candy located at the top of the group in this image and return its location and label.", "Visual Grounding + Object Detection"],
]
gr.Examples(
examples=example_prompts,
inputs=[image_input, prompt_input, category_input],
label="Click an example to populate the input"
)
generate_btn.click(fn=detect, inputs=[image_input, prompt_input], outputs=[output_image, output_textbox])
if __name__ == "__main__":
demo.launch()
|