evaltest / app.py
wuhp's picture
Update app.py
cf7f887 verified
raw
history blame
14.8 kB
# app.py – Roboflow‑aware YOLOv8 Dataset Quality Evaluator for Hugging Face Spaces
#
# ▸ Prompts for a Roboflow **API key** and a `.txt` list of Universe dataset URLs (one per line)
# ▸ Downloads each dataset automatically in YOLOv8 format to a temp directory
# ▸ Runs a battery of quality checks:
# – integrity / corruption
# – class‑balance stats
# – blur / brightness image‑quality flags
# – exact / near‑duplicate detection
# – optional model‑assisted label QA (needs a YOLO .pt weights file)
# ▸ Still supports manual ZIP / server‑path evaluation
# ▸ Outputs a Markdown report + class‑distribution dataframe
#
# Hugging Face Spaces picks up `app.py` automatically. Dependencies go in `requirements.txt`.
# Spaces injects the port as $PORT – we pass it to demo.launch().
from __future__ import annotations
import imghdr
import json
import os
import re
import shutil
import tempfile
from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple
import gradio as gr
import numpy as np
import pandas as pd
import yaml
from PIL import Image
from tqdm import tqdm
# --------------------------------------------------------------------------- #
# Optional heavy deps – present locally, but fine‑grained to keep Spaces slim #
# --------------------------------------------------------------------------- #
try:
import cv2 # type: ignore
except ImportError:
cv2 = None
try:
import imagehash # type: ignore
except ImportError:
imagehash = None
try:
from ultralytics import YOLO # type: ignore
except ImportError:
YOLO = None # noqa: N806
try:
from roboflow import Roboflow # type: ignore
except ImportError:
Roboflow = None # type: ignore
# --------------------------------------------------------------------------- #
TMP_ROOT = Path(tempfile.gettempdir()) / "rf_datasets"
TMP_ROOT.mkdir(parents=True, exist_ok=True)
@dataclass
class DuplicateGroup:
hash_val: str
paths: List[Path]
# --------------------------------------------------------------------------- #
# Generic helpers #
# --------------------------------------------------------------------------- #
def load_yaml(path: Path) -> Dict:
with path.open(encoding="utf-8") as f:
return yaml.safe_load(f)
def parse_label_file(path: Path) -> List[Tuple[int, float, float, float, float]]:
out: List[Tuple[int, float, float, float, float]] = []
if not path.exists():
return out
with path.open(encoding="utf-8") as f:
for ln in f:
parts = ln.strip().split()
if len(parts) == 5:
cid, *coords = parts
out.append((int(cid), *map(float, coords)))
return out
def guess_image_dirs(root: Path) -> List[Path]:
subs = [
root / "images",
root / "train" / "images",
root / "valid" / "images",
root / "val" / "images",
root / "test" / "images",
]
return [d for d in subs if d.exists()]
def gather_dataset(root: Path, yaml_path: Path | None = None):
if yaml_path is None:
yamls = list(root.glob("*.yaml"))
if not yamls:
raise FileNotFoundError("Dataset YAML not found")
yaml_path = yamls[0]
meta = load_yaml(yaml_path)
img_dirs = guess_image_dirs(root)
if not img_dirs:
raise FileNotFoundError("images/ directory hierarchy missing")
imgs = [p for d in img_dirs for p in d.rglob("*.*") if imghdr.what(p) is not None]
lbls = [p.parent.parent / "labels" / f"{p.stem}.txt" for p in imgs]
return imgs, lbls, meta
# --------------------------------------------------------------------------- #
# Quality‑check stages #
# --------------------------------------------------------------------------- #
def _is_corrupt(path: Path) -> bool:
try:
with Image.open(path) as im:
im.verify()
return False
except Exception:
return True
def qc_integrity(imgs: List[Path], lbls: List[Path]) -> Dict:
miss_lbl = [i for i, l in zip(imgs, lbls) if not l.exists()]
miss_img = [l for l in lbls if l.exists() and not (l.parent.parent / "images" / f"{l.stem}{l.suffix}").exists()]
corrupt: List[Path] = []
with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as ex:
fut = {ex.submit(_is_corrupt, p): p for p in imgs}
for f in tqdm(as_completed(fut), total=len(fut), desc="integrity", leave=False):
if f.result():
corrupt.append(fut[f])
score = 100 - (len(miss_lbl) + len(miss_img) + len(corrupt)) / max(len(imgs), 1) * 100
return {
"name": "Integrity",
"score": max(score, 0),
"details": {
"missing_label_files": [str(p) for p in miss_lbl],
"missing_image_files": [str(p) for p in miss_img],
"corrupt_images": [str(p) for p in corrupt],
},
}
def qc_class_balance(lbls: List[Path]) -> Dict:
cls_counts = Counter()
boxes_per_img = []
for l in lbls:
bs = parse_label_file(l)
boxes_per_img.append(len(bs))
cls_counts.update(b[0] for b in bs)
if not cls_counts:
return {"name": "Class balance", "score": 0, "details": "No labels"}
bal = min(cls_counts.values()) / max(cls_counts.values()) * 100
return {
"name": "Class balance",
"score": bal,
"details": {
"class_counts": dict(cls_counts),
"boxes_per_image": {
"min": int(np.min(boxes_per_img)),
"max": int(np.max(boxes_per_img)),
"mean": float(np.mean(boxes_per_img)),
},
},
}
def qc_image_quality(imgs: List[Path], blur_thr: float = 100.0) -> Dict:
if cv2 is None:
return {"name": "Image quality", "score": 100, "details": "cv2 not installed"}
blurry, dark, bright = [], [], []
for p in tqdm(imgs, desc="img‑quality", leave=False):
im = cv2.imread(str(p))
if im is None:
continue
gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
lap = cv2.Laplacian(gray, cv2.CV_64F).var()
br = np.mean(gray)
if lap < blur_thr:
blurry.append(p)
if br < 25:
dark.append(p)
if br > 230:
bright.append(p)
bad = len(set(blurry + dark + bright))
score = 100 - bad / max(len(imgs), 1) * 100
return {
"name": "Image quality",
"score": score,
"details": {
"blurry": [str(p) for p in blurry],
"dark": [str(p) for p in dark],
"bright": [str(p) for p in bright],
},
}
def qc_duplicates(imgs: List[Path]) -> Dict:
if imagehash is None:
return {"name": "Duplicates", "score": 100, "details": "imagehash not installed"}
hashes: Dict[str, List[Path]] = defaultdict(list)
for p in tqdm(imgs, desc="hashing", leave=False):
h = str(imagehash.average_hash(Image.open(p)))
hashes[h].append(p)
groups = [g for g in hashes.values() if len(g) > 1]
dup = sum(len(g) - 1 for g in groups)
score = 100 - dup / max(len(imgs), 1) * 100
return {
"name": "Duplicates",
"score": score,
"details": {"groups": [[str(p) for p in g] for g in groups]},
}
def _rel_iou(b1, b2):
x1, y1, w1, h1 = b1
x2, y2, w2, h2 = b2
xa1, ya1, xa2, ya2 = x1 - w1 / 2, y1 - h1 / 2, x1 + w1 / 2, y1 + h1 / 2
xb1, yb1, xb2, yb2 = x2 - w2 / 2, y2 - h2 / 2, x2 + w2 / 2, y2 + h2 / 2
ix1, iy1, ix2, iy2 = max(xa1, xb1), max(ya1, yb1), min(xa2, xb2), min(ya2, yb2)
iw, ih = max(ix2 - ix1, 0), max(iy2 - iy1, 0)
inter = iw * ih
union = w1 * h1 + w2 * h2 - inter
return inter / union if union else 0
def qc_model_qa(imgs: List[Path], weights: str | None, lbls: List[Path], iou_thr: float = 0.5) -> Dict:
if weights is None or YOLO is None:
return {"name": "Model QA", "score": 100, "details": "weights or YOLO unavailable"}
model = YOLO(weights)
ious, mism = [], []
for p in tqdm(imgs, desc="model‑QA", leave=False):
gtb = parse_label_file(p.parent.parent / "labels" / f"{p.stem}.txt")
if not gtb:
continue
res = model.predict(p, verbose=False)[0]
for cls, x, y, w, h in gtb:
best = 0.0
for b, c in zip(res.boxes.xywh, res.boxes.cls):
if int(c) != cls:
continue
best = max(best, _rel_iou((x, y, w, h), tuple(b.tolist())))
ious.append(best)
if best < iou_thr:
mism.append(p)
miou = float(np.mean(ious)) if ious else 1.0
return {
"name": "Model QA",
"score": miou * 100,
"details": {"mean_iou": miou, "mismatched_images": [str(p) for p in mism[:50]]},
}
# --------------------------------------------------------------------------- #
DEFAULT_W = {
"Integrity": 0.30,
"Class balance": 0.15,
"Image quality": 0.15,
"Duplicates": 0.10,
"Model QA": 0.30,
}
def aggregate(scores):
return sum(DEFAULT_W.get(r["name"], 0) * r["score"] for r in scores)
# --------------------------------------------------------------------------- #
# Roboflow helpers #
# --------------------------------------------------------------------------- #
RF_RE = re.compile(r"https://universe\.roboflow\.com/([^/]+)/([^/]+)/dataset/(\d+)")
def download_rf_dataset(url: str, rf_api: "Roboflow", dest: Path) -> Path:
m = RF_RE.match(url.strip())
if not m:
raise ValueError(f"Bad RF URL: {url}")
ws, proj, ver = m.groups()
ds_dir = dest / f"{ws}_{proj}_v{ver}"
if ds_dir.exists():
return ds_dir
project = rf_api.workspace(ws).project(proj)
project.version(int(ver)).download("yolov8", location=str(ds_dir))
return ds_dir
# --------------------------------------------------------------------------- #
# Main evaluation logic #
# --------------------------------------------------------------------------- #
def run_quality(root: Path, yaml_override: Path | None, weights: Path | None):
imgs, lbls, meta = gather_dataset(root, yaml_override)
res = [
qc_integrity(imgs, lbls),
qc_class_balance(lbls),
qc_image_quality(imgs),
qc_duplicates(imgs),
qc_model_qa(imgs, str(weights) if weights else None, lbls),
]
final = aggregate(res)
# markdown
md = [f"## **{meta.get('name', root.name)}**  —  Score {final:.1f}/100"]
for r in res:
md.append(f"### {r['name']}  {r['score']:.1f}")
md.append("<details><summary>details</summary>\n\n```json")
md.append(json.dumps(r["details"], indent=2))
md.append("```\n</details>\n")
md_str = "\n".join(md)
cls_counts = res[1]["details"].get("class_counts", {}) # type: ignore[index]
df = pd.DataFrame.from_dict(cls_counts, orient="index", columns=["count"])
df.index.name = "class"
return md_str, df
# --------------------------------------------------------------------------- #
# Gradio interface #
# --------------------------------------------------------------------------- #
def evaluate(
api_key: str,
url_txt: gr.File | None,
zip_file: gr.File | None,
server_path: str,
yaml_file: gr.File | None,
weights: gr.File | None,
):
if not any([url_txt, zip_file, server_path]):
return "Upload a .txt of URLs or dataset ZIP/path", pd.DataFrame()
reports, dfs = [], []
# ---- Roboflow batch mode ----
if url_txt:
if Roboflow is None:
return "`roboflow` not installed", pd.DataFrame()
if not api_key:
return "Enter Roboflow API key", pd.DataFrame()
rf = Roboflow(api_key=api_key.strip())
txt_lines = Path(url_txt.name).read_text().splitlines()
for line in txt_lines:
if not line.strip():
continue
try:
ds_root = download_rf_dataset(line, rf, TMP_ROOT)
md, df = run_quality(ds_root, None, Path(weights.name) if weights else None)
reports.append(md)
dfs.append(df)
except Exception as e:
reports.append(f"### {line}\n\n⚠️ {e}")
# ---- Manual ZIP ----
if zip_file:
tmp_dir = Path(tempfile.mkdtemp())
shutil.unpack_archive(zip_file.name, tmp_dir)
md, df = run_quality(tmp_dir, Path(yaml_file.name) if yaml_file else None, Path(weights.name) if weights else None)
reports.append(md)
dfs.append(df)
shutil.rmtree(tmp_dir, ignore_errors=True)
# ---- Manual path ----
if server_path:
md, df = run_quality(Path(server_path), Path(yaml_file.name) if yaml_file else None, Path(weights.name) if weights else None)
reports.append(md)
dfs.append(df)
summary_md = "\n\n---\n\n".join(reports)
combined_df = pd.concat(dfs).groupby(level=0).sum() if dfs else pd.DataFrame()
return summary_md, combined_df
with gr.Blocks(title="YOLO Dataset Quality Evaluator") as demo:
gr.Markdown(
"""
# YOLOv8 Dataset Quality Evaluator
### Roboflow batch
1. Paste your **Roboflow API key**
2. Upload a **.txt** file – one `https://universe.roboflow.com/.../dataset/x` per line
### Manual
* Upload a dataset **ZIP** or type a dataset **path** on the server
* Optionally supply a custom **data.yaml** and/or a **YOLO .pt** weights file for model‑assisted QA
"""
)
with gr.Row():
api_in = gr.Textbox(label="Roboflow API key", type="password", placeholder="rf_XXXXXXXXXXXXXXXX")
url_txt_in = gr.File(label=".txt of RF dataset URLs", file_types=[".txt"])
with gr.Row():
zip_in = gr.File(label="Dataset ZIP")
path_in = gr.Textbox(label="Path on server", placeholder="/data/my_dataset")
with gr.Row():
yaml_in = gr.File(label="Custom YAML", file_types=[".yaml"])
weights_in = gr.File(label="YOLO weights (.pt)")
run_btn = gr.Button("Evaluate")
out_md = gr.Markdown()
out_df = gr.Dataframe()
run_btn.click(
evaluate,
inputs=[api_in, url_txt_in, zip_in, path_in, yaml_in, weights_in],
outputs=[out_md, out_df],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))