Update app.py
Browse files
app.py
CHANGED
|
@@ -17,7 +17,7 @@ REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2"
|
|
| 17 |
REPO_DIR = os.path.join(os.getcwd(), "third_party", "RT-DETRv2")
|
| 18 |
PY_IMPL_DIR = os.path.join(REPO_DIR, "rtdetrv2_pytorch") # Supervisely keeps PyTorch impl here
|
| 19 |
|
| 20 |
-
# Core deps + your requested packages; pinned as lower-bounds to avoid downgrades
|
| 21 |
COMMON_REQUIREMENTS = [
|
| 22 |
"gradio>=4.36.1",
|
| 23 |
"ultralytics>=8.2.0",
|
|
@@ -30,9 +30,9 @@ COMMON_REQUIREMENTS = [
|
|
| 30 |
"torchvision>=0.15.2",
|
| 31 |
"pyyaml>=6.0.1",
|
| 32 |
"Pillow>=10.0.0",
|
| 33 |
-
"supervisely>=6.0.0",
|
| 34 |
-
"tensorboard>=2.13.0",
|
| 35 |
-
"pycocotools>=2.0.7",
|
| 36 |
]
|
| 37 |
|
| 38 |
# === bootstrap (clone + pip) ===================================================
|
|
@@ -51,15 +51,17 @@ def ensure_repo_and_requirements():
|
|
| 51 |
except Exception:
|
| 52 |
logging.warning("git pull failed; continuing with current checkout")
|
| 53 |
|
| 54 |
-
#
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
#
|
|
|
|
| 58 |
req_file = os.path.join(PY_IMPL_DIR, "requirements.txt")
|
| 59 |
if os.path.exists(req_file):
|
| 60 |
pip_install(["-r", req_file])
|
| 61 |
|
| 62 |
-
# Double-check supervisely importability; if not, try again explicitly.
|
| 63 |
try:
|
| 64 |
import supervisely # noqa: F401
|
| 65 |
except Exception:
|
|
@@ -77,8 +79,10 @@ DEFAULT_MODEL_KEY = "rtdetrv2_s"
|
|
| 77 |
|
| 78 |
# === utilities ================================================================
|
| 79 |
def handle_remove_readonly(func, path, exc_info):
|
| 80 |
-
try:
|
| 81 |
-
|
|
|
|
|
|
|
| 82 |
func(path)
|
| 83 |
|
| 84 |
_ROBO_URL_RX = re.compile(r"""
|
|
@@ -104,8 +108,10 @@ def parse_roboflow_url(s: str):
|
|
| 104 |
version = None
|
| 105 |
if len(parts) >= 3:
|
| 106 |
v = parts[2]
|
| 107 |
-
if v.lower().startswith('v') and v[1:].isdigit():
|
| 108 |
-
|
|
|
|
|
|
|
| 109 |
return parts[0], parts[1], version
|
| 110 |
if '/' in s and 'roboflow' not in s:
|
| 111 |
p = s.split('/')
|
|
@@ -113,8 +119,10 @@ def parse_roboflow_url(s: str):
|
|
| 113 |
version = None
|
| 114 |
if len(p) >= 3:
|
| 115 |
v = p[2]
|
| 116 |
-
if v.lower().startswith('v') and v[1:].isdigit():
|
| 117 |
-
|
|
|
|
|
|
|
| 118 |
return p[0], p[1], version
|
| 119 |
return None, None, None
|
| 120 |
|
|
@@ -132,8 +140,10 @@ def _extract_class_names(data_yaml):
|
|
| 132 |
names = data_yaml.get('names', None)
|
| 133 |
if isinstance(names, dict):
|
| 134 |
def _k(x):
|
| 135 |
-
try:
|
| 136 |
-
|
|
|
|
|
|
|
| 137 |
keys = sorted(names.keys(), key=_k)
|
| 138 |
names_list = [names[k] for k in keys]
|
| 139 |
elif isinstance(names, list):
|
|
@@ -150,7 +160,8 @@ def download_dataset(api_key, workspace, project, version):
|
|
| 150 |
ver = proj.version(int(version))
|
| 151 |
dataset = ver.download("yolov8") # labels in YOLO format (we'll convert to COCO)
|
| 152 |
data_yaml_path = os.path.join(dataset.location, 'data.yaml')
|
| 153 |
-
with open(data_yaml_path, 'r') as f:
|
|
|
|
| 154 |
class_names = _extract_class_names(data_yaml)
|
| 155 |
splits = [s for s in ['train', 'valid', 'test'] if os.path.exists(os.path.join(dataset.location, s))]
|
| 156 |
return dataset.location, class_names, splits, f"{project}-v{version}"
|
|
@@ -170,7 +181,8 @@ def yolo_to_coco(split_dir_images, split_dir_labels, class_names, out_json):
|
|
| 170 |
ann_id = 1
|
| 171 |
img_id = 1
|
| 172 |
for fname in sorted(os.listdir(split_dir_images)):
|
| 173 |
-
if not fname.lower().endswith((".jpg",".jpeg",".png")):
|
|
|
|
| 174 |
img_path = os.path.join(split_dir_images, fname)
|
| 175 |
try:
|
| 176 |
with Image.open(img_path) as im:
|
|
@@ -183,19 +195,28 @@ def yolo_to_coco(split_dir_images, split_dir_labels, class_names, out_json):
|
|
| 183 |
with open(label_file, "r") as f:
|
| 184 |
for line in f:
|
| 185 |
parts = line.strip().split()
|
| 186 |
-
if len(parts) < 5:
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
annotations.append({
|
| 194 |
"id": ann_id,
|
| 195 |
"image_id": img_id,
|
| 196 |
"category_id": cls,
|
| 197 |
-
"bbox": [
|
| 198 |
-
"area": max(1.0, ww*hh),
|
| 199 |
"iscrowd": 0,
|
| 200 |
"segmentation": []
|
| 201 |
})
|
|
@@ -203,7 +224,8 @@ def yolo_to_coco(split_dir_images, split_dir_labels, class_names, out_json):
|
|
| 203 |
img_id += 1
|
| 204 |
coco = {"images": images, "annotations": annotations, "categories": categories}
|
| 205 |
os.makedirs(os.path.dirname(out_json), exist_ok=True)
|
| 206 |
-
with open(out_json, "w") as f:
|
|
|
|
| 207 |
|
| 208 |
def make_coco_annotations(merged_dir, class_names):
|
| 209 |
ann_dir = os.path.join(merged_dir, "annotations")
|
|
@@ -219,28 +241,34 @@ def make_coco_annotations(merged_dir, class_names):
|
|
| 219 |
|
| 220 |
# === dataset merging ==========================================================
|
| 221 |
def gather_class_counts(dataset_info, class_mapping):
|
| 222 |
-
if not dataset_info:
|
|
|
|
| 223 |
final_names = set(v for v in class_mapping.values() if v is not None)
|
| 224 |
counts = {name: 0 for name in final_names}
|
| 225 |
for loc, names, splits, _ in dataset_info:
|
| 226 |
id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)}
|
| 227 |
for split in splits:
|
| 228 |
labels_dir = os.path.join(loc, split, 'labels')
|
| 229 |
-
if not os.path.exists(labels_dir):
|
|
|
|
| 230 |
for label_file in os.listdir(labels_dir):
|
| 231 |
-
if not label_file.endswith('.txt'):
|
|
|
|
| 232 |
found = set()
|
| 233 |
with open(os.path.join(labels_dir, label_file), 'r') as f:
|
| 234 |
for line in f:
|
| 235 |
parts = line.strip().split()
|
| 236 |
-
if not parts:
|
|
|
|
| 237 |
try:
|
| 238 |
cls_id = int(parts[0])
|
| 239 |
mapped = id_to_name.get(cls_id, None)
|
| 240 |
-
if mapped:
|
|
|
|
| 241 |
except Exception:
|
| 242 |
continue
|
| 243 |
-
for m in found:
|
|
|
|
| 244 |
return counts
|
| 245 |
|
| 246 |
def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()):
|
|
@@ -260,7 +288,8 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
|
|
| 260 |
for loc, _, splits, _ in dataset_info:
|
| 261 |
for split in splits:
|
| 262 |
img_dir = os.path.join(loc, split, 'images')
|
| 263 |
-
if not os.path.exists(img_dir):
|
|
|
|
| 264 |
for img_file in os.listdir(img_dir):
|
| 265 |
if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
|
| 266 |
all_images.append((os.path.join(img_dir, img_file), split, loc))
|
|
@@ -272,24 +301,30 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
|
|
| 272 |
|
| 273 |
for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"):
|
| 274 |
lbl_path = label_path_for(img_path)
|
| 275 |
-
if not os.path.exists(lbl_path):
|
|
|
|
| 276 |
source_names = loc_to_names.get(source_loc, [])
|
| 277 |
image_classes = set()
|
| 278 |
with open(lbl_path, 'r') as f:
|
| 279 |
for line in f:
|
| 280 |
parts = line.strip().split()
|
| 281 |
-
if not parts:
|
|
|
|
| 282 |
try:
|
| 283 |
cls_id = int(parts[0])
|
| 284 |
orig = source_names[cls_id]
|
| 285 |
mapped = class_mapping.get(orig, orig)
|
| 286 |
-
if mapped in active_classes:
|
|
|
|
| 287 |
except Exception:
|
| 288 |
continue
|
| 289 |
-
if not image_classes:
|
| 290 |
-
|
|
|
|
|
|
|
| 291 |
selected_images.append((img_path, split))
|
| 292 |
-
for c in image_classes:
|
|
|
|
| 293 |
|
| 294 |
progress(0.6, desc=f"Copying {len(selected_images)} files...")
|
| 295 |
for img_path, split in progress.tqdm(selected_images, desc="Finalizing files"):
|
|
@@ -300,13 +335,16 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
|
|
| 300 |
|
| 301 |
source_loc = None
|
| 302 |
for info in dataset_info:
|
| 303 |
-
if img_path.startswith(info[0]):
|
|
|
|
|
|
|
| 304 |
source_names = loc_to_names.get(source_loc, [])
|
| 305 |
|
| 306 |
with open(lbl_path, 'r') as f_in, open(out_lbl, 'w') as f_out:
|
| 307 |
for line in f_in:
|
| 308 |
parts = line.strip().split()
|
| 309 |
-
if not parts:
|
|
|
|
| 310 |
try:
|
| 311 |
old_id = int(parts[0])
|
| 312 |
original_name = source_names[old_id]
|
|
@@ -334,10 +372,19 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
|
|
| 334 |
|
| 335 |
# === entrypoint + config detection/generation =================================
|
| 336 |
def find_training_script(repo_root):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
candidates = []
|
| 338 |
-
for pat in ["**/tools/train.py", "**/train.py"]:
|
| 339 |
candidates.extend(glob(os.path.join(repo_root, pat), recursive=True))
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
return candidates[0] if candidates else None
|
| 342 |
|
| 343 |
def find_model_config_template(model_key):
|
|
@@ -353,17 +400,37 @@ def find_model_config_template(model_key):
|
|
| 353 |
def score(p):
|
| 354 |
pl = p.lower()
|
| 355 |
s = 0
|
| 356 |
-
if "/rtdetrv2_pytorch/" in pl:
|
| 357 |
-
|
|
|
|
|
|
|
| 358 |
for token in want_tokens:
|
| 359 |
-
if token in os.path.basename(pl):
|
| 360 |
-
|
| 361 |
-
|
|
|
|
|
|
|
|
|
|
| 362 |
return -s, len(p)
|
| 363 |
|
| 364 |
yamls.sort(key=score)
|
| 365 |
return yamls[0] if yamls else None
|
| 366 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
| 368 |
epochs, batch, imgsz, lr, optimizer):
|
| 369 |
if not base_cfg_path or not os.path.exists(base_cfg_path):
|
|
@@ -383,7 +450,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 383 |
"out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
|
| 384 |
}
|
| 385 |
|
| 386 |
-
# dataset block
|
| 387 |
for root_key in ["dataset", "data"]:
|
| 388 |
if root_key in cfg and isinstance(cfg[root_key], dict):
|
| 389 |
ds = cfg[root_key]
|
|
@@ -393,21 +460,31 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 393 |
("test", "test_json", "test_img"),
|
| 394 |
]:
|
| 395 |
if split in ds and isinstance(ds[split], dict):
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
# num_classes
|
| 405 |
def set_num_classes(node, n):
|
| 406 |
-
if not isinstance(node, dict):
|
|
|
|
| 407 |
if "num_classes" in node:
|
| 408 |
-
node["num_classes"] = int(n)
|
|
|
|
| 409 |
for k, v in node.items():
|
| 410 |
-
if isinstance(v, dict) and set_num_classes(v, n):
|
|
|
|
| 411 |
return False
|
| 412 |
|
| 413 |
if "model" in cfg and isinstance(cfg["model"], dict):
|
|
@@ -420,17 +497,23 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 420 |
updated_epoch = False
|
| 421 |
for key in ["max_epoch", "epochs", "num_epochs"]:
|
| 422 |
if key in cfg:
|
| 423 |
-
cfg[key] = int(epochs)
|
|
|
|
|
|
|
| 424 |
if "solver" in cfg and isinstance(cfg["solver"], dict):
|
| 425 |
for key in ["max_epoch", "epochs", "num_epochs"]:
|
| 426 |
if key in cfg["solver"]:
|
| 427 |
-
cfg["solver"][key] = int(epochs)
|
|
|
|
|
|
|
| 428 |
if not updated_epoch:
|
| 429 |
cfg["max_epoch"] = int(epochs)
|
| 430 |
|
| 431 |
for key in ["input_size", "img_size", "imgsz"]:
|
| 432 |
-
if key in cfg:
|
| 433 |
-
|
|
|
|
|
|
|
| 434 |
|
| 435 |
# lr / optimizer / batch
|
| 436 |
if "solver" not in cfg or not isinstance(cfg["solver"], dict):
|
|
@@ -438,7 +521,8 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 438 |
sol = cfg["solver"]
|
| 439 |
for key in ["base_lr", "lr", "learning_rate"]:
|
| 440 |
if key in sol:
|
| 441 |
-
sol[key] = float(lr)
|
|
|
|
| 442 |
else:
|
| 443 |
sol["base_lr"] = float(lr)
|
| 444 |
|
|
@@ -456,9 +540,11 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 456 |
else:
|
| 457 |
cfg["output_dir"] = paths["out_dir"]
|
| 458 |
|
| 459 |
-
cfg_out_dir = os.path.join("generated_configs")
|
|
|
|
| 460 |
out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
|
| 461 |
-
with open(out_path, "w") as f:
|
|
|
|
| 462 |
return out_path
|
| 463 |
|
| 464 |
def find_best_checkpoint(out_dir):
|
|
@@ -470,16 +556,21 @@ def find_best_checkpoint(out_dir):
|
|
| 470 |
]
|
| 471 |
for p in pats:
|
| 472 |
f = sorted(glob(p, recursive=True))
|
| 473 |
-
if f:
|
| 474 |
-
|
| 475 |
-
|
|
|
|
|
|
|
|
|
|
| 476 |
return any_ckpt[-1] if any_ckpt else None
|
| 477 |
|
| 478 |
# === Gradio handlers ==========================================================
|
| 479 |
def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
|
| 480 |
api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "")
|
| 481 |
-
if not api_key:
|
| 482 |
-
|
|
|
|
|
|
|
| 483 |
|
| 484 |
with open(url_file.name, 'r', encoding='utf-8', errors='ignore') as f:
|
| 485 |
urls = [line.strip() for line in f if line.strip()]
|
|
@@ -497,8 +588,10 @@ def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
|
|
| 497 |
failures.append((raw, f"No latest version for {ws}/{proj}"))
|
| 498 |
continue
|
| 499 |
loc, names, splits, name_str = download_dataset(api_key, ws, proj, int(ver))
|
| 500 |
-
if loc:
|
| 501 |
-
|
|
|
|
|
|
|
| 502 |
|
| 503 |
if not dataset_info:
|
| 504 |
msg = "No datasets loaded.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
|
|
@@ -510,11 +603,13 @@ def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
|
|
| 510 |
df = pd.DataFrame([[n, n, counts.get(n, 0), False] for n in all_names],
|
| 511 |
columns=["Original Name", "Rename To", "Max Images", "Remove"])
|
| 512 |
status = "Datasets loaded successfully."
|
| 513 |
-
if failures:
|
|
|
|
| 514 |
return status, dataset_info, df
|
| 515 |
|
| 516 |
def update_class_counts_handler(class_df, dataset_info):
|
| 517 |
-
if class_df is None or not dataset_info:
|
|
|
|
| 518 |
class_df = pd.DataFrame(class_df)
|
| 519 |
mapping = {row["Original Name"]: (None if bool(row["Remove"]) else row["Rename To"])
|
| 520 |
for _, row in class_df.iterrows()}
|
|
@@ -524,41 +619,34 @@ def update_class_counts_handler(class_df, dataset_info):
|
|
| 524 |
id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)}
|
| 525 |
for split in splits:
|
| 526 |
labels_dir = os.path.join(loc, split, 'labels')
|
| 527 |
-
if not os.path.exists(labels_dir):
|
|
|
|
| 528 |
for label_file in os.listdir(labels_dir):
|
| 529 |
-
if not label_file.endswith('.txt'):
|
|
|
|
| 530 |
found = set()
|
| 531 |
with open(os.path.join(labels_dir, label_file), 'r') as f:
|
| 532 |
for line in f:
|
| 533 |
parts = line.strip().split()
|
| 534 |
-
if not parts:
|
|
|
|
| 535 |
try:
|
| 536 |
cls_id = int(parts[0])
|
| 537 |
mapped = id_to_final.get(cls_id, None)
|
| 538 |
-
if mapped:
|
|
|
|
| 539 |
except Exception:
|
| 540 |
continue
|
| 541 |
-
for m in found:
|
|
|
|
| 542 |
return pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"])
|
| 543 |
|
| 544 |
-
def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
|
| 545 |
-
if not dataset_info: raise gr.Error("Load datasets first in Tab 1.")
|
| 546 |
-
if class_df is None: raise gr.Error("Class data is missing.")
|
| 547 |
-
class_df = pd.DataFrame(class_df)
|
| 548 |
-
class_mapping, class_limits = {}, {}
|
| 549 |
-
for _, row in class_df.iterrows():
|
| 550 |
-
orig = row["Original Name"]
|
| 551 |
-
if bool(row["Remove"]): continue
|
| 552 |
-
final_name = row["Rename To"]
|
| 553 |
-
class_mapping[orig] = final_name
|
| 554 |
-
class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"])
|
| 555 |
-
status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
|
| 556 |
-
return status, path
|
| 557 |
-
|
| 558 |
def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()):
|
| 559 |
-
if not dataset_path:
|
|
|
|
| 560 |
|
| 561 |
train_script = find_training_script(REPO_DIR)
|
|
|
|
| 562 |
if not train_script:
|
| 563 |
raise gr.Error("RT-DETRv2 training script not found inside the repo (looked for **/tools/train.py).")
|
| 564 |
|
|
@@ -567,7 +655,8 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 567 |
raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/L/X).")
|
| 568 |
|
| 569 |
data_yaml = os.path.join(dataset_path, "data.yaml")
|
| 570 |
-
with open(data_yaml, "r") as f:
|
|
|
|
| 571 |
class_names = [str(x) for x in dy.get("names", [])]
|
| 572 |
make_coco_annotations(dataset_path, class_names)
|
| 573 |
|
|
@@ -600,7 +689,8 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 600 |
proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
|
| 601 |
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
| 602 |
bufsize=1, text=True, env=env)
|
| 603 |
-
for line in proc.stdout:
|
|
|
|
| 604 |
proc.wait()
|
| 605 |
q.put(f"__EXITCODE__:{proc.returncode}")
|
| 606 |
except Exception as e:
|
|
@@ -610,38 +700,78 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 610 |
|
| 611 |
log_tail, last_epoch, total_epochs = [], 0, int(epochs)
|
| 612 |
first_lines = []
|
|
|
|
| 613 |
while True:
|
| 614 |
line = q.get()
|
| 615 |
if line.startswith("__EXITCODE__"):
|
| 616 |
-
code = int(line.split(":",1)[1])
|
| 617 |
if code != 0:
|
| 618 |
head = "\n".join(first_lines[:60])
|
| 619 |
raise gr.Error(f"Training exited with code {code}.\nLast output:\n{head or 'No logs captured.'}")
|
| 620 |
break
|
| 621 |
if line.startswith("__ERROR__"):
|
| 622 |
-
raise gr.Error(f"Training failed: {line.split(':',1)[1]}")
|
| 623 |
|
| 624 |
-
if len(first_lines) < 120:
|
| 625 |
-
|
|
|
|
|
|
|
| 626 |
|
| 627 |
m = re.search(r"[Ee]poch\s+(\d+)\s*/\s*(\d+)", line)
|
| 628 |
if m:
|
| 629 |
try:
|
| 630 |
-
last_epoch = int(m.group(1))
|
| 631 |
-
|
| 632 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
|
| 634 |
-
fig1 = plt.figure(); plt.title("Loss (see logs)")
|
| 635 |
-
fig2 = plt.figure(); plt.title("mAP (see logs)")
|
| 636 |
yield "\n".join(log_tail), fig1, fig2, None
|
| 637 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
ckpt = find_best_checkpoint(out_dir) or find_best_checkpoint("runs")
|
| 639 |
if not ckpt or not os.path.exists(ckpt):
|
| 640 |
raise gr.Error("Training finished, but checkpoint file not found. Check logs/output directory.")
|
| 641 |
yield "Training complete!", None, None, gr.File.update(value=ckpt, visible=True)
|
| 642 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
|
| 644 |
-
if not model_file:
|
|
|
|
| 645 |
from huggingface_hub import HfApi, HfFolder
|
| 646 |
hf_status = "Skipped Hugging Face."
|
| 647 |
if hf_token and hf_repo:
|
|
@@ -658,21 +788,27 @@ def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr
|
|
| 658 |
if gh_token and gh_repo:
|
| 659 |
progress(0.5, desc="Uploading to GitHub...")
|
| 660 |
try:
|
| 661 |
-
if '/' not in gh_repo:
|
|
|
|
| 662 |
username, repo_name = gh_repo.split('/')
|
| 663 |
api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}"
|
| 664 |
headers = {"Authorization": f"token {gh_token}"}
|
| 665 |
-
with open(model_file.name, "rb") as f:
|
|
|
|
| 666 |
get_resp = requests.get(api_url, headers=headers, timeout=30)
|
| 667 |
sha = get_resp.json().get('sha') if get_resp.ok else None
|
| 668 |
data = {"message": "Upload trained model from Rolo app", "content": content}
|
| 669 |
-
if sha:
|
|
|
|
| 670 |
put_resp = requests.put(api_url, headers=headers, json=data, timeout=60)
|
| 671 |
-
if put_resp.ok:
|
| 672 |
-
|
|
|
|
|
|
|
| 673 |
except Exception as e:
|
| 674 |
gh_status = f"GitHub Error: {e}"
|
| 675 |
-
progress(1)
|
|
|
|
| 676 |
|
| 677 |
# === UI =======================================================================
|
| 678 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
|
|
@@ -751,4 +887,10 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
|
|
| 751 |
|
| 752 |
if __name__ == "__main__":
|
| 753 |
os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics") # silence stray warnings from other libs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 754 |
app.launch(debug=True)
|
|
|
|
| 17 |
REPO_DIR = os.path.join(os.getcwd(), "third_party", "RT-DETRv2")
|
| 18 |
PY_IMPL_DIR = os.path.join(REPO_DIR, "rtdetrv2_pytorch") # Supervisely keeps PyTorch impl here
|
| 19 |
|
| 20 |
+
# Core deps + your requested packages; pinned as lower-bounds to avoid downgrades (local runs only)
|
| 21 |
COMMON_REQUIREMENTS = [
|
| 22 |
"gradio>=4.36.1",
|
| 23 |
"ultralytics>=8.2.0",
|
|
|
|
| 30 |
"torchvision>=0.15.2",
|
| 31 |
"pyyaml>=6.0.1",
|
| 32 |
"Pillow>=10.0.0",
|
| 33 |
+
"supervisely>=6.0.0",
|
| 34 |
+
"tensorboard>=2.13.0",
|
| 35 |
+
"pycocotools>=2.0.7",
|
| 36 |
]
|
| 37 |
|
| 38 |
# === bootstrap (clone + pip) ===================================================
|
|
|
|
| 51 |
except Exception:
|
| 52 |
logging.warning("git pull failed; continuing with current checkout")
|
| 53 |
|
| 54 |
+
# On HF Spaces: expect requirements.txt to be used at build time; skip heavy runtime installs
|
| 55 |
+
if os.getenv("HF_SPACE") == "1" or os.getenv("SPACE_ID"):
|
| 56 |
+
logging.info("Detected Hugging Face Space — skipping runtime pip installs.")
|
| 57 |
+
return
|
| 58 |
|
| 59 |
+
# Local fallback (non-Spaces)
|
| 60 |
+
pip_install(COMMON_REQUIREMENTS)
|
| 61 |
req_file = os.path.join(PY_IMPL_DIR, "requirements.txt")
|
| 62 |
if os.path.exists(req_file):
|
| 63 |
pip_install(["-r", req_file])
|
| 64 |
|
|
|
|
| 65 |
try:
|
| 66 |
import supervisely # noqa: F401
|
| 67 |
except Exception:
|
|
|
|
| 79 |
|
| 80 |
# === utilities ================================================================
|
| 81 |
def handle_remove_readonly(func, path, exc_info):
|
| 82 |
+
try:
|
| 83 |
+
os.chmod(path, stat.S_IWRITE)
|
| 84 |
+
except Exception:
|
| 85 |
+
pass
|
| 86 |
func(path)
|
| 87 |
|
| 88 |
_ROBO_URL_RX = re.compile(r"""
|
|
|
|
| 108 |
version = None
|
| 109 |
if len(parts) >= 3:
|
| 110 |
v = parts[2]
|
| 111 |
+
if v.lower().startswith('v') and v[1:].isdigit():
|
| 112 |
+
version = int(v[1:])
|
| 113 |
+
elif v.isdigit():
|
| 114 |
+
version = int(v)
|
| 115 |
return parts[0], parts[1], version
|
| 116 |
if '/' in s and 'roboflow' not in s:
|
| 117 |
p = s.split('/')
|
|
|
|
| 119 |
version = None
|
| 120 |
if len(p) >= 3:
|
| 121 |
v = p[2]
|
| 122 |
+
if v.lower().startswith('v') and v[1:].isdigit():
|
| 123 |
+
version = int(v[1:])
|
| 124 |
+
elif v.isdigit():
|
| 125 |
+
version = int(v)
|
| 126 |
return p[0], p[1], version
|
| 127 |
return None, None, None
|
| 128 |
|
|
|
|
| 140 |
names = data_yaml.get('names', None)
|
| 141 |
if isinstance(names, dict):
|
| 142 |
def _k(x):
|
| 143 |
+
try:
|
| 144 |
+
return int(x)
|
| 145 |
+
except Exception:
|
| 146 |
+
return str(x)
|
| 147 |
keys = sorted(names.keys(), key=_k)
|
| 148 |
names_list = [names[k] for k in keys]
|
| 149 |
elif isinstance(names, list):
|
|
|
|
| 160 |
ver = proj.version(int(version))
|
| 161 |
dataset = ver.download("yolov8") # labels in YOLO format (we'll convert to COCO)
|
| 162 |
data_yaml_path = os.path.join(dataset.location, 'data.yaml')
|
| 163 |
+
with open(data_yaml_path, 'r') as f:
|
| 164 |
+
data_yaml = yaml.safe_load(f)
|
| 165 |
class_names = _extract_class_names(data_yaml)
|
| 166 |
splits = [s for s in ['train', 'valid', 'test'] if os.path.exists(os.path.join(dataset.location, s))]
|
| 167 |
return dataset.location, class_names, splits, f"{project}-v{version}"
|
|
|
|
| 181 |
ann_id = 1
|
| 182 |
img_id = 1
|
| 183 |
for fname in sorted(os.listdir(split_dir_images)):
|
| 184 |
+
if not fname.lower().endswith((".jpg", ".jpeg", ".png")):
|
| 185 |
+
continue
|
| 186 |
img_path = os.path.join(split_dir_images, fname)
|
| 187 |
try:
|
| 188 |
with Image.open(img_path) as im:
|
|
|
|
| 195 |
with open(label_file, "r") as f:
|
| 196 |
for line in f:
|
| 197 |
parts = line.strip().split()
|
| 198 |
+
if len(parts) < 5:
|
| 199 |
+
continue
|
| 200 |
+
try:
|
| 201 |
+
cls = int(float(parts[0]))
|
| 202 |
+
cx, cy, bw, bh = map(float, parts[1:5])
|
| 203 |
+
except Exception:
|
| 204 |
+
continue
|
| 205 |
+
x = max(0.0, (cx - bw / 2.0) * w)
|
| 206 |
+
y = max(0.0, (cy - bh / 2.0) * h)
|
| 207 |
+
ww = max(1.0, bw * w)
|
| 208 |
+
hh = max(1.0, bh * h)
|
| 209 |
+
# clamp right/bottom to image bounds
|
| 210 |
+
if x + ww > w:
|
| 211 |
+
ww = max(1.0, w - x)
|
| 212 |
+
if y + hh > h:
|
| 213 |
+
hh = max(1.0, h - y)
|
| 214 |
annotations.append({
|
| 215 |
"id": ann_id,
|
| 216 |
"image_id": img_id,
|
| 217 |
"category_id": cls,
|
| 218 |
+
"bbox": [x, y, ww, hh],
|
| 219 |
+
"area": max(1.0, ww * hh),
|
| 220 |
"iscrowd": 0,
|
| 221 |
"segmentation": []
|
| 222 |
})
|
|
|
|
| 224 |
img_id += 1
|
| 225 |
coco = {"images": images, "annotations": annotations, "categories": categories}
|
| 226 |
os.makedirs(os.path.dirname(out_json), exist_ok=True)
|
| 227 |
+
with open(out_json, "w") as f:
|
| 228 |
+
json.dump(coco, f)
|
| 229 |
|
| 230 |
def make_coco_annotations(merged_dir, class_names):
|
| 231 |
ann_dir = os.path.join(merged_dir, "annotations")
|
|
|
|
| 241 |
|
| 242 |
# === dataset merging ==========================================================
|
| 243 |
def gather_class_counts(dataset_info, class_mapping):
|
| 244 |
+
if not dataset_info:
|
| 245 |
+
return {}
|
| 246 |
final_names = set(v for v in class_mapping.values() if v is not None)
|
| 247 |
counts = {name: 0 for name in final_names}
|
| 248 |
for loc, names, splits, _ in dataset_info:
|
| 249 |
id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)}
|
| 250 |
for split in splits:
|
| 251 |
labels_dir = os.path.join(loc, split, 'labels')
|
| 252 |
+
if not os.path.exists(labels_dir):
|
| 253 |
+
continue
|
| 254 |
for label_file in os.listdir(labels_dir):
|
| 255 |
+
if not label_file.endswith('.txt'):
|
| 256 |
+
continue
|
| 257 |
found = set()
|
| 258 |
with open(os.path.join(labels_dir, label_file), 'r') as f:
|
| 259 |
for line in f:
|
| 260 |
parts = line.strip().split()
|
| 261 |
+
if not parts:
|
| 262 |
+
continue
|
| 263 |
try:
|
| 264 |
cls_id = int(parts[0])
|
| 265 |
mapped = id_to_name.get(cls_id, None)
|
| 266 |
+
if mapped:
|
| 267 |
+
found.add(mapped)
|
| 268 |
except Exception:
|
| 269 |
continue
|
| 270 |
+
for m in found:
|
| 271 |
+
counts[m] += 1
|
| 272 |
return counts
|
| 273 |
|
| 274 |
def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()):
|
|
|
|
| 288 |
for loc, _, splits, _ in dataset_info:
|
| 289 |
for split in splits:
|
| 290 |
img_dir = os.path.join(loc, split, 'images')
|
| 291 |
+
if not os.path.exists(img_dir):
|
| 292 |
+
continue
|
| 293 |
for img_file in os.listdir(img_dir):
|
| 294 |
if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
|
| 295 |
all_images.append((os.path.join(img_dir, img_file), split, loc))
|
|
|
|
| 301 |
|
| 302 |
for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"):
|
| 303 |
lbl_path = label_path_for(img_path)
|
| 304 |
+
if not os.path.exists(lbl_path):
|
| 305 |
+
continue
|
| 306 |
source_names = loc_to_names.get(source_loc, [])
|
| 307 |
image_classes = set()
|
| 308 |
with open(lbl_path, 'r') as f:
|
| 309 |
for line in f:
|
| 310 |
parts = line.strip().split()
|
| 311 |
+
if not parts:
|
| 312 |
+
continue
|
| 313 |
try:
|
| 314 |
cls_id = int(parts[0])
|
| 315 |
orig = source_names[cls_id]
|
| 316 |
mapped = class_mapping.get(orig, orig)
|
| 317 |
+
if mapped in active_classes:
|
| 318 |
+
image_classes.add(mapped)
|
| 319 |
except Exception:
|
| 320 |
continue
|
| 321 |
+
if not image_classes:
|
| 322 |
+
continue
|
| 323 |
+
if any(current_counts[c] >= class_limits[c] for c in image_classes):
|
| 324 |
+
continue
|
| 325 |
selected_images.append((img_path, split))
|
| 326 |
+
for c in image_classes:
|
| 327 |
+
current_counts[c] += 1
|
| 328 |
|
| 329 |
progress(0.6, desc=f"Copying {len(selected_images)} files...")
|
| 330 |
for img_path, split in progress.tqdm(selected_images, desc="Finalizing files"):
|
|
|
|
| 335 |
|
| 336 |
source_loc = None
|
| 337 |
for info in dataset_info:
|
| 338 |
+
if img_path.startswith(info[0]):
|
| 339 |
+
source_loc = info[0]
|
| 340 |
+
break
|
| 341 |
source_names = loc_to_names.get(source_loc, [])
|
| 342 |
|
| 343 |
with open(lbl_path, 'r') as f_in, open(out_lbl, 'w') as f_out:
|
| 344 |
for line in f_in:
|
| 345 |
parts = line.strip().split()
|
| 346 |
+
if not parts:
|
| 347 |
+
continue
|
| 348 |
try:
|
| 349 |
old_id = int(parts[0])
|
| 350 |
original_name = source_names[old_id]
|
|
|
|
| 372 |
|
| 373 |
# === entrypoint + config detection/generation =================================
|
| 374 |
def find_training_script(repo_root):
|
| 375 |
+
# Hard-prefer the canonical path widely used in the repo/issues
|
| 376 |
+
canonical = os.path.join(repo_root, "rtdetrv2_pytorch", "tools", "train.py")
|
| 377 |
+
if os.path.exists(canonical):
|
| 378 |
+
return canonical
|
| 379 |
+
|
| 380 |
candidates = []
|
| 381 |
+
for pat in ["**/tools/train.py", "**/train.py", "**/tools/train_net.py"]:
|
| 382 |
candidates.extend(glob(os.path.join(repo_root, pat), recursive=True))
|
| 383 |
+
# Prefer anything inside rtdetrv2_pytorch, then shorter paths
|
| 384 |
+
def _score(p):
|
| 385 |
+
pl = p.replace("\\", "/").lower()
|
| 386 |
+
return (0 if "rtdetrv2_pytorch" in pl else 1, len(p))
|
| 387 |
+
candidates.sort(key=_score)
|
| 388 |
return candidates[0] if candidates else None
|
| 389 |
|
| 390 |
def find_model_config_template(model_key):
|
|
|
|
| 400 |
def score(p):
|
| 401 |
pl = p.lower()
|
| 402 |
s = 0
|
| 403 |
+
if "/rtdetrv2_pytorch/" in pl:
|
| 404 |
+
s += 4
|
| 405 |
+
if "/config" in pl:
|
| 406 |
+
s += 3
|
| 407 |
for token in want_tokens:
|
| 408 |
+
if token in os.path.basename(pl):
|
| 409 |
+
s += 3
|
| 410 |
+
if token in pl:
|
| 411 |
+
s += 2
|
| 412 |
+
if "coco" in pl:
|
| 413 |
+
s += 1
|
| 414 |
return -s, len(p)
|
| 415 |
|
| 416 |
yamls.sort(key=score)
|
| 417 |
return yamls[0] if yamls else None
|
| 418 |
|
| 419 |
+
def _set_first_existing_key(d: dict, keys: list, value, fallback_key: str | None = None):
|
| 420 |
+
"""
|
| 421 |
+
If any key from `keys` exists in dict `d`, set the first one found to `value`.
|
| 422 |
+
Otherwise, if `fallback_key` is given, create it with `value`.
|
| 423 |
+
Returns the key that was set, or None.
|
| 424 |
+
"""
|
| 425 |
+
for k in keys:
|
| 426 |
+
if k in d:
|
| 427 |
+
d[k] = value
|
| 428 |
+
return k
|
| 429 |
+
if fallback_key:
|
| 430 |
+
d[fallback_key] = value
|
| 431 |
+
return fallback_key
|
| 432 |
+
return None
|
| 433 |
+
|
| 434 |
def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
| 435 |
epochs, batch, imgsz, lr, optimizer):
|
| 436 |
if not base_cfg_path or not os.path.exists(base_cfg_path):
|
|
|
|
| 450 |
"out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
|
| 451 |
}
|
| 452 |
|
| 453 |
+
# dataset block: set an existing alias if present, otherwise add a common key
|
| 454 |
for root_key in ["dataset", "data"]:
|
| 455 |
if root_key in cfg and isinstance(cfg[root_key], dict):
|
| 456 |
ds = cfg[root_key]
|
|
|
|
| 460 |
("test", "test_json", "test_img"),
|
| 461 |
]:
|
| 462 |
if split in ds and isinstance(ds[split], dict):
|
| 463 |
+
node = ds[split]
|
| 464 |
+
node["name"] = node.get("name", "coco")
|
| 465 |
+
_set_first_existing_key(
|
| 466 |
+
node,
|
| 467 |
+
keys=["ann_file", "ann_path", "annotation", "annotations"],
|
| 468 |
+
value=paths[jf],
|
| 469 |
+
fallback_key="ann_file",
|
| 470 |
+
)
|
| 471 |
+
_set_first_existing_key(
|
| 472 |
+
node,
|
| 473 |
+
keys=["img_prefix", "img_dir", "image_root", "data_root"],
|
| 474 |
+
value=paths[ip],
|
| 475 |
+
fallback_key="img_prefix",
|
| 476 |
+
)
|
| 477 |
|
| 478 |
# num_classes
|
| 479 |
def set_num_classes(node, n):
|
| 480 |
+
if not isinstance(node, dict):
|
| 481 |
+
return False
|
| 482 |
if "num_classes" in node:
|
| 483 |
+
node["num_classes"] = int(n)
|
| 484 |
+
return True
|
| 485 |
for k, v in node.items():
|
| 486 |
+
if isinstance(v, dict) and set_num_classes(v, n):
|
| 487 |
+
return True
|
| 488 |
return False
|
| 489 |
|
| 490 |
if "model" in cfg and isinstance(cfg["model"], dict):
|
|
|
|
| 497 |
updated_epoch = False
|
| 498 |
for key in ["max_epoch", "epochs", "num_epochs"]:
|
| 499 |
if key in cfg:
|
| 500 |
+
cfg[key] = int(epochs)
|
| 501 |
+
updated_epoch = True
|
| 502 |
+
break
|
| 503 |
if "solver" in cfg and isinstance(cfg["solver"], dict):
|
| 504 |
for key in ["max_epoch", "epochs", "num_epochs"]:
|
| 505 |
if key in cfg["solver"]:
|
| 506 |
+
cfg["solver"][key] = int(epochs)
|
| 507 |
+
updated_epoch = True
|
| 508 |
+
break
|
| 509 |
if not updated_epoch:
|
| 510 |
cfg["max_epoch"] = int(epochs)
|
| 511 |
|
| 512 |
for key in ["input_size", "img_size", "imgsz"]:
|
| 513 |
+
if key in cfg:
|
| 514 |
+
cfg[key] = int(imgsz)
|
| 515 |
+
if "input_size" not in cfg:
|
| 516 |
+
cfg["input_size"] = int(imgsz)
|
| 517 |
|
| 518 |
# lr / optimizer / batch
|
| 519 |
if "solver" not in cfg or not isinstance(cfg["solver"], dict):
|
|
|
|
| 521 |
sol = cfg["solver"]
|
| 522 |
for key in ["base_lr", "lr", "learning_rate"]:
|
| 523 |
if key in sol:
|
| 524 |
+
sol[key] = float(lr)
|
| 525 |
+
break
|
| 526 |
else:
|
| 527 |
sol["base_lr"] = float(lr)
|
| 528 |
|
|
|
|
| 540 |
else:
|
| 541 |
cfg["output_dir"] = paths["out_dir"]
|
| 542 |
|
| 543 |
+
cfg_out_dir = os.path.join("generated_configs")
|
| 544 |
+
os.makedirs(cfg_out_dir, exist_ok=True)
|
| 545 |
out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
|
| 546 |
+
with open(out_path, "w") as f:
|
| 547 |
+
yaml.safe_dump(cfg, f, sort_keys=False)
|
| 548 |
return out_path
|
| 549 |
|
| 550 |
def find_best_checkpoint(out_dir):
|
|
|
|
| 556 |
]
|
| 557 |
for p in pats:
|
| 558 |
f = sorted(glob(p, recursive=True))
|
| 559 |
+
if f:
|
| 560 |
+
return f[0]
|
| 561 |
+
any_ckpt = sorted(
|
| 562 |
+
glob(os.path.join(out_dir, "**", "*.pt"), recursive=True)
|
| 563 |
+
+ glob(os.path.join(out_dir, "**", "*.pth"), recursive=True)
|
| 564 |
+
)
|
| 565 |
return any_ckpt[-1] if any_ckpt else None
|
| 566 |
|
| 567 |
# === Gradio handlers ==========================================================
|
| 568 |
def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
|
| 569 |
api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "")
|
| 570 |
+
if not api_key:
|
| 571 |
+
raise gr.Error("Roboflow API Key is required (or set ROBOFLOW_API_KEY).")
|
| 572 |
+
if not url_file:
|
| 573 |
+
raise gr.Error("Upload a .txt with Roboflow URLs or 'workspace/project[/vN]' lines.")
|
| 574 |
|
| 575 |
with open(url_file.name, 'r', encoding='utf-8', errors='ignore') as f:
|
| 576 |
urls = [line.strip() for line in f if line.strip()]
|
|
|
|
| 588 |
failures.append((raw, f"No latest version for {ws}/{proj}"))
|
| 589 |
continue
|
| 590 |
loc, names, splits, name_str = download_dataset(api_key, ws, proj, int(ver))
|
| 591 |
+
if loc:
|
| 592 |
+
dataset_info.append((loc, names, splits, name_str))
|
| 593 |
+
else:
|
| 594 |
+
failures.append((raw, f"DownloadError: {ws}/{proj}/v{ver}"))
|
| 595 |
|
| 596 |
if not dataset_info:
|
| 597 |
msg = "No datasets loaded.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
|
|
|
|
| 603 |
df = pd.DataFrame([[n, n, counts.get(n, 0), False] for n in all_names],
|
| 604 |
columns=["Original Name", "Rename To", "Max Images", "Remove"])
|
| 605 |
status = "Datasets loaded successfully."
|
| 606 |
+
if failures:
|
| 607 |
+
status += f" ({len(dataset_info)} OK, {len(failures)} failed; see logs)."
|
| 608 |
return status, dataset_info, df
|
| 609 |
|
| 610 |
def update_class_counts_handler(class_df, dataset_info):
|
| 611 |
+
if class_df is None or not dataset_info:
|
| 612 |
+
return None
|
| 613 |
class_df = pd.DataFrame(class_df)
|
| 614 |
mapping = {row["Original Name"]: (None if bool(row["Remove"]) else row["Rename To"])
|
| 615 |
for _, row in class_df.iterrows()}
|
|
|
|
| 619 |
id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)}
|
| 620 |
for split in splits:
|
| 621 |
labels_dir = os.path.join(loc, split, 'labels')
|
| 622 |
+
if not os.path.exists(labels_dir):
|
| 623 |
+
continue
|
| 624 |
for label_file in os.listdir(labels_dir):
|
| 625 |
+
if not label_file.endswith('.txt'):
|
| 626 |
+
continue
|
| 627 |
found = set()
|
| 628 |
with open(os.path.join(labels_dir, label_file), 'r') as f:
|
| 629 |
for line in f:
|
| 630 |
parts = line.strip().split()
|
| 631 |
+
if not parts:
|
| 632 |
+
continue
|
| 633 |
try:
|
| 634 |
cls_id = int(parts[0])
|
| 635 |
mapped = id_to_final.get(cls_id, None)
|
| 636 |
+
if mapped:
|
| 637 |
+
found.add(mapped)
|
| 638 |
except Exception:
|
| 639 |
continue
|
| 640 |
+
for m in found:
|
| 641 |
+
counts[m] += 1
|
| 642 |
return pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"])
|
| 643 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()):
|
| 645 |
+
if not dataset_path:
|
| 646 |
+
raise gr.Error("Finalize a dataset in Tab 2 before training.")
|
| 647 |
|
| 648 |
train_script = find_training_script(REPO_DIR)
|
| 649 |
+
logging.info(f"Resolved training script: {train_script}")
|
| 650 |
if not train_script:
|
| 651 |
raise gr.Error("RT-DETRv2 training script not found inside the repo (looked for **/tools/train.py).")
|
| 652 |
|
|
|
|
| 655 |
raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/L/X).")
|
| 656 |
|
| 657 |
data_yaml = os.path.join(dataset_path, "data.yaml")
|
| 658 |
+
with open(data_yaml, "r") as f:
|
| 659 |
+
dy = yaml.safe_load(f)
|
| 660 |
class_names = [str(x) for x in dy.get("names", [])]
|
| 661 |
make_coco_annotations(dataset_path, class_names)
|
| 662 |
|
|
|
|
| 689 |
proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
|
| 690 |
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
| 691 |
bufsize=1, text=True, env=env)
|
| 692 |
+
for line in proc.stdout:
|
| 693 |
+
q.put(line.rstrip())
|
| 694 |
proc.wait()
|
| 695 |
q.put(f"__EXITCODE__:{proc.returncode}")
|
| 696 |
except Exception as e:
|
|
|
|
| 700 |
|
| 701 |
log_tail, last_epoch, total_epochs = [], 0, int(epochs)
|
| 702 |
first_lines = []
|
| 703 |
+
line_no = 0
|
| 704 |
while True:
|
| 705 |
line = q.get()
|
| 706 |
if line.startswith("__EXITCODE__"):
|
| 707 |
+
code = int(line.split(":", 1)[1])
|
| 708 |
if code != 0:
|
| 709 |
head = "\n".join(first_lines[:60])
|
| 710 |
raise gr.Error(f"Training exited with code {code}.\nLast output:\n{head or 'No logs captured.'}")
|
| 711 |
break
|
| 712 |
if line.startswith("__ERROR__"):
|
| 713 |
+
raise gr.Error(f"Training failed: {line.split(':', 1)[1]}")
|
| 714 |
|
| 715 |
+
if len(first_lines) < 120:
|
| 716 |
+
first_lines.append(line)
|
| 717 |
+
log_tail.append(line)
|
| 718 |
+
log_tail = log_tail[-40:]
|
| 719 |
|
| 720 |
m = re.search(r"[Ee]poch\s+(\d+)\s*/\s*(\d+)", line)
|
| 721 |
if m:
|
| 722 |
try:
|
| 723 |
+
last_epoch = int(m.group(1))
|
| 724 |
+
total_epochs = max(total_epochs, int(m.group(2)))
|
| 725 |
+
except Exception:
|
| 726 |
+
pass
|
| 727 |
+
progress(min(max(last_epoch / max(1, total_epochs), 0.0), 1.0), desc=f"Epoch {last_epoch}/{total_epochs}")
|
| 728 |
+
|
| 729 |
+
# Throttle plotting; close figs after yield to avoid leaks
|
| 730 |
+
line_no += 1
|
| 731 |
+
fig1 = fig2 = None
|
| 732 |
+
if line_no % 80 == 0:
|
| 733 |
+
fig1 = plt.figure()
|
| 734 |
+
plt.title("Loss (see logs)")
|
| 735 |
+
plt.plot([0, last_epoch], [0, 0])
|
| 736 |
+
plt.tight_layout()
|
| 737 |
+
|
| 738 |
+
fig2 = plt.figure()
|
| 739 |
+
plt.title("mAP (see logs)")
|
| 740 |
+
plt.plot([0, last_epoch], [0, 0])
|
| 741 |
+
plt.tight_layout()
|
| 742 |
|
|
|
|
|
|
|
| 743 |
yield "\n".join(log_tail), fig1, fig2, None
|
| 744 |
|
| 745 |
+
if fig1 is not None:
|
| 746 |
+
plt.close(fig1)
|
| 747 |
+
if fig2 is not None:
|
| 748 |
+
plt.close(fig2)
|
| 749 |
+
|
| 750 |
ckpt = find_best_checkpoint(out_dir) or find_best_checkpoint("runs")
|
| 751 |
if not ckpt or not os.path.exists(ckpt):
|
| 752 |
raise gr.Error("Training finished, but checkpoint file not found. Check logs/output directory.")
|
| 753 |
yield "Training complete!", None, None, gr.File.update(value=ckpt, visible=True)
|
| 754 |
|
| 755 |
+
def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
|
| 756 |
+
if not dataset_info:
|
| 757 |
+
raise gr.Error("Load datasets first in Tab 1.")
|
| 758 |
+
if class_df is None:
|
| 759 |
+
raise gr.Error("Class data is missing.")
|
| 760 |
+
class_df = pd.DataFrame(class_df)
|
| 761 |
+
class_mapping, class_limits = {}, {}
|
| 762 |
+
for _, row in class_df.iterrows():
|
| 763 |
+
orig = row["Original Name"]
|
| 764 |
+
if bool(row["Remove"]):
|
| 765 |
+
continue
|
| 766 |
+
final_name = row["Rename To"]
|
| 767 |
+
class_mapping[orig] = final_name
|
| 768 |
+
class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"])
|
| 769 |
+
status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
|
| 770 |
+
return status, path
|
| 771 |
+
|
| 772 |
def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
|
| 773 |
+
if not model_file:
|
| 774 |
+
raise gr.Error("No trained model file to upload.")
|
| 775 |
from huggingface_hub import HfApi, HfFolder
|
| 776 |
hf_status = "Skipped Hugging Face."
|
| 777 |
if hf_token and hf_repo:
|
|
|
|
| 788 |
if gh_token and gh_repo:
|
| 789 |
progress(0.5, desc="Uploading to GitHub...")
|
| 790 |
try:
|
| 791 |
+
if '/' not in gh_repo:
|
| 792 |
+
raise ValueError("GitHub repo must be 'username/repo'.")
|
| 793 |
username, repo_name = gh_repo.split('/')
|
| 794 |
api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}"
|
| 795 |
headers = {"Authorization": f"token {gh_token}"}
|
| 796 |
+
with open(model_file.name, "rb") as f:
|
| 797 |
+
content = base64.b64encode(f.read()).decode()
|
| 798 |
get_resp = requests.get(api_url, headers=headers, timeout=30)
|
| 799 |
sha = get_resp.json().get('sha') if get_resp.ok else None
|
| 800 |
data = {"message": "Upload trained model from Rolo app", "content": content}
|
| 801 |
+
if sha:
|
| 802 |
+
data["sha"] = sha
|
| 803 |
put_resp = requests.put(api_url, headers=headers, json=data, timeout=60)
|
| 804 |
+
if put_resp.ok:
|
| 805 |
+
gh_status = f"Success! {put_resp.json()['content']['html_url']}"
|
| 806 |
+
else:
|
| 807 |
+
gh_status = f"GitHub Error: {put_resp.json().get('message','Unknown')}"
|
| 808 |
except Exception as e:
|
| 809 |
gh_status = f"GitHub Error: {e}"
|
| 810 |
+
progress(1)
|
| 811 |
+
return hf_status, gh_status
|
| 812 |
|
| 813 |
# === UI =======================================================================
|
| 814 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
|
|
|
|
| 887 |
|
| 888 |
if __name__ == "__main__":
|
| 889 |
os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics") # silence stray warnings from other libs
|
| 890 |
+
# Log training script resolution at startup for quick troubleshooting
|
| 891 |
+
try:
|
| 892 |
+
ts = find_training_script(REPO_DIR)
|
| 893 |
+
logging.info(f"Startup check — training script at: {ts}")
|
| 894 |
+
except Exception as e:
|
| 895 |
+
logging.warning(f"Startup training-script check failed: {e}")
|
| 896 |
app.launch(debug=True)
|