Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,342 Bytes
ddd1e4a 1090938 ddd1e4a 32cf9f0 69f0b8f ddd1e4a 69f0b8f 32cf9f0 69f0b8f ddd1e4a 69f0b8f ddd1e4a 32cf9f0 1090938 ddd1e4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import GroundingDinoProcessor
from modeling_grounding_dino import GroundingDinoForObjectDetection
from PIL import Image, ImageDraw, ImageFont
from itertools import cycle
import gradio as gr
import spaces
# Load model and processor
model_id = "fushh7/llmdet_swin_large_hf"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {DEVICE}")
print(f"[INFO] Loading model from {model_id}...")
processor = GroundingDinoProcessor.from_pretrained(model_id)
model = GroundingDinoForObjectDetection.from_pretrained(model_id).to(DEVICE)
model.eval();
print("[INFO] Model loaded successfully.")
# Pre-defined palette (extend or tweak as you like)
BOX_COLORS = [
"deepskyblue", "red", "lime", "dodgerblue",
"cyan", "magenta", "yellow",
"orange", "chartreuse"
]
def draw_boxes(image, boxes, labels, scores, colors=BOX_COLORS, font_path="arial.ttf", font_size=16):
"""
Draw bounding boxes and labels on a PIL Image.
:param image: PIL Image object
:param boxes: Iterable of [x_min, y_min, x_max, y_max]
:param labels: Iterable of label strings
:param scores: Iterable of scalar confidences (0-1)
:param colors: List/tuple of colour names or RGB tuples
:param font_path: Path to a TTF font for labels
:param font_size: Int size of font to use, default 16
:return: PIL Image with drawn boxes
"""
# Ensure we can iterate colours indefinitely
colour_cycle = cycle(colors)
draw = ImageDraw.Draw(image)
# Pick a font (fallback to default if missing)
try:
font = ImageFont.truetype(font_path, size=font_size)
except IOError:
font = ImageFont.load_default(size=font_size)
# Assign a consistent colour per label (optional)
label_to_colour = {}
for box, label, score in zip(boxes, labels, scores):
# Reuse colour if label seen before, else take next from cycle
colour = label_to_colour.setdefault(label, next(colour_cycle))
x_min, y_min, x_max, y_max = map(int, box)
# Draw rectangle
draw.rectangle([x_min, y_min, x_max, y_max], outline=colour, width=2)
# Compose text
text = f"{label} ({score:.3f})"
text_size = draw.textbbox((0, 0), text, font=font)[2:]
# Draw text background for legibility
bg_coords = [x_min, y_min - text_size[1] - 4,
x_min + text_size[0] + 4, y_min]
draw.rectangle(bg_coords, fill=colour)
# Draw text
draw.text((x_min + 2, y_min - text_size[1] - 2),
text, fill="black", font=font)
return image
def resize_image_max_dimension(image, max_size=1024):
"""
Resize an image so that the longest side is at most max_size pixels,
while maintaining the aspect ratio.
:param image: PIL Image object
:param max_size: Maximum dimension in pixels (default: 1024)
:return: PIL Image object (resized)
"""
width, height = image.size
# Check if resizing is needed
if max(width, height) <= max_size:
return image
# Calculate new dimensions maintaining aspect ratio
ratio = max_size / max(width, height)
new_width = int(width * ratio)
new_height = int(height * ratio)
# Resize the image using high-quality resampling
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
@spaces.GPU(duration=120)
def detect_and_draw(
img: Image.Image,
text_query: str,
box_threshold: float = 0.4,
text_threshold: float = 0.3
) -> Image.Image:
"""
Detect objects described in `text_query`, draw boxes, return the image.
Note: `text_query` must be lowercase and each concept ends with a dot
(e.g. 'a cat. a remote control.')
"""
# Make sure text is lowered
text_query = text_query.lower()
# If the image size is too large, we make it smaller
img = resize_image_max_dimension(img, max_size=1024)
# Preprocess the image
inputs = processor(images=img, text=text_query, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[img.size[::-1]]
)[0]
img_out = img.copy()
img_out = draw_boxes(
img_out,
boxes = results["boxes"].cpu().numpy(),
labels = results.get("text_labels", results.get("labels", [])),
scores = results["scores"]
)
return img_out
# Create example list
examples = [
["examples/IMG_8920.jpeg", "bin. water bottle. hand. shoe.", 0.4, 0.3],
["examples/IMG_9435.jpeg", "lettuce. orange slices (group). eggs (group). cheese (group). red cabbage. pear slices (group).", 0.4, 0.3],
]
# Create Gradio demo
demo = gr.Interface(
fn = detect_and_draw,
inputs = [
gr.Image(type="pil", label="Image"),
gr.Textbox(value="",
label="Text Query (lowercase, end each with '.', for example 'a bird. a tree.')"),
gr.Slider(0.0, 1.0, 0.4, 0.05, label="Box Threshold"),
gr.Slider(0.0, 1.0, 0.3, 0.05, label="Text Threshold")
],
outputs = gr.Image(type="pil", label="Detections"),
title = "LLMDet Demo: Open-Vocabulary Grounded Object Detection",
description = f"""Upload an image, enter text queries, and adjust thresholds to see detections.
Adapted from LLMDet GitHub repo [Hugging Face demo](https://github.com/iSEE-Laboratory/LLMDet/tree/main/hf_model).
This space uses: {model_id}
See original:
* [LLMDet GitHub](https://github.com/iSEE-Laboratory/LLMDet/tree/main?tab=readme-ov-file)
* [LLMDet Paper](https://arxiv.org/abs/2501.18954) - LLMDet: Learning Strong Open-Vocabulary Object Detectors under the Supervision of Large Language Models
* LLMDet model checkpoints:
* [Tiny](https://huggingface.co/fushh7/llmdet_swin_tiny_hf) (173M params)
* [Base](https://huggingface.co/fushh7/llmdet_swin_base_hf) (233M params)
* [Large](https://huggingface.co/fushh7/llmdet_swin_large_hf) (344M params)
""",
examples = examples,
cache_examples = True,
)
demo.launch() |