Spaces:
Sleeping
Sleeping
| """ | |
| Gradio app to compare object‑detection models: | |
| • Ultralytics YOLOv12 (n, s, m, l, x) | |
| • Ultralytics YOLOv11 (n, s, m, l, x) | |
| • Roboflow RF‑DETR (Base, Large) | |
| • Custom fine‑tuned checkpoints (.pt/.pth upload) | |
| Revision 2025‑04‑19‑d: | |
| • Pre‑loads all selected models before running detections, with a visible progress bar. | |
| • Progress shows two phases: *Loading weights* and *Running inference*. | |
| • Keeps thin, semi‑transparent boxes and concise error labels. | |
| """ | |
| from __future__ import annotations | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import gradio as gr | |
| import supervision as sv | |
| from ultralytics import YOLO | |
| from rfdetr import RFDETRBase, RFDETRLarge | |
| from rfdetr.util.coco_classes import COCO_CLASSES | |
| ############################################################################### | |
| # Model registry & cache | |
| ############################################################################### | |
| YOLO_MODEL_MAP: Dict[str, str] = { | |
| # NOTE: Ultralytics filenames do NOT include the "v" character | |
| "YOLOv12‑n": "yolo12n.pt", | |
| "YOLOv12‑s": "yolo12s.pt", | |
| "YOLOv12‑m": "yolo12m.pt", | |
| "YOLOv12‑l": "yolo12l.pt", | |
| "YOLOv12‑x": "yolo12x.pt", | |
| "YOLOv11‑n": "yolo11n.pt", | |
| "YOLOv11‑s": "yolo11s.pt", | |
| "YOLOv11‑m": "yolo11m.pt", | |
| "YOLOv11‑l": "yolo11l.pt", | |
| "YOLOv11‑x": "yolo11x.pt", | |
| } | |
| RFDETR_MODEL_MAP = { | |
| "RF‑DETR‑Base (29M)": "base", | |
| "RF‑DETR‑Large (128M)": "large", | |
| } | |
| ALL_MODELS = list(YOLO_MODEL_MAP.keys()) + list(RFDETR_MODEL_MAP.keys()) + [ | |
| "Custom YOLO (.pt/.pth)", | |
| "Custom RF‑DETR (.pth)", | |
| ] | |
| _loaded: Dict[str, object] = {} | |
| def load_model(choice: str, custom_file: Optional[Path] = None): | |
| """Fetch and cache a detector instance for *choice*.""" | |
| if choice in _loaded: | |
| return _loaded[choice] | |
| if choice in YOLO_MODEL_MAP: | |
| model = YOLO(YOLO_MODEL_MAP[choice]) # Ultralytics auto‑downloads if missing | |
| elif choice in RFDETR_MODEL_MAP: | |
| model = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge() | |
| elif choice.startswith("Custom YOLO"): | |
| if custom_file is None: | |
| raise RuntimeError("Upload a YOLO .pt/.pth checkpoint first.") | |
| model = YOLO(str(custom_file)) | |
| elif choice.startswith("Custom RF‑DETR"): | |
| if custom_file is None: | |
| raise RuntimeError("Upload an RF‑DETR .pth checkpoint first.") | |
| model = RFDETRBase(pretrain_weights=str(custom_file)) | |
| else: | |
| raise RuntimeError(f"Unsupported model choice: {choice}") | |
| _loaded[choice] = model | |
| return model | |
| ############################################################################### | |
| # Inference helpers | |
| ############################################################################### | |
| BOX_THICKNESS = 2 # thinner boxes | |
| BOX_ALPHA = 0.6 # 60 % opacity | |
| box_annotator = sv.BoxAnnotator(thickness=BOX_THICKNESS) | |
| label_annotator = sv.LabelAnnotator() | |
| def _blend(base: np.ndarray, overlay: np.ndarray, alpha: float = BOX_ALPHA) -> np.ndarray: | |
| return cv2.addWeighted(overlay, alpha, base, 1 - alpha, 0) | |
| def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[Image.Image, float]: | |
| start = time.perf_counter() | |
| if isinstance(model, (RFDETRBase, RFDETRLarge)): | |
| detections = model.predict(image, threshold=threshold) | |
| label_src = COCO_CLASSES | |
| else: | |
| ul_res = model.predict(image, verbose=False)[0] | |
| detections = sv.Detections.from_ultralytics(ul_res) | |
| label_src = model.names # type: ignore | |
| runtime = time.perf_counter() - start | |
| img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| overlay = img_bgr.copy() | |
| overlay = box_annotator.annotate(overlay, detections) | |
| overlay = label_annotator.annotate( | |
| overlay, | |
| detections, | |
| [f"{label_src[c]} {p:.2f}" for c, p in zip(detections.class_id, detections.confidence)], | |
| ) | |
| blended = _blend(img_bgr, overlay) | |
| return Image.fromarray(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)), runtime | |
| ############################################################################### | |
| # Gradio generator callback with progress | |
| ############################################################################### | |
| def compare_models( | |
| models: List[str], | |
| img: Image.Image, | |
| threshold: float, | |
| custom_file: Optional[Path], | |
| ): | |
| if img is None: | |
| raise gr.Error("Please upload an image first.") | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| total_steps = len(models) * 2 # phase 1: load, phase 2: inference | |
| progress = gr.Progress() | |
| # ----- Phase 1: preload weights ----- | |
| detectors: Dict[str, object] = {} | |
| for i, name in enumerate(models, 1): | |
| try: | |
| detectors[name] = load_model(name, custom_file) | |
| except Exception as exc: | |
| detectors[name] = exc # store exception for later reporting | |
| progress(i, total=total_steps, desc=f"Loading {name}") | |
| # ----- Phase 2: run inference ----- | |
| results: List[Image.Image] = [] | |
| legends: Dict[str, str] = {} | |
| for j, name in enumerate(models, 1): | |
| detector_or_err = detectors[name] | |
| step_index = len(models) + j | |
| if isinstance(detector_or_err, Exception): | |
| # model failed to load | |
| results.append(Image.new("RGB", img.size, (40, 40, 40))) | |
| emsg = str(detector_or_err) | |
| legends[name] = "Unavailable (weights not found)" if "No such file" in emsg or "not found" in emsg else f"ERROR: {emsg.splitlines()[0][:120]}" | |
| progress(step_index, total=total_steps, desc=f"Skipped {name}") | |
| continue | |
| try: | |
| annotated, latency = run_single_inference(detector_or_err, img, threshold) | |
| results.append(annotated) | |
| legends[name] = f"{latency*1000:.1f} ms" | |
| except Exception as exc: | |
| results.append(Image.new("RGB", img.size, (40, 40, 40))) | |
| legends[name] = f"ERROR: {str(exc).splitlines()[0][:120]}" | |
| progress(step_index, total=total_steps, desc=f"Inference {name}") | |
| yield results, legends # final output | |
| ############################################################################### | |
| # Gradio UI | |
| ############################################################################### | |
| def build_demo(): | |
| with gr.Blocks(title="CV Model Comparison") as demo: | |
| gr.Markdown("""# 🔍 Compare Object‑Detection Models\nUpload an image, select detectors, then click **Run Inference**.\nThin, semi‑transparent boxes highlight detections.""") | |
| with gr.Row(): | |
| sel_models = gr.CheckboxGroup(ALL_MODELS, value=["YOLOv12‑n"], label="Models") | |
| conf_slider = gr.Slider(0.0, 1.0, 0.5, 0.05, label="Confidence") | |
| ckpt_file = gr.File(label="Custom checkpoint (.pt/.pth)", file_types=[".pt", ".pth"], interactive=True) | |
| img_in = gr.Image(type="pil", label="Image", sources=["upload", "webcam"]) | |
| with gr.Row(): | |
| gallery = gr.Gallery(label="Results", columns=2, height="auto") | |
| legend_out = gr.JSON(label="Latency / status by model") | |
| run_btn = gr.Button("Run Inference", variant="primary") | |
| run_btn.click(compare_models, [sel_models, img_in, conf_slider, ckpt_file], [gallery, legend_out]) | |
| return demo | |
| if __name__ == "__main__": | |
| build_demo().launch() | |