Spaces:
Sleeping
Sleeping
""" | |
Gradio app to compare object‑detection models: | |
• Ultralytics YOLOv12 (n, s, m, l, x) | |
• Ultralytics YOLOv11 (n, s, m, l, x) | |
• Roboflow RF‑DETR (Base, Large) | |
• Custom fine‑tuned checkpoints (.pt/.pth upload) | |
Revision 2025‑04‑19‑c: | |
• Re‑indented entire file with 4‑space consistency to remove `IndentationError`. | |
• Thin, semi‑transparent 60 % boxes; concise error labels. | |
""" | |
from __future__ import annotations | |
import time | |
from pathlib import Path | |
from typing import Dict, List, Optional, Tuple | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
import supervision as sv | |
from ultralytics import YOLO | |
from rfdetr import RFDETRBase, RFDETRLarge | |
from rfdetr.util.coco_classes import COCO_CLASSES | |
############################################################################### | |
# Model registry & lazy loader | |
############################################################################### | |
YOLO_MODEL_MAP: Dict[str, str] = { | |
"YOLOv12‑n": "yolov12n.pt", | |
"YOLOv12‑s": "yolov12s.pt", | |
"YOLOv12‑m": "yolov12m.pt", | |
"YOLOv12‑l": "yolov12l.pt", | |
"YOLOv12‑x": "yolov12x.pt", | |
"YOLOv11‑n": "yolov11n.pt", | |
"YOLOv11‑s": "yolov11s.pt", | |
"YOLOv11‑m": "yolov11m.pt", | |
"YOLOv11‑l": "yolov11l.pt", | |
"YOLOv11‑x": "yolov11x.pt", | |
} | |
RFDETR_MODEL_MAP = { | |
"RF‑DETR‑Base (29M)": "base", | |
"RF‑DETR‑Large (128M)": "large", | |
} | |
ALL_MODELS = list(YOLO_MODEL_MAP.keys()) + list(RFDETR_MODEL_MAP.keys()) + [ | |
"Custom YOLO (.pt/.pth)", | |
"Custom RF‑DETR (.pth)", | |
] | |
_loaded: Dict[str, object] = {} | |
def load_model(choice: str, custom_file: Optional[Path] = None): | |
"""Return and cache a detector matching *choice*.""" | |
if choice in _loaded: | |
return _loaded[choice] | |
try: | |
if choice in YOLO_MODEL_MAP: | |
model = YOLO(YOLO_MODEL_MAP[choice]) | |
elif choice in RFDETR_MODEL_MAP: | |
model = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge() | |
elif choice.startswith("Custom YOLO"): | |
if custom_file is None: | |
raise ValueError("Upload a YOLO .pt/.pth checkpoint first.") | |
model = YOLO(str(custom_file)) | |
elif choice.startswith("Custom RF‑DETR"): | |
if custom_file is None: | |
raise ValueError("Upload an RF‑DETR .pth checkpoint first.") | |
model = RFDETRBase(pretrain_weights=str(custom_file)) | |
else: | |
raise ValueError(f"Unsupported model choice: {choice}") | |
except Exception as exc: | |
raise RuntimeError(str(exc)) from exc | |
_loaded[choice] = model | |
return model | |
############################################################################### | |
# Inference helpers | |
############################################################################### | |
BOX_THICKNESS = 2 | |
BOX_ALPHA = 0.6 | |
box_annotator = sv.BoxAnnotator(thickness=BOX_THICKNESS) | |
label_annotator = sv.LabelAnnotator() | |
def _blend(base: np.ndarray, overlay: np.ndarray, alpha: float = BOX_ALPHA) -> np.ndarray: | |
return cv2.addWeighted(overlay, alpha, base, 1 - alpha, 0) | |
def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[Image.Image, float]: | |
start = time.perf_counter() | |
if isinstance(model, (RFDETRBase, RFDETRLarge)): | |
detections = model.predict(image, threshold=threshold) | |
label_src = COCO_CLASSES | |
else: | |
ul_result = model.predict(image, verbose=False)[0] | |
detections = sv.Detections.from_ultralytics(ul_result) | |
label_src = model.names # type: ignore | |
runtime = time.perf_counter() - start | |
base_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
overlay = base_bgr.copy() | |
overlay = box_annotator.annotate(overlay, detections) | |
overlay = label_annotator.annotate( | |
overlay, | |
detections, | |
[f"{label_src[cid]} {conf:.2f}" for cid, conf in zip(detections.class_id, detections.confidence)], | |
) | |
blended = _blend(base_bgr, overlay) | |
out_pil = Image.fromarray(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)) | |
return out_pil, runtime | |
############################################################################### | |
# Gradio callback | |
############################################################################### | |
def compare_models( | |
models: List[str], | |
img: Image.Image, | |
threshold: float, | |
custom_file: Optional[Path], | |
): | |
if img is None: | |
raise gr.Error("Please upload an image first.") | |
if img.mode != "RGB": | |
img = img.convert("RGB") | |
results: List[Image.Image] = [] | |
legends: Dict[str, str] = {} | |
for model_name in models: | |
try: | |
detector = load_model(model_name, custom_file) | |
annotated, latency = run_single_inference(detector, img, threshold) | |
results.append(annotated) | |
legends[model_name] = f"{latency*1000:.1f} ms" | |
except Exception as exc: | |
results.append(Image.new("RGB", img.size, (40, 40, 40))) | |
emsg = str(exc) | |
if "No such file" in emsg or "not found" in emsg: | |
legends[model_name] = "Unavailable (weights not found)" | |
else: | |
legends[model_name] = f"ERROR: {emsg.splitlines()[0][:120]}" | |
return results, legends | |
############################################################################### | |
# Gradio UI | |
############################################################################### | |
def build_demo(): | |
with gr.Blocks(title="CV Model Comparison") as demo: | |
gr.Markdown( | |
"""# 🔍 Compare Object‑Detection Models\nUpload an image, choose detectors, and optionally add a custom checkpoint.\nBounding boxes are thin (2 px) and 60 % transparent for clarity.""" | |
) | |
with gr.Row(): | |
sel_models = gr.CheckboxGroup(ALL_MODELS, value=["YOLOv12‑n"], label="Models") | |
conf_slider = gr.Slider(0.0, 1.0, 0.5, 0.05, label="Confidence") | |
ckpt_file = gr.File(label="Custom checkpoint (.pt/.pth)", file_types=[".pt", ".pth"], interactive=True) | |
img_in = gr.Image(type="pil", label="Image", sources=["upload", "webcam"]) | |
with gr.Row(): | |
gallery = gr.Gallery(label="Results", columns=2, height="auto") | |
legend_out = gr.JSON(label="Latency / status by model") | |
run_btn = gr.Button("Run Inference", variant="primary") | |
run_btn.click(compare_models, [sel_models, img_in, conf_slider, ckpt_file], [gallery, legend_out]) | |
return demo | |
if __name__ == "__main__": | |
build_demo().launch() | |