evaltest / app.py
wuhp's picture
Update app.py
22b976e verified
raw
history blame
8.3 kB
"""
app.py – Roboflow‑aware YOLOv8 Dataset Quality Evaluator (v3)
─────────────────────────────────────────────────────────────
ChangelogΒ (2025‑04‑17Β b)
β€’Β **Cleanlab** integration β†’ extra *label‑issue* metric (skips gracefully if lib missing).
β€’Β New **BBoxΒ validity** check: flags coords outsideΒ [0,β€―1].
β€’Β Weight table updated (IntegrityΒ 25β€―%, ModelΒ 20β€―%, CleanlabΒ 10β€―%, etc.).
β€’Β Minor: switched to cached NumPy reader for labels, clarified envΒ vars.
"""
from __future__ import annotations
import imghdr
import json
import os
import re
import shutil
import tempfile
from collections import Counter, defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple
import gradio as gr
import numpy as np
import pandas as pd
import yaml
from PIL import Image
from tqdm import tqdm
# ───────────────────────────── Optional heavy deps (fail‑soft) ──
try:
import cv2 # type: ignore
except ImportError:
cv2 = None
try:
import imagehash # type: ignore
except ImportError:
imagehash = None
try:
import fastdup # type: ignore
except ImportError:
fastdup = None
try:
from cleanlab.object_detection import find_label_issues # type: ignore
except (ImportError, AttributeError):
find_label_issues = None # type: ignore
try:
from ultralytics import YOLO # type: ignore
except ImportError:
YOLO = None # noqa: N806
try:
from roboflow import Roboflow # type: ignore
except ImportError:
Roboflow = None # type: ignore
# ───────────────────────────────────────── Config & constants ──
TMP_ROOT = Path(tempfile.gettempdir()) / "rf_datasets"
TMP_ROOT.mkdir(parents=True, exist_ok=True)
CPU_COUNT = int(os.getenv("QC_CPU", max(1, (os.cpu_count() or 4) // 2)))
BATCH = int(os.getenv("QC_BATCH", 16))
DEFAULT_W = {
"Integrity": 0.25,
"Class balance": 0.15,
"Image quality": 0.15,
"Duplicates": 0.10,
"BBox validity": 0.05,
"Model QA": 0.20,
"Cleanlab QA": 0.10,
}
@dataclass
class DuplicateGroup:
hash_val: str
paths: List[Path]
# ───────────────────────────────────────── Generic helpers ─────
def load_yaml(path: Path) -> Dict:
with path.open(encoding="utf-8") as f:
return yaml.safe_load(f)
_label_cache: dict[Path, np.ndarray] = {}
def load_labels_np(path: Path) -> np.ndarray:
if path in _label_cache:
return _label_cache[path]
try:
arr = np.loadtxt(path, dtype=float)
if arr.ndim == 1:
arr = arr.reshape(1, -1)
except Exception:
arr = np.empty((0, 5))
_label_cache[path] = arr
return arr
def guess_image_dirs(root: Path) -> List[Path]:
subs = [
root / "images",
root / "train" / "images",
root / "valid" / "images",
root / "val" / "images",
root / "test" / "images",
]
return [d for d in subs if d.exists()]
def gather_dataset(root: Path, yaml_path: Path | None = None):
if yaml_path is None:
yamls = list(root.glob("*.yaml"))
if not yamls:
raise FileNotFoundError("Dataset YAML not found")
yaml_path = yamls[0]
meta = load_yaml(yaml_path)
img_dirs = guess_image_dirs(root)
if not img_dirs:
raise FileNotFoundError("images/ directory hierarchy missing")
imgs = [p for d in img_dirs for p in d.rglob("*.*") if imghdr.what(p) is not None]
labels_root = {d.parent / "labels" for d in img_dirs}
lbls = [next((lr / f"{p.stem}.txt" for lr in labels_root if (lr / f"{p.stem}.txt").exists()), None) for p in imgs]
return imgs, lbls, meta
# ───────────────────────────────────────── Quality checks ─────
# Integrity -----------------------------------------------------
def _is_corrupt(path: Path) -> bool:
try:
with Image.open(path) as im:
im.verify()
return False
except Exception:
return True
def qc_integrity(imgs: List[Path], lbls: List[Path]):
miss_lbl = [i for i, l in zip(imgs, lbls) if l is None]
corrupt: List[Path] = []
with ProcessPoolExecutor(max_workers=CPU_COUNT) as ex:
fut = {ex.submit(_is_corrupt, p): p for p in imgs}
for f in tqdm(as_completed(fut), total=len(fut), desc="integrity", leave=False):
if f.result():
corrupt.append(fut[f])
score = 100 - (len(miss_lbl) + len(corrupt)) / max(len(imgs), 1) * 100
return {
"name": "Integrity",
"score": max(score, 0),
"details": {
"missing_label_files": [str(p) for p in miss_lbl],
"corrupt_images": [str(p) for p in corrupt],
},
}
# Class balance -------------------------------------------------
def qc_class_balance(lbls: List[Path]):
cls_counts = Counter()
boxes_per_img = []
for l in lbls:
arr = load_labels_np(l) if l else np.empty((0, 5))
boxes_per_img.append(len(arr))
cls_counts.update(arr[:, 0].astype(int) if arr.size else [])
if not cls_counts:
return {"name": "Class balance", "score": 0, "details": "No labels"}
bal = (min(cls_counts.values()) / max(cls_counts.values())) * 100
return {
"name": "Class balance",
"score": bal,
"details": {
"class_counts": dict(cls_counts),
"boxes_per_image": {
"min": int(np.min(boxes_per_img)),
"max": int(np.max(boxes_per_img)),
"mean": float(np.mean(boxes_per_img)),
},
},
}
# Image quality -------------------------------------------------
def _quality_stat(path: Path, blur_thr: float):
im = cv2.imread(str(path)) if cv2 else None
if im is None:
return path, False, False, False
gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
lap = cv2.Laplacian(gray, cv2.CV_64F).var()
br = gray.mean()
return path, lap < blur_thr, br < 25, br > 230
def qc_image_quality(imgs: List[Path], blur_thr: float = 100.0):
if cv2 is None:
return {"name": "Image quality", "score": 100, "details": "cv2 not installed"}
blurry, dark, bright = [], [], []
with ProcessPoolExecutor(max_workers=CPU_COUNT) as ex:
for p, is_blur, is_dark, is_bright in tqdm(
ex.map(lambda x: _quality_stat(x, blur_thr), imgs),
total=len(imgs),
desc="img‑quality",
leave=False,
):
if is_blur:
blurry.append(p)
if is_dark:
dark.append(p)
if is_bright:
bright.append(p)
bad = len(set(blurry + dark + bright))
score = 100 - bad / max(len(imgs), 1) * 100
return {
"name": "Image quality",
"score": score,
"details": {
"blurry": [str(p) for p in blurry],
"dark": [str(p) for p in dark],
"bright": [str(p) for p in bright],
},
}
# Duplicate images ---------------------------------------------
def qc_duplicates(imgs: List[Path]):
if fastdup is not None and len(imgs) > 50:
try:
fd = fastdup.create(input_dir=str(Path(imgs[0]).parent.parent), work_dir=str(TMP_ROOT / "fastdup"))
fd.run()
clusters = fd.get_clusters()
dup = sum(len(c) - 1 for c in clusters)
score = 100 - dup / max(len(imgs), 1) * 100
return {"name": "Duplicates", "score": score, "details": {"groups": clusters[:50]}}
except Exception:
pass
if imagehash is None:
return {"name": "Duplicates", "score": 100, "details": "skipped (deps)"}
def _hash(p):
return str(imagehash.average_hash(Image.open(p)))
hashes: Dict[str, List[Path]] = defaultdict(list)
with ProcessPoolExecutor(max_workers=CPU