Spaces:
Sleeping
Sleeping
import torch | |
from PIL import Image, ImageDraw, ImageFont | |
from transformers import GroundingDinoProcessor | |
from modeling_grounding_dino import GroundingDinoForObjectDetection | |
from PIL import Image, ImageDraw, ImageFont | |
from itertools import cycle | |
import os | |
from datetime import datetime | |
import gradio as gr | |
import tempfile | |
# Load model and processor | |
model_id = "fushh7/llmdet_swin_large_hf" | |
model_id = "fushh7/llmdet_swin_tiny_hf" | |
DEVICE = "cpu" | |
print(f"[INFO] Using device: {DEVICE}") | |
print(f"[INFO] Loading model from {model_id}...") | |
processor = GroundingDinoProcessor.from_pretrained(model_id) | |
model = GroundingDinoForObjectDetection.from_pretrained(model_id).to(DEVICE) | |
model.eval() | |
print("[INFO] Model loaded successfully.") | |
# Pre-defined palette (extend or tweak as you like) | |
BOX_COLORS = [ | |
"deepskyblue", "red", "lime", "dodgerblue", | |
"cyan", "magenta", "yellow", | |
"orange", "chartreuse" | |
] | |
def save_cropped_images(original_image, boxes, labels, scores): | |
""" | |
Salva ogni regione ritagliata definita dalle bounding box in file temporanei. | |
:param original_image: Immagine PIL originale | |
:param boxes: Lista di bounding box [x_min, y_min, x_max, y_max] | |
:param labels: Lista di etichette per ogni box | |
:param scores: Lista di punteggi di confidenza | |
:return: Lista dei percorsi dei file temporanei salvati | |
""" | |
saved_paths = [] | |
for i, (box, label, score) in enumerate(zip(boxes, labels, scores)): | |
# Crea un file temporaneo | |
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file: | |
filepath = tmp_file.name | |
# Ritaglia la regione dall'immagine originale | |
cropped_img = original_image.crop(box) | |
# Salva l'immagine ritagliata | |
cropped_img.save(filepath) | |
saved_paths.append(filepath) | |
return saved_paths | |
def draw_boxes(image, boxes, labels, scores, colors=BOX_COLORS, font_path="arial.ttf", font_size=16): | |
""" | |
Draw bounding boxes and labels on a PIL Image. | |
:param image: PIL Image object | |
:param boxes: Iterable of [x_min, y_min, x_max, y_max] | |
:param labels: Iterable of label strings | |
:param scores: Iterable of scalar confidences (0-1) | |
:param colors: List/tuple of colour names or RGB tuples | |
:param font_path: Path to a TTF font for labels | |
:param font_size: Int size of font to use, default 16 | |
:return: PIL Image with drawn boxes | |
""" | |
# Ensure we can iterate colours indefinitely | |
colour_cycle = cycle(colors) | |
draw = ImageDraw.Draw(image) | |
# Pick a font (fallback to default if missing) | |
try: | |
font = ImageFont.truetype(font_path, size=font_size) | |
except IOError: | |
font = ImageFont.load_default(size=font_size) | |
# Assign a consistent colour per label (optional) | |
label_to_colour = {} | |
for box, label, score in zip(boxes, labels, scores): | |
# Reuse colour if label seen before, else take next from cycle | |
colour = label_to_colour.setdefault(label, next(colour_cycle)) | |
x_min, y_min, x_max, y_max = map(int, box) | |
# Draw rectangle | |
draw.rectangle([x_min, y_min, x_max, y_max], outline=colour, width=2) | |
# Compose text | |
text = f"{label} ({score:.3f})" | |
text_size = draw.textbbox((0, 0), text, font=font)[2:] | |
# Draw text background for legibility | |
bg_coords = [x_min, y_min - text_size[1] - 4, | |
x_min + text_size[0] + 4, y_min] | |
draw.rectangle(bg_coords, fill=colour) | |
# Draw text | |
draw.text((x_min + 2, y_min - text_size[1] - 2), | |
text, fill="black", font=font) | |
return image | |
def resize_image_max_dimension(image, max_size=4096): | |
""" | |
Resize an image so that the longest side is at most max_size pixels, | |
while maintaining the aspect ratio. | |
:param image: PIL Image object | |
:param max_size: Maximum dimension in pixels (default: 1024) | |
:return: PIL Image object (resized) | |
""" | |
width, height = image.size | |
# Check if resizing is needed | |
if max(width, height) <= max_size: | |
return image | |
# Calculate new dimensions maintaining aspect ratio | |
ratio = max_size / max(width, height) | |
new_width = int(width * ratio) | |
new_height = int(height * ratio) | |
# Resize the image using high-quality resampling | |
return image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
def detect_and_draw( | |
img: Image.Image, | |
text_query: str, | |
box_threshold: float = 0.14, | |
text_threshold: float = 0.13, | |
save_crops: bool = True | |
): | |
""" | |
Detect objects described in `text_query`, draw boxes, return the image and crops. | |
Note: `text_query` must be lowercase and each concept ends with a dot | |
(e.g. 'a cat. a remote control.') | |
""" | |
# Make sure text is lowered | |
text_query = text_query.lower() | |
# If the image size is too large, we make it smaller | |
img = resize_image_max_dimension(img, max_size=4096) | |
# Preprocess the image | |
inputs = processor(images=img, text=text_query, return_tensors="pt").to(DEVICE) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
results = processor.post_process_grounded_object_detection( | |
outputs, | |
inputs.input_ids, | |
box_threshold=box_threshold, | |
text_threshold=text_threshold, | |
target_sizes=[img.size[::-1]] | |
)[0] | |
img_out = img.copy() | |
img_out = draw_boxes( | |
img_out, | |
boxes = results["boxes"].cpu().numpy(), | |
labels = results.get("text_labels", results.get("labels", [])), | |
scores = results["scores"] | |
) | |
# Lista per i percorsi dei crop | |
crop_paths = [] | |
if save_crops: | |
crop_paths = save_cropped_images( | |
img, | |
boxes=results["boxes"].cpu().numpy(), | |
labels=results.get("text_labels", results.get("labels", [])), | |
scores=results["scores"] | |
) | |
print(f"Generated {len(crop_paths)} cropped images") | |
return img_out, crop_paths | |
# Create example list | |
examples = [ | |
["examples/stickers(1).jpg", "stickers. labels.", 0.24, 0.23], | |
] | |
# Funzione per pulire i file temporanei dopo l'uso | |
def cleanup_temp_files(crop_paths): | |
for path in crop_paths: | |
try: | |
os.unlink(path) | |
except: | |
pass | |
# Create Gradio demo | |
with gr.Blocks(title="Stikkiers", css=".gradio-container {max-width: 100% !important}") as demo: | |
gr.Markdown("# Sticker Finder") | |
gr.Markdown("Upload an image and adjust thresholds to see detections.") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Input Image") | |
text_query = gr.Textbox( | |
value="stickers. labels. postcards.", | |
label="Text Query (lowercase, end each with '.', for example 'a bird. a tree.')" | |
) | |
box_threshold = gr.Slider(0.0, 1.0, 0.14, step=0.05, label="Box Threshold") | |
text_threshold = gr.Slider(0.0, 1.0, 0.13, step=0.05, label="Text Threshold") | |
submit_btn = gr.Button("Detect") | |
with gr.Column(): | |
image_output = gr.Image(type="pil", label="Detections") | |
# Galleria per i crop | |
gallery = gr.Gallery( | |
label="Detected Crops", | |
columns=[4], | |
rows=[2], | |
object_fit="contain", | |
height="auto" | |
) | |
# Esempi | |
gr.Examples( | |
examples=examples, | |
inputs=[image_input, text_query, box_threshold, text_threshold], | |
outputs=[image_output, gallery], | |
fn=detect_and_draw, | |
cache_examples=True | |
) | |
# Pulsante di submit | |
submit_btn.click( | |
fn=detect_and_draw, | |
inputs=[image_input, text_query, box_threshold, text_threshold], | |
outputs=[image_output, gallery] | |
) | |
# Pulisci i file temporanei quando viene caricato un nuovo esempio | |
demo.load( | |
fn=lambda: None, | |
inputs=None, | |
outputs=None, | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", share=False) |