Spaces:
Sleeping
Sleeping
File size: 8,068 Bytes
d15a538 bf50961 d15a538 bf50961 d15a538 bf50961 d15a538 bf50961 d15a538 bf50961 d15a538 bf50961 d15a538 bf50961 d15a538 bf50961 d15a538 bf50961 d15a538 bf50961 d15a538 bf50961 d15a538 bf50961 d15a538 bf50961 d15a538 bf50961 d15a538 4735f9f d15a538 bf50961 d15a538 f892a57 f6081d0 bf50961 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import GroundingDinoProcessor
from modeling_grounding_dino import GroundingDinoForObjectDetection
from PIL import Image, ImageDraw, ImageFont
from itertools import cycle
import os
from datetime import datetime
import gradio as gr
import tempfile
# Load model and processor
model_id = "fushh7/llmdet_swin_large_hf"
model_id = "fushh7/llmdet_swin_tiny_hf"
DEVICE = "cpu"
print(f"[INFO] Using device: {DEVICE}")
print(f"[INFO] Loading model from {model_id}...")
processor = GroundingDinoProcessor.from_pretrained(model_id)
model = GroundingDinoForObjectDetection.from_pretrained(model_id).to(DEVICE)
model.eval()
print("[INFO] Model loaded successfully.")
# Pre-defined palette (extend or tweak as you like)
BOX_COLORS = [
"deepskyblue", "red", "lime", "dodgerblue",
"cyan", "magenta", "yellow",
"orange", "chartreuse"
]
def save_cropped_images(original_image, boxes, labels, scores):
"""
Salva ogni regione ritagliata definita dalle bounding box in file temporanei.
:param original_image: Immagine PIL originale
:param boxes: Lista di bounding box [x_min, y_min, x_max, y_max]
:param labels: Lista di etichette per ogni box
:param scores: Lista di punteggi di confidenza
:return: Lista dei percorsi dei file temporanei salvati
"""
saved_paths = []
for i, (box, label, score) in enumerate(zip(boxes, labels, scores)):
# Crea un file temporaneo
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
filepath = tmp_file.name
# Ritaglia la regione dall'immagine originale
cropped_img = original_image.crop(box)
# Salva l'immagine ritagliata
cropped_img.save(filepath)
saved_paths.append(filepath)
return saved_paths
def draw_boxes(image, boxes, labels, scores, colors=BOX_COLORS, font_path="arial.ttf", font_size=16):
"""
Draw bounding boxes and labels on a PIL Image.
:param image: PIL Image object
:param boxes: Iterable of [x_min, y_min, x_max, y_max]
:param labels: Iterable of label strings
:param scores: Iterable of scalar confidences (0-1)
:param colors: List/tuple of colour names or RGB tuples
:param font_path: Path to a TTF font for labels
:param font_size: Int size of font to use, default 16
:return: PIL Image with drawn boxes
"""
# Ensure we can iterate colours indefinitely
colour_cycle = cycle(colors)
draw = ImageDraw.Draw(image)
# Pick a font (fallback to default if missing)
try:
font = ImageFont.truetype(font_path, size=font_size)
except IOError:
font = ImageFont.load_default(size=font_size)
# Assign a consistent colour per label (optional)
label_to_colour = {}
for box, label, score in zip(boxes, labels, scores):
# Reuse colour if label seen before, else take next from cycle
colour = label_to_colour.setdefault(label, next(colour_cycle))
x_min, y_min, x_max, y_max = map(int, box)
# Draw rectangle
draw.rectangle([x_min, y_min, x_max, y_max], outline=colour, width=2)
# Compose text
text = f"{label} ({score:.3f})"
text_size = draw.textbbox((0, 0), text, font=font)[2:]
# Draw text background for legibility
bg_coords = [x_min, y_min - text_size[1] - 4,
x_min + text_size[0] + 4, y_min]
draw.rectangle(bg_coords, fill=colour)
# Draw text
draw.text((x_min + 2, y_min - text_size[1] - 2),
text, fill="black", font=font)
return image
def resize_image_max_dimension(image, max_size=4096):
"""
Resize an image so that the longest side is at most max_size pixels,
while maintaining the aspect ratio.
:param image: PIL Image object
:param max_size: Maximum dimension in pixels (default: 1024)
:return: PIL Image object (resized)
"""
width, height = image.size
# Check if resizing is needed
if max(width, height) <= max_size:
return image
# Calculate new dimensions maintaining aspect ratio
ratio = max_size / max(width, height)
new_width = int(width * ratio)
new_height = int(height * ratio)
# Resize the image using high-quality resampling
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
def detect_and_draw(
img: Image.Image,
text_query: str,
box_threshold: float = 0.14,
text_threshold: float = 0.13,
save_crops: bool = True
):
"""
Detect objects described in `text_query`, draw boxes, return the image and crops.
Note: `text_query` must be lowercase and each concept ends with a dot
(e.g. 'a cat. a remote control.')
"""
# Make sure text is lowered
text_query = text_query.lower()
# If the image size is too large, we make it smaller
img = resize_image_max_dimension(img, max_size=4096)
# Preprocess the image
inputs = processor(images=img, text=text_query, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[img.size[::-1]]
)[0]
img_out = img.copy()
img_out = draw_boxes(
img_out,
boxes = results["boxes"].cpu().numpy(),
labels = results.get("text_labels", results.get("labels", [])),
scores = results["scores"]
)
# Lista per i percorsi dei crop
crop_paths = []
if save_crops:
crop_paths = save_cropped_images(
img,
boxes=results["boxes"].cpu().numpy(),
labels=results.get("text_labels", results.get("labels", [])),
scores=results["scores"]
)
print(f"Generated {len(crop_paths)} cropped images")
return img_out, crop_paths
# Create example list
examples = [
["examples/stickers(1).jpg", "stickers. labels.", 0.24, 0.23],
]
# Funzione per pulire i file temporanei dopo l'uso
def cleanup_temp_files(crop_paths):
for path in crop_paths:
try:
os.unlink(path)
except:
pass
# Create Gradio demo
with gr.Blocks(title="Stikkiers", css=".gradio-container {max-width: 100% !important}") as demo:
gr.Markdown("# Sticker Finder")
gr.Markdown("Upload an image and adjust thresholds to see detections.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Input Image")
text_query = gr.Textbox(
value="stickers. labels. postcards.",
label="Text Query (lowercase, end each with '.', for example 'a bird. a tree.')"
)
box_threshold = gr.Slider(0.0, 1.0, 0.14, step=0.05, label="Box Threshold")
text_threshold = gr.Slider(0.0, 1.0, 0.13, step=0.05, label="Text Threshold")
submit_btn = gr.Button("Detect")
with gr.Column():
image_output = gr.Image(type="pil", label="Detections")
# Galleria per i crop
gallery = gr.Gallery(
label="Detected Crops",
columns=[4],
rows=[2],
object_fit="contain",
height="auto"
)
# Esempi
gr.Examples(
examples=examples,
inputs=[image_input, text_query, box_threshold, text_threshold],
outputs=[image_output, gallery],
fn=detect_and_draw,
cache_examples=True
)
# Pulsante di submit
submit_btn.click(
fn=detect_and_draw,
inputs=[image_input, text_query, box_threshold, text_threshold],
outputs=[image_output, gallery]
)
# Pulisci i file temporanei quando viene caricato un nuovo esempio
demo.load(
fn=lambda: None,
inputs=None,
outputs=None,
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=False) |