Spaces:
Running
Running
""" | |
Gradio app to compare object‑detection models: | |
• Ultralytics YOLOv12 (n, s, m, l, x) | |
• Ultralytics YOLOv11 (n, s, m, l, x) | |
• Roboflow RF‑DETR (Base, Large) | |
• Custom fine‑tuned checkpoints for either framework (upload .pt/.pth files) | |
Python ≥3.9 | |
Install: | |
pip install -r requirements.txt | |
Optionally, add GPU‑specific PyTorch wheels or `rfdetr[onnxexport]` for ONNX export. | |
""" | |
from __future__ import annotations | |
import time | |
from pathlib import Path | |
from typing import List, Tuple, Optional | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
import supervision as sv | |
from ultralytics import YOLO | |
from rfdetr import RFDETRBase, RFDETRLarge | |
from rfdetr.util.coco_classes import COCO_CLASSES | |
############################################################################### | |
# Model registry & lazy loader | |
############################################################################### | |
YOLO_MODEL_MAP = { | |
# Names follow Ultralytics hub convention; they will be auto‑downloaded | |
"YOLOv12‑n": "yolov12n.pt", | |
"YOLOv12‑s": "yolov12s.pt", | |
"YOLOv12‑m": "yolov12m.pt", | |
"YOLOv12‑l": "yolov12l.pt", | |
"YOLOv12‑x": "yolov12x.pt", | |
"YOLOv11‑n": "yolov11n.pt", | |
"YOLOv11‑s": "yolov11s.pt", | |
"YOLOv11‑m": "yolov11m.pt", | |
"YOLOv11‑l": "yolov11l.pt", | |
"YOLOv11‑x": "yolov11x.pt", | |
} | |
RFDETR_MODEL_MAP = { | |
"RF‑DETR‑Base (29M)": "base", | |
"RF‑DETR‑Large (128M)": "large", | |
} | |
ALL_MODELS = list(YOLO_MODEL_MAP.keys()) + list(RFDETR_MODEL_MAP.keys()) + [ | |
"Custom YOLO (.pt/.pth)", | |
"Custom RF‑DETR (.pth)", | |
] | |
_loaded = {} # cache of already‑instantiated models | |
def load_model(choice: str, custom_file: Optional[Path] = None): | |
"""Return (and cache) a model instance for *choice*. | |
custom_file is a Path object (uploaded file) used when choice is custom. | |
Raises RuntimeError with helpful message if loading fails. | |
""" | |
global _loaded | |
if choice in _loaded: | |
return _loaded[choice] | |
try: | |
if choice in YOLO_MODEL_MAP: | |
weight_id = YOLO_MODEL_MAP[choice] | |
mdl = YOLO(weight_id) # Ultralytics downloads if not local | |
elif choice in RFDETR_MODEL_MAP: | |
mdl = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge() | |
elif choice.startswith("Custom YOLO"): | |
if not custom_file: | |
raise ValueError("Upload a YOLO .pt/.pth checkpoint first.") | |
mdl = YOLO(str(custom_file)) | |
elif choice.startswith("Custom RF‑DETR"): | |
if not custom_file: | |
raise ValueError("Upload an RF‑DETR .pth checkpoint first.") | |
mdl = RFDETRBase(pretrain_weights=str(custom_file)) | |
else: | |
raise ValueError(f"Unsupported model choice: {choice}") | |
except FileNotFoundError as e: | |
raise RuntimeError( | |
f"Weights for '{choice}' not found locally and could not be downloaded. " | |
"Place the .pt file in the working directory, supply a custom checkpoint, " | |
"or ensure the model is released on the Ultralytics hub.\n" + str(e) | |
) from e | |
_loaded[choice] = mdl | |
return mdl | |
############################################################################### | |
# Inference helpers | |
############################################################################### | |
box_annotator = sv.BoxAnnotator() | |
label_annotator = sv.LabelAnnotator() | |
def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[Image.Image, float]: | |
start = time.perf_counter() | |
if isinstance(model, (RFDETRBase, RFDETRLarge)): | |
detections = model.predict(image, threshold=threshold) | |
label_source = COCO_CLASSES | |
else: # Ultralytics YOLO | |
result = model.predict(image, verbose=False)[0] | |
detections = sv.Detections.from_ultralytics(result) | |
label_source = model.names | |
runtime = time.perf_counter() - start | |
labels = [f"{label_source[cid]} {conf:.2f}" for cid, conf in zip(detections.class_id, detections.confidence)] | |
annotated = box_annotator.annotate(image.copy(), detections) | |
annotated = label_annotator.annotate(annotated, detections, labels) | |
return annotated, runtime | |
############################################################################### | |
# Gradio UI logic | |
############################################################################### | |
def compare_models(models: List[str], img: Image.Image, threshold: float, custom_file: Optional[Path]): | |
if img is None: | |
raise gr.Error("Please upload an image first.") | |
if img.mode != "RGB": | |
img = img.convert("RGB") | |
results, legends = [], [] | |
for m in models: | |
try: | |
model_obj = load_model(m, custom_file) | |
annotated, t = run_single_inference(model_obj, img, threshold) | |
results.append(annotated) | |
legends.append(f"{m} – {t*1000:.1f} ms") | |
except Exception as e: | |
# Append a blank image with the error message overlayed | |
error_img = Image.new("RGB", img.size, color=(30, 30, 30)) | |
legends.append(f"{m} – ERROR: {e}") | |
results.append(error_img) | |
return results, legends | |
############################################################################### | |
# Build & launch demo | |
############################################################################### | |
def build_demo(): | |
with gr.Blocks(title="CV Model Comparison") as demo: | |
gr.Markdown("""# 🔍 Compare Object‑Detection Models\nUpload an image, select detectors, and optionally upload a custom checkpoint.\nThe app annotates predictions and reports per‑model latency.""") | |
with gr.Row(): | |
model_select = gr.CheckboxGroup(choices=ALL_MODELS, value=["YOLOv12‑n"], label="Select models") | |
threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence threshold") | |
custom_checkpoint = gr.File(label="Upload custom YOLO / RF‑DETR checkpoint", file_types=[".pt", ".pth"], interactive=True) | |
image_in = gr.Image(type="pil", label="Upload image", sources=["upload", "webcam"], show_label=True) | |
with gr.Row(): | |
gallery = gr.Gallery(label="Annotated results", columns=2, height="auto") | |
legends_out = gr.JSON(label="Runtime (ms) or error messages") | |
run_btn = gr.Button("Run Inference", variant="primary") | |
run_btn.click( | |
fn=compare_models, | |
inputs=[model_select, image_in, threshold_slider, custom_checkpoint], | |
outputs=[gallery, legends_out], | |
) | |
return demo | |
if __name__ == "__main__": | |
build_demo().launch() | |