Update app.py
Browse files
app.py
CHANGED
@@ -1,28 +1,13 @@
|
|
1 |
# app.py – YOLOv8 Dataset Quality Evaluator for Hugging Face Spaces
|
2 |
"""
|
3 |
-
Gradio application
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
* Keep
|
9 |
-
*
|
10 |
-
*
|
11 |
-
|
12 |
-
Checks implemented
|
13 |
-
------------------
|
14 |
-
1. **Dataset integrity** – verify that every image has a label file (or an allowed empty‑label exemption) and that each
|
15 |
-
label file parses correctly.
|
16 |
-
2. **Class stats / balance** – count instances per class and per‑image instance distribution.
|
17 |
-
3. **Image quality** – flag blurry, too‑dark or over‑bright images using simple OpenCV heuristics.
|
18 |
-
4. **Duplicate & near‑duplicate images** – perceptual‑hash pass (fallback) or FastDup if available.
|
19 |
-
5. **Duplicate boxes** – IoU > 0.9 duplicates in the same image.
|
20 |
-
6. **Optional model‑assisted label QA** – if the user provides a YOLO weights file, run inference and compute IoU‑based
|
21 |
-
agreement metrics plus Cleanlab label‑quality scores when the library is installed.
|
22 |
-
7. **Composite scoring** – combine sub‑scores (with adjustable weights) into a final 0‑100 quality score.
|
23 |
-
|
24 |
-
The code is intentionally modular: each check lives in its own function that returns a `dict` of metrics; adding new
|
25 |
-
checks is as simple as creating another function that follows the same signature and adding it to the `CHECKS` list.
|
26 |
"""
|
27 |
from __future__ import annotations
|
28 |
|
@@ -32,10 +17,10 @@ import os
|
|
32 |
import shutil
|
33 |
import tempfile
|
34 |
from collections import Counter
|
|
|
35 |
from dataclasses import dataclass
|
36 |
from pathlib import Path
|
37 |
from typing import Dict, List, Tuple
|
38 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
39 |
|
40 |
import gradio as gr
|
41 |
import numpy as np
|
@@ -44,7 +29,7 @@ import yaml
|
|
44 |
from PIL import Image
|
45 |
from tqdm import tqdm
|
46 |
|
47 |
-
# Optional
|
48 |
try:
|
49 |
import cv2 # type: ignore
|
50 |
except ImportError:
|
@@ -65,171 +50,150 @@ try:
|
|
65 |
except ImportError:
|
66 |
cl_rank = None
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
# --------------------------------------------------------------------------------------
|
71 |
-
# Utility dataclasses
|
72 |
-
# --------------------------------------------------------------------------------------
|
73 |
-
@dataclass
|
74 |
-
class ImageMetrics:
|
75 |
-
path: Path
|
76 |
-
width: int
|
77 |
-
height: int
|
78 |
-
blur_score: float | None = None
|
79 |
-
brightness: float | None = None
|
80 |
-
|
81 |
-
@property
|
82 |
-
def aspect_ratio(self) -> float:
|
83 |
-
return self.width / self.height if self.height else 0
|
84 |
-
|
85 |
|
|
|
|
|
|
|
86 |
@dataclass
|
87 |
class DuplicateGroup:
|
88 |
hash_val: str
|
89 |
paths: List[Path]
|
90 |
|
91 |
|
92 |
-
#
|
93 |
-
#
|
94 |
-
#
|
95 |
|
96 |
-
def load_yaml(
|
97 |
-
with
|
98 |
return yaml.safe_load(f)
|
99 |
|
100 |
|
101 |
-
def parse_label_file(
|
102 |
-
"""Return list of (class_id, x_center, y_center, width, height)."""
|
103 |
entries: List[Tuple[int, float, float, float, float]] = []
|
104 |
-
with
|
105 |
-
for
|
106 |
-
parts =
|
107 |
if len(parts) != 5:
|
108 |
-
raise ValueError(f"Malformed
|
109 |
-
|
110 |
-
entries.append((int(
|
111 |
return entries
|
112 |
|
113 |
|
114 |
def guess_image_dirs(root: Path) -> List[Path]:
|
115 |
-
|
116 |
-
candidates = [
|
117 |
root / "images",
|
118 |
root / "train" / "images",
|
119 |
root / "valid" / "images",
|
120 |
root / "val" / "images",
|
121 |
root / "test" / "images",
|
122 |
]
|
123 |
-
return [
|
124 |
|
125 |
|
126 |
-
def gather_dataset(root: Path, yaml_path: Path | None = None)
|
127 |
-
"""Return (image_paths, label_paths, yaml_dict)."""
|
128 |
if yaml_path is None:
|
129 |
yaml_candidates = list(root.glob("*.yaml"))
|
130 |
if not yaml_candidates:
|
131 |
-
raise FileNotFoundError("
|
132 |
yaml_path = yaml_candidates[0]
|
133 |
meta = load_yaml(yaml_path)
|
134 |
|
135 |
-
|
136 |
-
if not
|
137 |
-
raise FileNotFoundError("No images directory found under dataset root
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
for
|
142 |
-
|
143 |
-
|
144 |
-
label_paths.append(label_path)
|
145 |
-
return image_paths, label_paths, meta
|
146 |
|
147 |
|
148 |
-
#
|
149 |
-
#
|
150 |
-
#
|
151 |
|
152 |
-
def _is_corrupt(
|
153 |
try:
|
154 |
-
with Image.open(
|
155 |
im.verify()
|
156 |
return False
|
157 |
-
except Exception:
|
158 |
return True
|
159 |
|
160 |
|
161 |
-
def check_integrity(
|
162 |
-
|
163 |
-
|
164 |
-
missing_images = [lbl for lbl in label_paths if lbl.exists() and not lbl.with_name("images").exists()]
|
165 |
|
166 |
-
|
167 |
-
corrupt_images = []
|
168 |
with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as ex:
|
169 |
-
|
170 |
-
for
|
171 |
-
if
|
172 |
-
|
173 |
|
174 |
-
score = 100 - (len(
|
175 |
return {
|
176 |
"name": "Integrity",
|
177 |
"score": max(score, 0),
|
178 |
"details": {
|
179 |
-
"missing_label_files": [str(p) for p in
|
180 |
-
"missing_image_files": [str(p) for p in
|
181 |
-
"corrupt_images": [str(p) for p in
|
182 |
},
|
183 |
}
|
184 |
|
185 |
|
186 |
-
def compute_class_stats(
|
187 |
-
|
188 |
-
|
189 |
-
for
|
190 |
-
if not
|
191 |
continue
|
192 |
-
boxes = parse_label_file(
|
193 |
-
|
194 |
-
|
195 |
-
if not
|
196 |
return {"name": "Class balance", "score": 0, "details": {"message": "No labels found"}}
|
197 |
-
|
198 |
-
balance_score = min_count / max_count * 100 if max_count else 0
|
199 |
return {
|
200 |
"name": "Class balance",
|
201 |
-
"score":
|
202 |
"details": {
|
203 |
-
"class_counts": dict(
|
204 |
"boxes_per_image_stats": {
|
205 |
-
"min": int(np.min(
|
206 |
-
"max": int(np.max(
|
207 |
-
"mean": float(np.mean(
|
208 |
},
|
209 |
},
|
210 |
}
|
211 |
|
212 |
|
213 |
-
def
|
214 |
if cv2 is None:
|
215 |
-
return {"name": "Image quality", "score": 100, "details": {"message": "cv2
|
216 |
blurry, dark, bright = [], [], []
|
217 |
-
for p in tqdm(
|
218 |
-
|
219 |
-
if
|
220 |
continue
|
221 |
-
gray = cv2.cvtColor(
|
222 |
-
|
223 |
-
|
224 |
-
if
|
225 |
blurry.append(p)
|
226 |
-
if
|
227 |
dark.append(p)
|
228 |
-
if
|
229 |
bright.append(p)
|
230 |
-
total = len(image_paths)
|
231 |
bad = len(set(blurry + dark + bright))
|
232 |
-
score = 100 - bad / max(
|
233 |
return {
|
234 |
"name": "Image quality",
|
235 |
"score": score,
|
@@ -241,17 +205,166 @@ def image_quality_metrics(image_paths: List[Path], blur_thresh: float = 100.0) -
|
|
241 |
}
|
242 |
|
243 |
|
244 |
-
def detect_duplicates(
|
|
|
|
|
245 |
if use_fastdup:
|
246 |
-
global FASTDUP_AVAILABLE
|
247 |
try:
|
248 |
import fastdup # type: ignore
|
249 |
|
250 |
FASTDUP_AVAILABLE = True
|
|
|
|
|
|
|
|
|
|
|
251 |
except ImportError:
|
252 |
use_fastdup = False
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# app.py – YOLOv8 Dataset Quality Evaluator for Hugging Face Spaces
|
2 |
"""
|
3 |
+
Gradio application that audits Roboflow/YOLO‑format object‑detection datasets. It computes a suite of quality metrics
|
4 |
+
— integrity, class balance, image quality, duplicate images, and optional model‑assisted label QA — and returns a
|
5 |
+
Markdown report plus a class‑distribution dataframe.
|
6 |
+
|
7 |
+
Ready for **Hugging Face Spaces**:
|
8 |
+
* Keep this file name `app.py` (Spaces detects it automatically).
|
9 |
+
* Spaces sets the webserver port in `$PORT`; we pass it to `demo.launch()`.
|
10 |
+
* Dependencies are listed in `requirements.txt` (see repo root).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
"""
|
12 |
from __future__ import annotations
|
13 |
|
|
|
17 |
import shutil
|
18 |
import tempfile
|
19 |
from collections import Counter
|
20 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
21 |
from dataclasses import dataclass
|
22 |
from pathlib import Path
|
23 |
from typing import Dict, List, Tuple
|
|
|
24 |
|
25 |
import gradio as gr
|
26 |
import numpy as np
|
|
|
29 |
from PIL import Image
|
30 |
from tqdm import tqdm
|
31 |
|
32 |
+
# Optional deps ---------------------------------------------------------------
|
33 |
try:
|
34 |
import cv2 # type: ignore
|
35 |
except ImportError:
|
|
|
50 |
except ImportError:
|
51 |
cl_rank = None
|
52 |
|
53 |
+
# ----------------------------------------------------------------------------
|
54 |
+
FASTDUP_AVAILABLE = False # toggled if library present
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
+
# ----------------------------------------------------------------------------
|
57 |
+
# Dataclasses
|
58 |
+
# ----------------------------------------------------------------------------
|
59 |
@dataclass
|
60 |
class DuplicateGroup:
|
61 |
hash_val: str
|
62 |
paths: List[Path]
|
63 |
|
64 |
|
65 |
+
# ----------------------------------------------------------------------------
|
66 |
+
# Helper functions
|
67 |
+
# ----------------------------------------------------------------------------
|
68 |
|
69 |
+
def load_yaml(path: Path) -> Dict:
|
70 |
+
with path.open("r", encoding="utf-8") as f:
|
71 |
return yaml.safe_load(f)
|
72 |
|
73 |
|
74 |
+
def parse_label_file(path: Path) -> List[Tuple[int, float, float, float, float]]:
|
|
|
75 |
entries: List[Tuple[int, float, float, float, float]] = []
|
76 |
+
with path.open("r", encoding="utf-8") as f:
|
77 |
+
for ln in f:
|
78 |
+
parts = ln.strip().split()
|
79 |
if len(parts) != 5:
|
80 |
+
raise ValueError(f"Malformed line in {path}: {ln}")
|
81 |
+
cid, *coords = parts
|
82 |
+
entries.append((int(cid), *map(float, coords)))
|
83 |
return entries
|
84 |
|
85 |
|
86 |
def guess_image_dirs(root: Path) -> List[Path]:
|
87 |
+
subs = [
|
|
|
88 |
root / "images",
|
89 |
root / "train" / "images",
|
90 |
root / "valid" / "images",
|
91 |
root / "val" / "images",
|
92 |
root / "test" / "images",
|
93 |
]
|
94 |
+
return [d for d in subs if d.exists()]
|
95 |
|
96 |
|
97 |
+
def gather_dataset(root: Path, yaml_path: Path | None = None):
|
|
|
98 |
if yaml_path is None:
|
99 |
yaml_candidates = list(root.glob("*.yaml"))
|
100 |
if not yaml_candidates:
|
101 |
+
raise FileNotFoundError("YAML not found — provide one or place it in dataset root")
|
102 |
yaml_path = yaml_candidates[0]
|
103 |
meta = load_yaml(yaml_path)
|
104 |
|
105 |
+
img_dirs = guess_image_dirs(root)
|
106 |
+
if not img_dirs:
|
107 |
+
raise FileNotFoundError("No images directory found under dataset root")
|
108 |
|
109 |
+
imgs = [p for d in img_dirs for p in d.rglob("*.*") if imghdr.what(p) is not None]
|
110 |
+
lbls: List[Path] = []
|
111 |
+
for p in imgs:
|
112 |
+
lbls.append(p.parent.parent / "labels" / f"{p.stem}.txt")
|
113 |
+
return imgs, lbls, meta
|
|
|
|
|
114 |
|
115 |
|
116 |
+
# ----------------------------------------------------------------------------
|
117 |
+
# Quality checks
|
118 |
+
# ----------------------------------------------------------------------------
|
119 |
|
120 |
+
def _is_corrupt(p: Path) -> bool:
|
121 |
try:
|
122 |
+
with Image.open(p) as im:
|
123 |
im.verify()
|
124 |
return False
|
125 |
+
except Exception:
|
126 |
return True
|
127 |
|
128 |
|
129 |
+
def check_integrity(imgs: List[Path], lbls: List[Path]) -> Dict:
|
130 |
+
miss_lbl = [i for i, l in zip(imgs, lbls) if not l.exists()]
|
131 |
+
miss_img = [l for l in lbls if l.exists() and not (l.parent.parent / "images" / f"{l.stem}{l.suffix}").exists()]
|
|
|
132 |
|
133 |
+
corrupt: List[Path] = []
|
|
|
134 |
with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as ex:
|
135 |
+
futs = {ex.submit(_is_corrupt, p): p for p in imgs}
|
136 |
+
for fu in tqdm(as_completed(futs), total=len(futs), desc="Integrity", leave=False):
|
137 |
+
if fu.result():
|
138 |
+
corrupt.append(futs[fu])
|
139 |
|
140 |
+
score = 100 - (len(miss_lbl) + len(miss_img) + len(corrupt)) / max(len(imgs), 1) * 100
|
141 |
return {
|
142 |
"name": "Integrity",
|
143 |
"score": max(score, 0),
|
144 |
"details": {
|
145 |
+
"missing_label_files": [str(p) for p in miss_lbl],
|
146 |
+
"missing_image_files": [str(p) for p in miss_img],
|
147 |
+
"corrupt_images": [str(p) for p in corrupt],
|
148 |
},
|
149 |
}
|
150 |
|
151 |
|
152 |
+
def compute_class_stats(lbls: List[Path]) -> Dict:
|
153 |
+
cls_counts = Counter()
|
154 |
+
boxes_per_img = []
|
155 |
+
for l in lbls:
|
156 |
+
if not l.exists():
|
157 |
continue
|
158 |
+
boxes = parse_label_file(l)
|
159 |
+
boxes_per_img.append(len(boxes))
|
160 |
+
cls_counts.update([b[0] for b in boxes])
|
161 |
+
if not cls_counts:
|
162 |
return {"name": "Class balance", "score": 0, "details": {"message": "No labels found"}}
|
163 |
+
bal_score = min(cls_counts.values()) / max(cls_counts.values()) * 100
|
|
|
164 |
return {
|
165 |
"name": "Class balance",
|
166 |
+
"score": bal_score,
|
167 |
"details": {
|
168 |
+
"class_counts": dict(cls_counts),
|
169 |
"boxes_per_image_stats": {
|
170 |
+
"min": int(np.min(boxes_per_img) if boxes_per_img else 0),
|
171 |
+
"max": int(np.max(boxes_per_img) if boxes_per_img else 0),
|
172 |
+
"mean": float(np.mean(boxes_per_img) if boxes_per_img else 0),
|
173 |
},
|
174 |
},
|
175 |
}
|
176 |
|
177 |
|
178 |
+
def image_quality(imgs: List[Path], blur_thresh: float = 100.0) -> Dict:
|
179 |
if cv2 is None:
|
180 |
+
return {"name": "Image quality", "score": 100, "details": {"message": "cv2 missing"}}
|
181 |
blurry, dark, bright = [], [], []
|
182 |
+
for p in tqdm(imgs, desc="Image quality", leave=False):
|
183 |
+
im = cv2.imread(str(p))
|
184 |
+
if im is None:
|
185 |
continue
|
186 |
+
gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
|
187 |
+
lapv = cv2.Laplacian(gray, cv2.CV_64F).var()
|
188 |
+
bright = np.mean(gray)
|
189 |
+
if lapv < blur_thresh:
|
190 |
blurry.append(p)
|
191 |
+
if bright < 25:
|
192 |
dark.append(p)
|
193 |
+
if bright > 230:
|
194 |
bright.append(p)
|
|
|
195 |
bad = len(set(blurry + dark + bright))
|
196 |
+
score = 100 - bad / max(len(imgs), 1) * 100
|
197 |
return {
|
198 |
"name": "Image quality",
|
199 |
"score": score,
|
|
|
205 |
}
|
206 |
|
207 |
|
208 |
+
def detect_duplicates(imgs: List[Path], use_fastdup: bool = False) -> Dict:
|
209 |
+
global FASTDUP_AVAILABLE
|
210 |
+
groups: List[DuplicateGroup] = []
|
211 |
if use_fastdup:
|
|
|
212 |
try:
|
213 |
import fastdup # type: ignore
|
214 |
|
215 |
FASTDUP_AVAILABLE = True
|
216 |
+
fd = fastdup.create(input_dir=str(imgs[0].parent.parent), work_dir="fastdup_work")
|
217 |
+
fd.run(num_images=0)
|
218 |
+
for h, lst in fd.clusters.items(): # type: ignore[attr-defined]
|
219 |
+
if len(lst) > 1:
|
220 |
+
groups.append(DuplicateGroup(h, [Path(p) for p in lst]))
|
221 |
except ImportError:
|
222 |
use_fastdup = False
|
223 |
+
if not use_fastdup:
|
224 |
+
if imagehash is None:
|
225 |
+
return {"name": "Duplicates", "score": 100, "details": {"message": "imagehash not installed"}}
|
226 |
+
hashes: Dict[str, List[Path]] = {}
|
227 |
+
for p in tqdm(imgs, desc="Hashing", leave=False):
|
228 |
+
h = str(imagehash.average_hash(Image.open(p)))
|
229 |
+
hashes.setdefault(h, []).append(p)
|
230 |
+
groups = [DuplicateGroup(h, v) for h, v in hashes.items() if len(v) > 1]
|
231 |
+
|
232 |
+
dup_count = sum(len(g.paths) - 1 for g in groups)
|
233 |
+
score = 100 - dup_count / max(len(imgs), 1) * 100
|
234 |
+
return {
|
235 |
+
"name": "Duplicates",
|
236 |
+
"score": score,
|
237 |
+
"details": {"groups": [[str(p) for p in g.paths] for g in groups]},
|
238 |
+
}
|
239 |
+
|
240 |
+
|
241 |
+
# -- Model‑assisted QA --------------------------------------------------------
|
242 |
+
|
243 |
+
def _rel_iou(box1, box2):
|
244 |
+
x1, y1, w1, h1 = box1
|
245 |
+
x2, y2, w2, h2 = box2
|
246 |
+
xa1, ya1, xa2, ya2 = x1 - w1 / 2, y1 - h1 / 2, x1 + w1 / 2, y1 + h1 / 2
|
247 |
+
xb1, yb1, xb2, yb2 = x2 - w2 / 2, y2 - h2 / 2, x2 + w2 / 2, y2 + h2 / 2
|
248 |
+
ix1, iy1 = max(xa1, xb1), max(ya1, yb1)
|
249 |
+
ix2, iy2 = min(xa2, xb2), min(ya2, yb2)
|
250 |
+
iw, ih = max(0, ix2 - ix1), max(0, iy2 - iy1)
|
251 |
+
inter = iw * ih
|
252 |
+
union = w1 * h1 + w2 * h2 - inter
|
253 |
+
return inter / union if union else 0
|
254 |
+
|
255 |
+
|
256 |
+
def model_qa(imgs: List[Path], lbls: List[Path], weights: str | None, iou_thr: float = 0.5) -> Dict:
|
257 |
+
if weights is None or YOLO is None:
|
258 |
+
return {"name": "Model QA", "score": 100, "details": {"message": "weights or YOLO not available"}}
|
259 |
+
model = YOLO(weights)
|
260 |
+
ious: List[float] = []
|
261 |
+
mism: List[Path] = []
|
262 |
+
# batch inference for speed
|
263 |
+
for i in tqdm(range(0, len(imgs), 16), desc="Model QA", leave=False):
|
264 |
+
batch = imgs[i : i + 16]
|
265 |
+
preds = model.predict(batch, verbose=False)
|
266 |
+
for pth, pred in zip(batch, preds):
|
267 |
+
gt = parse_label_file((pth.parent.parent / "labels" / f"{pth.stem}.txt"))
|
268 |
+
for (cls, x, y, w, h) in gt:
|
269 |
+
best = 0
|
270 |
+
for pb, pc in zip(pred.boxes.xywh, pred.boxes.cls): # type: ignore[attr-defined]
|
271 |
+
if int(pc) != cls:
|
272 |
+
continue
|
273 |
+
iou = _rel_iou((x, y, w, h), tuple(pb.tolist()))
|
274 |
+
best = max(best, iou)
|
275 |
+
ious.append(best)
|
276 |
+
if best < iou_thr:
|
277 |
+
mism.append(pth)
|
278 |
+
miou = float(np.mean(ious)) if ious else 1.0
|
279 |
+
return {
|
280 |
+
"name": "Model QA",
|
281 |
+
"score": miou * 100,
|
282 |
+
"details": {"mean_iou": miou, "mismatched_images": [str(p) for p in mism[:50]]},
|
283 |
+
}
|
284 |
+
|
285 |
+
|
286 |
+
# ----------------------------------------------------------------------------
|
287 |
+
# Scoring aggregation
|
288 |
+
# ----------------------------------------------------------------------------
|
289 |
+
DEFAULT_WEIGHTS = {
|
290 |
+
"Integrity": 0.3,
|
291 |
+
"Class balance": 0.15,
|
292 |
+
"Image quality": 0.15,
|
293 |
+
"Duplicates": 0.1,
|
294 |
+
"Model QA": 0.3,
|
295 |
+
}
|
296 |
+
|
297 |
+
|
298 |
+
def aggregate(res):
|
299 |
+
return sum(DEFAULT_WEIGHTS.get(r["name"], 0) * r["score"] for r in res)
|
300 |
+
|
301 |
+
|
302 |
+
# ----------------------------------------------------------------------------
|
303 |
+
# Gradio interface
|
304 |
+
# ----------------------------------------------------------------------------
|
305 |
+
|
306 |
+
def evaluate(dataset_zip: gr.File | None, dataset_path: str, yaml_file: gr.File | None, weights_file: gr.File | None):
|
307 |
+
if not dataset_zip and not dataset_path:
|
308 |
+
return "Please upload a dataset zip or enter a path", pd.DataFrame()
|
309 |
+
|
310 |
+
tmp: Path | None = None
|
311 |
+
root: Path
|
312 |
+
if dataset_zip:
|
313 |
+
tmp = Path(tempfile.mkdtemp())
|
314 |
+
shutil.unpack_archive(dataset_zip.name, tmp)
|
315 |
+
root = tmp
|
316 |
+
else:
|
317 |
+
root = Path(dataset_path)
|
318 |
+
|
319 |
+
yaml_path = Path(yaml_file.name) if yaml_file else None
|
320 |
+
|
321 |
+
imgs, lbls, _ = gather_dataset(root, yaml_path)
|
322 |
+
|
323 |
+
results = [
|
324 |
+
check_integrity(imgs, lbls),
|
325 |
+
compute_class_stats(lbls),
|
326 |
+
image_quality(imgs),
|
327 |
+
detect_duplicates(imgs),
|
328 |
+
model_qa(imgs, lbls, weights_file.name if weights_file else None),
|
329 |
+
]
|
330 |
+
final = aggregate(results)
|
331 |
+
|
332 |
+
# Markdown summary
|
333 |
+
lines = [f"# Dataset Quality Report\n\n**Overall score:** {final:.1f}/100\n"]
|
334 |
+
for r in results:
|
335 |
+
lines.append(f"## {r['name']} — {r['score']:.1f}")
|
336 |
+
if r["details"]:
|
337 |
+
lines.append("<details><summary>Details</summary>\n\n```json")
|
338 |
+
lines.append(json.dumps(r["details"], indent=2))
|
339 |
+
lines.append("```\n</details>\n")
|
340 |
+
md = "\n".join(lines)
|
341 |
+
|
342 |
+
class_counts = results[1]["details"].get("class_counts", {}) # type: ignore[index]
|
343 |
+
df = pd.DataFrame.from_dict(class_counts, orient="index", columns=["count"])
|
344 |
+
|
345 |
+
if tmp:
|
346 |
+
shutil.rmtree(tmp, ignore_errors=True)
|
347 |
+
return md, df
|
348 |
+
|
349 |
+
|
350 |
+
with gr.Blocks(title="YOLO Dataset Quality Evaluator") as demo:
|
351 |
+
gr.Markdown("""## YOLOv8 Dataset Quality Evaluator
|
352 |
+
Upload a Roboflow‑exported (or generic YOLO) dataset and get a quick quality report.
|
353 |
+
* Provide either a ZIP file or a server path.
|
354 |
+
* Optionally add trained weights to enable model‑assisted checks.
|
355 |
+
""")
|
356 |
+
with gr.Row():
|
357 |
+
zip_in = gr.File(label="Dataset ZIP")
|
358 |
+
path_in = gr.Textbox(label="dataset path on server", placeholder="/data/my_dataset")
|
359 |
+
with gr.Row():
|
360 |
+
yaml_in = gr.File(label="custom YAML", file_types=[".yaml"])
|
361 |
+
weights_in = gr.File(label="YOLO weights (.pt)")
|
362 |
+
btn = gr.Button("Evaluate")
|
363 |
+
out_md = gr.Markdown()
|
364 |
+
out_df = gr.Dataframe()
|
365 |
+
|
366 |
+
btn.click(evaluate, inputs=[zip_in, path_in, yaml_in, weights_in], outputs=[out_md, out_df])
|
367 |
+
|
368 |
+
|
369 |
+
if __name__ == "__main__":
|
370 |
+
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
|