evaltest / app.py
wuhp's picture
Update app.py
e09a48c verified
raw
history blame
16.1 kB
"""
app.py – Roboflow‑aware YOLOv8 Dataset Quality Evaluator (v2)
Changelog (2025‑04‑17)
──────────────────────
β€’ **CPU‑bound loops parallelised** with `concurrent.futures.ProcessPoolExecutor`.
β€’ **Batch inference** in `qc_model_qa()` (GPU util ↑, latency ↓).
β€’ Optional **fastdup** path for duplicate detection (β‰ˆβ€―10Γ— faster on large sets).
β€’ Faster NumPy‑based `parse_label_file()`.
β€’ Small refactors β†’ clearer separation of stages & fewer globals.
β€’ Graceful degradation if heavy deps unavailable (cv2, imagehash, fastdup).
β€’ Tunable `CPU_COUNT` + env‑var guard for HF Spaces quota.
"""
from __future__ import annotations
import imghdr
import json
import os
import re
import shutil
import tempfile
from collections import Counter, defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple
import gradio as gr
import numpy as np
import pandas as pd
import yaml
from PIL import Image
from tqdm import tqdm
# ───────────────────────────────────────── Heavy optional deps ──
try:
import cv2 # type: ignore
except ImportError:
cv2 = None
try:
import imagehash # type: ignore
except ImportError:
imagehash = None
try:
import fastdup # type: ignore
except ImportError:
fastdup = None
try:
from ultralytics import YOLO # type: ignore
except ImportError:
YOLO = None # noqa: N806
try:
from roboflow import Roboflow # type: ignore
except ImportError:
Roboflow = None # type: ignore
# ───────────────────────────────────────── Config & constants ──
TMP_ROOT = Path(tempfile.gettempdir()) / "rf_datasets"
TMP_ROOT.mkdir(parents=True, exist_ok=True)
# Limit CPU workers on HF Spaces (feel free to raise locally)
CPU_COUNT = int(os.getenv("QC_CPU", max(1, (os.cpu_count() or 4) // 2)))
BATCH = int(os.getenv("QC_BATCH", 16))
DEFAULT_W = {
"Integrity": 0.30,
"Class balance": 0.15,
"Image quality": 0.15,
"Duplicates": 0.10,
"Model QA": 0.30,
}
@dataclass
class DuplicateGroup:
hash_val: str
paths: List[Path]
# ───────────────────────────────────────── Generic helpers ─────
def load_yaml(path: Path) -> Dict:
with path.open(encoding="utf-8") as f:
return yaml.safe_load(f)
def parse_label_file(path: Path) -> list[tuple[int, float, float, float, float]]:
if not path.exists() or path.stat().st_size == 0:
return []
try:
arr = np.loadtxt(path, dtype=float)
if arr.ndim == 1:
arr = arr.reshape(1, -1)
return [tuple(row) for row in arr]
except Exception:
return []
def guess_image_dirs(root: Path) -> List[Path]:
subs = [
root / "images",
root / "train" / "images",
root / "valid" / "images",
root / "val" / "images",
root / "test" / "images",
]
return [d for d in subs if d.exists()]
def gather_dataset(root: Path, yaml_path: Path | None = None):
if yaml_path is None:
yamls = list(root.glob("*.yaml"))
if not yamls:
raise FileNotFoundError("Dataset YAML not found")
yaml_path = yamls[0]
meta = load_yaml(yaml_path)
img_dirs = guess_image_dirs(root)
if not img_dirs:
raise FileNotFoundError("images/ directory hierarchy missing")
imgs = [p for d in img_dirs for p in d.rglob("*.*") if imghdr.what(p) is not None]
labels_root = {d.parent / "labels" for d in img_dirs}
lbls = [next((lr / f"{p.stem}.txt" for lr in labels_root if (lr / f"{p.stem}.txt").exists()), None) for p in imgs]
return imgs, lbls, meta
# ───────────────────────────────────────── Quality checks ─────
# Integrity -----------------------------------------------------
def _is_corrupt(path: Path) -> bool:
try:
with Image.open(path) as im:
im.verify()
return False
except Exception:
return True
def qc_integrity(imgs: List[Path], lbls: List[Path]):
miss_lbl = [i for i, l in zip(imgs, lbls) if l is None]
corrupt: List[Path] = []
with ProcessPoolExecutor(max_workers=CPU_COUNT) as ex:
fut = {ex.submit(_is_corrupt, p): p for p in imgs}
for f in tqdm(as_completed(fut), total=len(fut), desc="integrity", leave=False):
if f.result():
corrupt.append(fut[f])
score = 100 - (len(miss_lbl) + len(corrupt)) / max(len(imgs), 1) * 100
return {
"name": "Integrity",
"score": max(score, 0),
"details": {
"missing_label_files": [str(p) for p in miss_lbl],
"corrupt_images": [str(p) for p in corrupt],
},
}
# Class balance -------------------------------------------------
def qc_class_balance(lbls: List[Path]):
cls_counts = Counter()
boxes_per_img = []
for l in lbls:
bs = parse_label_file(l) if l else []
boxes_per_img.append(len(bs))
cls_counts.update(b[0] for b in bs)
if not cls_counts:
return {"name": "Class balance", "score": 0, "details": "No labels"}
bal = (min(cls_counts.values()) / max(cls_counts.values())) * 100
return {
"name": "Class balance",
"score": bal,
"details": {
"class_counts": dict(cls_counts),
"boxes_per_image": {
"min": int(np.min(boxes_per_img)),
"max": int(np.max(boxes_per_img)),
"mean": float(np.mean(boxes_per_img)),
},
},
}
# Image quality -------------------------------------------------
def _quality_stat(path: Path, blur_thr: float):
im = cv2.imread(str(path)) if cv2 else None
if im is None:
return path, False, False, False
gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
lap = cv2.Laplacian(gray, cv2.CV_64F).var()
br = gray.mean()
return path, lap < blur_thr, br < 25, br > 230
def qc_image_quality(imgs: List[Path], blur_thr: float = 100.0):
if cv2 is None:
return {"name": "Image quality", "score": 100, "details": "cv2 not installed"}
blurry: list[Path] = []
dark: list[Path] = []
bright: list[Path] = []
with ProcessPoolExecutor(max_workers=CPU_COUNT) as ex:
for p, is_blur, is_dark, is_bright in tqdm(
ex.map(lambda x: _quality_stat(x, blur_thr), imgs),
total=len(imgs),
desc="img‑quality",
leave=False,
):
if is_blur:
blurry.append(p)
if is_dark:
dark.append(p)
if is_bright:
bright.append(p)
bad = len(set(blurry + dark + bright))
score = 100 - bad / max(len(imgs), 1) * 100
return {
"name": "Image quality",
"score": score,
"details": {
"blurry": [str(p) for p in blurry],
"dark": [str(p) for p in dark],
"bright": [str(p) for p in bright],
},
}
# Duplicate images ---------------------------------------------
def qc_duplicates(imgs: List[Path]):
# Fast path – use fastdup if installed & enough images
if fastdup is not None and len(imgs) > 50:
try:
fd = fastdup.create(input_dir=str(Path(imgs[0]).parent.parent), work_dir=str(TMP_ROOT / "fastdup"))
fd.run()
clusters = fd.get_clusters()
dup = sum(len(c) - 1 for c in clusters)
score = 100 - dup / max(len(imgs), 1) * 100
return {
"name": "Duplicates",
"score": score,
"details": {"groups": clusters[:50]},
}
except Exception:
pass # fallback to hash
if imagehash is None:
return {"name": "Duplicates", "score": 100, "details": "skipped (deps)"}
def _hash(p):
return str(imagehash.average_hash(Image.open(p)))
hashes: Dict[str, List[Path]] = defaultdict(list)
with ProcessPoolExecutor(max_workers=CPU_COUNT) as ex:
for h, p in tqdm(
zip(ex.map(_hash, imgs), imgs),
total=len(imgs),
desc="hashing",
leave=False,
):
hashes[h].append(p)
groups = [g for g in hashes.values() if len(g) > 1]
dup = sum(len(g) - 1 for g in groups)
score = 100 - dup / max(len(imgs), 1) * 100
return {
"name": "Duplicates",
"score": score,
"details": {"groups": [[str(p) for p in g] for g in groups[:50]]},
}
# Model‑assisted QA --------------------------------------------
def _rel_iou(b1, b2):
x1, y1, w1, h1 = b1
x2, y2, w2, h2 = b2
xa1, ya1, xa2, ya2 = x1 - w1 / 2, y1 - h1 / 2, x1 + w1 / 2, y1 + h1 / 2
xb1, yb1, xb2, yb2 = x2 - w2 / 2, y2 - h2 / 2, x2 + w2 / 2, y2 + h2 / 2
ix1, iy1, ix2, iy2 = max(xa1, xb1), max(ya1, yb1), min(xa2, xb2), min(ya2, yb2)
inter = max(ix2 - ix1, 0) * max(iy2 - iy1, 0)
union = w1 * h1 + w2 * h2 - inter
return inter / union if union else 0.0
def qc_model_qa(imgs: List[Path], weights: str | None, lbls: List[Path], iou_thr: float = 0.5):
if weights is None or YOLO is None:
return {"name": "Model QA", "score": 100, "details": "skipped (no weights)"}
model = YOLO(weights)
ious, mism = [], []
for i in range(0, len(imgs), BATCH):
batch_paths = imgs[i : i + BATCH]
results = model.predict(batch_paths, verbose=False)
for p, res in zip(batch_paths, results):
gtb = parse_label_file(p.parent.parent / "labels" / f"{p.stem}.txt")
if not gtb:
continue
for cls, x, y, w, h in gtb:
best = 0.0
for b, c in zip(res.boxes.xywh.cpu().numpy(), res.boxes.cls.cpu().numpy()):
if int(c) != cls:
continue
best = max(best, _rel_iou((x, y, w, h), tuple(b)))
ious.append(best)
if best < iou_thr:
mism.append(str(p))
miou = float(np.mean(ious)) if ious else 1.0
return {
"name": "Model QA",
"score": miou * 100,
"details": {"mean_iou": miou, "mismatched_images": mism[:50]},
}
# Aggregate -----------------------------------------------------
def aggregate(scores):
return sum(DEFAULT_W.get(r["name"], 0) * r["score"] for r in scores)
# ───────────────────────────────────────── Roboflow helpers ────
RF_RE = re.compile(r"https://universe\.roboflow\.com/([^/]+)/([^/]+)/dataset/(\d+)")
def download_rf_dataset(url: str, rf_api: "Roboflow", dest: Path) -> Path:
m = RF_RE.match(url.strip())
if not m:
raise ValueError(f"Bad RF URL: {url}")
ws, proj, ver = m.groups()
ds_dir = dest / f"{ws}_{proj}_v{ver}"
if ds_dir.exists():
return ds_dir
project = rf_api.workspace(ws).project(proj)
project.version(int(ver)).download("yolov8", location=str(ds_dir))
return ds_dir
# ───────────────────────────────────────── Main logic ──────────
def run_quality(root: Path, yaml_override: Path | None, weights: Path | None):
imgs, lbls, meta = gather_dataset(root, yaml_override)
res = [
qc_integrity(imgs, lbls),
qc_class_balance(lbls),
qc_image_quality(imgs),
qc_duplicates(imgs),
qc_model_qa(imgs, str(weights) if weights else None, lbls),
]
final = aggregate(res)
md = [f"## **{meta.get('name', root.name)}**Β β€”Β ScoreΒ {final:.1f}/100"]
for r in res:
md.append(f"### {r['name']}Β Β {r['score']:.1f}")
md.append("<details><summary>details</summary>\n\n```json")
md.append(json.dumps(r["details"], indent=2))
md.append("```\n</details>\n")
md_str = "\n".join(md)
cls_counts = res[1]["details"].get("class_counts", {}) # type: ignore[index]
df = pd.DataFrame.from_dict(cls_counts, orient="index", columns=["count"])
df.index.name = "class"
return md_str, df
# ───────────────────────────────────────── Gradio UI ───────────
def evaluate(
api_key: str,
url_txt: gr.File | None,
zip_file: gr.File | None,
server_path: str,
yaml_file: gr.File | None,
weights: gr.File | None,
):
if not any([url_txt, zip_file, server_path]):
return "Upload a .txt of URLs or dataset ZIP/path", pd.DataFrame()
reports, dfs = [], []
# Roboflow batch ------------------------------------------
if url_txt:
if Roboflow is None:
return "`roboflow` not installed", pd.DataFrame()
if not api_key:
return "Enter Roboflow API key", pd.DataFrame()
rf = Roboflow(api_key=api_key.strip())
for line in Path(url_txt.name).read_text().splitlines():
if not line.strip():
continue
try:
ds_root = download_rf_dataset(line, rf, TMP_ROOT)
md, df = run_quality(ds_root, None, Path(weights.name) if weights else None)
reports.append(md)
dfs.append(df)
except Exception as e:
reports.append(f"### {line}\n\n⚠️ {e}")
# Manual ZIP ----------------------------------------------
if zip_file:
tmp_dir = Path(tempfile.mkdtemp())
shutil.unpack_archive(zip_file.name, tmp_dir)
md, df = run_quality(tmp_dir, Path(yaml_file.name) if yaml_file else None, Path(weights.name) if weights else None)
reports.append(md)
dfs.append(df)
shutil.rmtree(tmp_dir, ignore_errors=True)
# Manual path ---------------------------------------------
if server_path:
md, df = run_quality(Path(server_path), Path(yaml_file.name) if yaml_file else None, Path(weights.name) if weights else None)
reports.append(md)
dfs.append(df)
summary_md = "\n\n---\n\n".join(reports)
combined_df = pd.concat(dfs).groupby(level=0).sum() if dfs else pd.DataFrame()
return summary_md, combined_df
# ───────────────────────────────────────── Launch ────────────
with gr.Blocks(title="YOLO Dataset Quality Evaluator") as demo:
gr.Markdown(
"""
# YOLOv8 Dataset Quality Evaluator
### Roboflow batch
1. Paste your **Roboflow API key**
2. Upload a **.txt** file – one `https://universe.roboflow.com/.../dataset/x` per line
### Manual
* Upload a dataset **ZIP** or type a dataset **path** on the server
* Optionally supply a custom **data.yaml** and/or a **YOLOΒ .pt** weights file for model‑assisted QA
"""
)
with gr.Row():
api_in = gr.Textbox(label="Roboflow API key", type="password", placeholder="rf_XXXXXXXXXXXXXXXX")
url_txt_in = gr.File(label=".txt of RF dataset URLs", file_types=[".txt"])
with gr.Row():
zip_in = gr.File(label="Dataset ZIP")
path_in = gr.Textbox(label="Path on server", placeholder="/data/my_dataset")
with gr.Row():
yaml_in = gr.File(label="Custom YAML", file_types=[".yaml"])
weights_in = gr.File(label="YOLO weights (.pt)")
run_btn = gr.Button("Evaluate")
out_md = gr.Markdown()
out_df = gr.Dataframe()
run_btn.click(
evaluate,
inputs=[api_in, url_txt_in, zip_in, path_in, yaml_in, weights_in],
outputs=[out_md, out_df],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))