wuhp commited on
Commit
3e12066
·
verified ·
1 Parent(s): 2078af3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +257 -115
app.py CHANGED
@@ -17,7 +17,7 @@ REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2"
17
  REPO_DIR = os.path.join(os.getcwd(), "third_party", "RT-DETRv2")
18
  PY_IMPL_DIR = os.path.join(REPO_DIR, "rtdetrv2_pytorch") # Supervisely keeps PyTorch impl here
19
 
20
- # Core deps + your requested packages; pinned as lower-bounds to avoid downgrades
21
  COMMON_REQUIREMENTS = [
22
  "gradio>=4.36.1",
23
  "ultralytics>=8.2.0",
@@ -30,9 +30,9 @@ COMMON_REQUIREMENTS = [
30
  "torchvision>=0.15.2",
31
  "pyyaml>=6.0.1",
32
  "Pillow>=10.0.0",
33
- "supervisely>=6.0.0", # <- fixes ModuleNotFoundError from repo trainer
34
- "tensorboard>=2.13.0", # convenience: sometimes used by forks
35
- "pycocotools>=2.0.7", # convenience: ensure wheels are present
36
  ]
37
 
38
  # === bootstrap (clone + pip) ===================================================
@@ -51,15 +51,17 @@ def ensure_repo_and_requirements():
51
  except Exception:
52
  logging.warning("git pull failed; continuing with current checkout")
53
 
54
- # Make sure all our app/runtime deps (incl. supervisely & ultralytics) are present
55
- pip_install(COMMON_REQUIREMENTS)
 
 
56
 
57
- # Then install repo-specific extras (pycocotools/tensorboard etc. if required)
 
58
  req_file = os.path.join(PY_IMPL_DIR, "requirements.txt")
59
  if os.path.exists(req_file):
60
  pip_install(["-r", req_file])
61
 
62
- # Double-check supervisely importability; if not, try again explicitly.
63
  try:
64
  import supervisely # noqa: F401
65
  except Exception:
@@ -77,8 +79,10 @@ DEFAULT_MODEL_KEY = "rtdetrv2_s"
77
 
78
  # === utilities ================================================================
79
  def handle_remove_readonly(func, path, exc_info):
80
- try: os.chmod(path, stat.S_IWRITE)
81
- except Exception: pass
 
 
82
  func(path)
83
 
84
  _ROBO_URL_RX = re.compile(r"""
@@ -104,8 +108,10 @@ def parse_roboflow_url(s: str):
104
  version = None
105
  if len(parts) >= 3:
106
  v = parts[2]
107
- if v.lower().startswith('v') and v[1:].isdigit(): version = int(v[1:])
108
- elif v.isdigit(): version = int(v)
 
 
109
  return parts[0], parts[1], version
110
  if '/' in s and 'roboflow' not in s:
111
  p = s.split('/')
@@ -113,8 +119,10 @@ def parse_roboflow_url(s: str):
113
  version = None
114
  if len(p) >= 3:
115
  v = p[2]
116
- if v.lower().startswith('v') and v[1:].isdigit(): version = int(v[1:])
117
- elif v.isdigit(): version = int(v)
 
 
118
  return p[0], p[1], version
119
  return None, None, None
120
 
@@ -132,8 +140,10 @@ def _extract_class_names(data_yaml):
132
  names = data_yaml.get('names', None)
133
  if isinstance(names, dict):
134
  def _k(x):
135
- try: return int(x)
136
- except Exception: return str(x)
 
 
137
  keys = sorted(names.keys(), key=_k)
138
  names_list = [names[k] for k in keys]
139
  elif isinstance(names, list):
@@ -150,7 +160,8 @@ def download_dataset(api_key, workspace, project, version):
150
  ver = proj.version(int(version))
151
  dataset = ver.download("yolov8") # labels in YOLO format (we'll convert to COCO)
152
  data_yaml_path = os.path.join(dataset.location, 'data.yaml')
153
- with open(data_yaml_path, 'r') as f: data_yaml = yaml.safe_load(f)
 
154
  class_names = _extract_class_names(data_yaml)
155
  splits = [s for s in ['train', 'valid', 'test'] if os.path.exists(os.path.join(dataset.location, s))]
156
  return dataset.location, class_names, splits, f"{project}-v{version}"
@@ -170,7 +181,8 @@ def yolo_to_coco(split_dir_images, split_dir_labels, class_names, out_json):
170
  ann_id = 1
171
  img_id = 1
172
  for fname in sorted(os.listdir(split_dir_images)):
173
- if not fname.lower().endswith((".jpg",".jpeg",".png")): continue
 
174
  img_path = os.path.join(split_dir_images, fname)
175
  try:
176
  with Image.open(img_path) as im:
@@ -183,19 +195,28 @@ def yolo_to_coco(split_dir_images, split_dir_labels, class_names, out_json):
183
  with open(label_file, "r") as f:
184
  for line in f:
185
  parts = line.strip().split()
186
- if len(parts) < 5: continue
187
- cls = int(float(parts[0]))
188
- cx, cy, bw, bh = map(float, parts[1:5])
189
- x = (cx - bw/2.0) * w
190
- y = (cy - bh/2.0) * h
191
- ww = bw * w
192
- hh = bh * h
 
 
 
 
 
 
 
 
 
193
  annotations.append({
194
  "id": ann_id,
195
  "image_id": img_id,
196
  "category_id": cls,
197
- "bbox": [max(0.0,x), max(0.0,y), max(1.0,ww), max(1.0,hh)],
198
- "area": max(1.0, ww*hh),
199
  "iscrowd": 0,
200
  "segmentation": []
201
  })
@@ -203,7 +224,8 @@ def yolo_to_coco(split_dir_images, split_dir_labels, class_names, out_json):
203
  img_id += 1
204
  coco = {"images": images, "annotations": annotations, "categories": categories}
205
  os.makedirs(os.path.dirname(out_json), exist_ok=True)
206
- with open(out_json, "w") as f: json.dump(coco, f)
 
207
 
208
  def make_coco_annotations(merged_dir, class_names):
209
  ann_dir = os.path.join(merged_dir, "annotations")
@@ -219,28 +241,34 @@ def make_coco_annotations(merged_dir, class_names):
219
 
220
  # === dataset merging ==========================================================
221
  def gather_class_counts(dataset_info, class_mapping):
222
- if not dataset_info: return {}
 
223
  final_names = set(v for v in class_mapping.values() if v is not None)
224
  counts = {name: 0 for name in final_names}
225
  for loc, names, splits, _ in dataset_info:
226
  id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)}
227
  for split in splits:
228
  labels_dir = os.path.join(loc, split, 'labels')
229
- if not os.path.exists(labels_dir): continue
 
230
  for label_file in os.listdir(labels_dir):
231
- if not label_file.endswith('.txt'): continue
 
232
  found = set()
233
  with open(os.path.join(labels_dir, label_file), 'r') as f:
234
  for line in f:
235
  parts = line.strip().split()
236
- if not parts: continue
 
237
  try:
238
  cls_id = int(parts[0])
239
  mapped = id_to_name.get(cls_id, None)
240
- if mapped: found.add(mapped)
 
241
  except Exception:
242
  continue
243
- for m in found: counts[m] += 1
 
244
  return counts
245
 
246
  def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()):
@@ -260,7 +288,8 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
260
  for loc, _, splits, _ in dataset_info:
261
  for split in splits:
262
  img_dir = os.path.join(loc, split, 'images')
263
- if not os.path.exists(img_dir): continue
 
264
  for img_file in os.listdir(img_dir):
265
  if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
266
  all_images.append((os.path.join(img_dir, img_file), split, loc))
@@ -272,24 +301,30 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
272
 
273
  for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"):
274
  lbl_path = label_path_for(img_path)
275
- if not os.path.exists(lbl_path): continue
 
276
  source_names = loc_to_names.get(source_loc, [])
277
  image_classes = set()
278
  with open(lbl_path, 'r') as f:
279
  for line in f:
280
  parts = line.strip().split()
281
- if not parts: continue
 
282
  try:
283
  cls_id = int(parts[0])
284
  orig = source_names[cls_id]
285
  mapped = class_mapping.get(orig, orig)
286
- if mapped in active_classes: image_classes.add(mapped)
 
287
  except Exception:
288
  continue
289
- if not image_classes: continue
290
- if any(current_counts[c] >= class_limits[c] for c in image_classes): continue
 
 
291
  selected_images.append((img_path, split))
292
- for c in image_classes: current_counts[c] += 1
 
293
 
294
  progress(0.6, desc=f"Copying {len(selected_images)} files...")
295
  for img_path, split in progress.tqdm(selected_images, desc="Finalizing files"):
@@ -300,13 +335,16 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
300
 
301
  source_loc = None
302
  for info in dataset_info:
303
- if img_path.startswith(info[0]): source_loc = info[0]; break
 
 
304
  source_names = loc_to_names.get(source_loc, [])
305
 
306
  with open(lbl_path, 'r') as f_in, open(out_lbl, 'w') as f_out:
307
  for line in f_in:
308
  parts = line.strip().split()
309
- if not parts: continue
 
310
  try:
311
  old_id = int(parts[0])
312
  original_name = source_names[old_id]
@@ -334,10 +372,19 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
334
 
335
  # === entrypoint + config detection/generation =================================
336
  def find_training_script(repo_root):
 
 
 
 
 
337
  candidates = []
338
- for pat in ["**/tools/train.py", "**/train.py"]:
339
  candidates.extend(glob(os.path.join(repo_root, pat), recursive=True))
340
- candidates.sort(key=lambda p: (0 if "rtdetrv2_pytorch" in p else 1, len(p)))
 
 
 
 
341
  return candidates[0] if candidates else None
342
 
343
  def find_model_config_template(model_key):
@@ -353,17 +400,37 @@ def find_model_config_template(model_key):
353
  def score(p):
354
  pl = p.lower()
355
  s = 0
356
- if "/rtdetrv2_pytorch/" in pl: s += 4
357
- if "/config" in pl: s += 3
 
 
358
  for token in want_tokens:
359
- if token in os.path.basename(pl): s += 3
360
- if token in pl: s += 2
361
- if "coco" in pl: s += 1
 
 
 
362
  return -s, len(p)
363
 
364
  yamls.sort(key=score)
365
  return yamls[0] if yamls else None
366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
368
  epochs, batch, imgsz, lr, optimizer):
369
  if not base_cfg_path or not os.path.exists(base_cfg_path):
@@ -383,7 +450,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
383
  "out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
384
  }
385
 
386
- # dataset block
387
  for root_key in ["dataset", "data"]:
388
  if root_key in cfg and isinstance(cfg[root_key], dict):
389
  ds = cfg[root_key]
@@ -393,21 +460,31 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
393
  ("test", "test_json", "test_img"),
394
  ]:
395
  if split in ds and isinstance(ds[split], dict):
396
- ds[split]["name"] = ds[split].get("name", "coco")
397
- for k in ["ann_file", "ann_path", "annotation", "annotations"]:
398
- if k in ds[split] or k in ["ann_file", "ann_path"]:
399
- ds[split][k] = paths[jf]; break
400
- for k in ["img_prefix", "img_dir", "image_root", "data_root"]:
401
- if k in ds[split] or k in ["img_prefix", "img_dir"]:
402
- ds[split][k] = paths[ip]; break
 
 
 
 
 
 
 
403
 
404
  # num_classes
405
  def set_num_classes(node, n):
406
- if not isinstance(node, dict): return False
 
407
  if "num_classes" in node:
408
- node["num_classes"] = int(n); return True
 
409
  for k, v in node.items():
410
- if isinstance(v, dict) and set_num_classes(v, n): return True
 
411
  return False
412
 
413
  if "model" in cfg and isinstance(cfg["model"], dict):
@@ -420,17 +497,23 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
420
  updated_epoch = False
421
  for key in ["max_epoch", "epochs", "num_epochs"]:
422
  if key in cfg:
423
- cfg[key] = int(epochs); updated_epoch = True; break
 
 
424
  if "solver" in cfg and isinstance(cfg["solver"], dict):
425
  for key in ["max_epoch", "epochs", "num_epochs"]:
426
  if key in cfg["solver"]:
427
- cfg["solver"][key] = int(epochs); updated_epoch = True; break
 
 
428
  if not updated_epoch:
429
  cfg["max_epoch"] = int(epochs)
430
 
431
  for key in ["input_size", "img_size", "imgsz"]:
432
- if key in cfg: cfg[key] = int(imgsz)
433
- if "input_size" not in cfg: cfg["input_size"] = int(imgsz)
 
 
434
 
435
  # lr / optimizer / batch
436
  if "solver" not in cfg or not isinstance(cfg["solver"], dict):
@@ -438,7 +521,8 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
438
  sol = cfg["solver"]
439
  for key in ["base_lr", "lr", "learning_rate"]:
440
  if key in sol:
441
- sol[key] = float(lr); break
 
442
  else:
443
  sol["base_lr"] = float(lr)
444
 
@@ -456,9 +540,11 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
456
  else:
457
  cfg["output_dir"] = paths["out_dir"]
458
 
459
- cfg_out_dir = os.path.join("generated_configs"); os.makedirs(cfg_out_dir, exist_ok=True)
 
460
  out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
461
- with open(out_path, "w") as f: yaml.safe_dump(cfg, f, sort_keys=False)
 
462
  return out_path
463
 
464
  def find_best_checkpoint(out_dir):
@@ -470,16 +556,21 @@ def find_best_checkpoint(out_dir):
470
  ]
471
  for p in pats:
472
  f = sorted(glob(p, recursive=True))
473
- if f: return f[0]
474
- any_ckpt = sorted(glob(os.path.join(out_dir, "**", "*.pt"), recursive=True) +
475
- glob(os.path.join(out_dir, "**", "*.pth"), recursive=True))
 
 
 
476
  return any_ckpt[-1] if any_ckpt else None
477
 
478
  # === Gradio handlers ==========================================================
479
  def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
480
  api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "")
481
- if not api_key: raise gr.Error("Roboflow API Key is required (or set ROBOFLOW_API_KEY).")
482
- if not url_file: raise gr.Error("Upload a .txt with Roboflow URLs or 'workspace/project[/vN]' lines.")
 
 
483
 
484
  with open(url_file.name, 'r', encoding='utf-8', errors='ignore') as f:
485
  urls = [line.strip() for line in f if line.strip()]
@@ -497,8 +588,10 @@ def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
497
  failures.append((raw, f"No latest version for {ws}/{proj}"))
498
  continue
499
  loc, names, splits, name_str = download_dataset(api_key, ws, proj, int(ver))
500
- if loc: dataset_info.append((loc, names, splits, name_str))
501
- else: failures.append((raw, f"DownloadError: {ws}/{proj}/v{ver}"))
 
 
502
 
503
  if not dataset_info:
504
  msg = "No datasets loaded.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
@@ -510,11 +603,13 @@ def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
510
  df = pd.DataFrame([[n, n, counts.get(n, 0), False] for n in all_names],
511
  columns=["Original Name", "Rename To", "Max Images", "Remove"])
512
  status = "Datasets loaded successfully."
513
- if failures: status += f" ({len(dataset_info)} OK, {len(failures)} failed; see logs)."
 
514
  return status, dataset_info, df
515
 
516
  def update_class_counts_handler(class_df, dataset_info):
517
- if class_df is None or not dataset_info: return None
 
518
  class_df = pd.DataFrame(class_df)
519
  mapping = {row["Original Name"]: (None if bool(row["Remove"]) else row["Rename To"])
520
  for _, row in class_df.iterrows()}
@@ -524,41 +619,34 @@ def update_class_counts_handler(class_df, dataset_info):
524
  id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)}
525
  for split in splits:
526
  labels_dir = os.path.join(loc, split, 'labels')
527
- if not os.path.exists(labels_dir): continue
 
528
  for label_file in os.listdir(labels_dir):
529
- if not label_file.endswith('.txt'): continue
 
530
  found = set()
531
  with open(os.path.join(labels_dir, label_file), 'r') as f:
532
  for line in f:
533
  parts = line.strip().split()
534
- if not parts: continue
 
535
  try:
536
  cls_id = int(parts[0])
537
  mapped = id_to_final.get(cls_id, None)
538
- if mapped: found.add(mapped)
 
539
  except Exception:
540
  continue
541
- for m in found: counts[m] += 1
 
542
  return pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"])
543
 
544
- def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
545
- if not dataset_info: raise gr.Error("Load datasets first in Tab 1.")
546
- if class_df is None: raise gr.Error("Class data is missing.")
547
- class_df = pd.DataFrame(class_df)
548
- class_mapping, class_limits = {}, {}
549
- for _, row in class_df.iterrows():
550
- orig = row["Original Name"]
551
- if bool(row["Remove"]): continue
552
- final_name = row["Rename To"]
553
- class_mapping[orig] = final_name
554
- class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"])
555
- status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
556
- return status, path
557
-
558
  def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()):
559
- if not dataset_path: raise gr.Error("Finalize a dataset in Tab 2 before training.")
 
560
 
561
  train_script = find_training_script(REPO_DIR)
 
562
  if not train_script:
563
  raise gr.Error("RT-DETRv2 training script not found inside the repo (looked for **/tools/train.py).")
564
 
@@ -567,7 +655,8 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
567
  raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/L/X).")
568
 
569
  data_yaml = os.path.join(dataset_path, "data.yaml")
570
- with open(data_yaml, "r") as f: dy = yaml.safe_load(f)
 
571
  class_names = [str(x) for x in dy.get("names", [])]
572
  make_coco_annotations(dataset_path, class_names)
573
 
@@ -600,7 +689,8 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
600
  proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
601
  stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
602
  bufsize=1, text=True, env=env)
603
- for line in proc.stdout: q.put(line.rstrip())
 
604
  proc.wait()
605
  q.put(f"__EXITCODE__:{proc.returncode}")
606
  except Exception as e:
@@ -610,38 +700,78 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
610
 
611
  log_tail, last_epoch, total_epochs = [], 0, int(epochs)
612
  first_lines = []
 
613
  while True:
614
  line = q.get()
615
  if line.startswith("__EXITCODE__"):
616
- code = int(line.split(":",1)[1])
617
  if code != 0:
618
  head = "\n".join(first_lines[:60])
619
  raise gr.Error(f"Training exited with code {code}.\nLast output:\n{head or 'No logs captured.'}")
620
  break
621
  if line.startswith("__ERROR__"):
622
- raise gr.Error(f"Training failed: {line.split(':',1)[1]}")
623
 
624
- if len(first_lines) < 120: first_lines.append(line)
625
- log_tail.append(line); log_tail = log_tail[-40:]
 
 
626
 
627
  m = re.search(r"[Ee]poch\s+(\d+)\s*/\s*(\d+)", line)
628
  if m:
629
  try:
630
- last_epoch = int(m.group(1)); total_epochs = max(total_epochs, int(m.group(2)))
631
- except Exception: pass
632
- progress(min(max(last_epoch / max(1,total_epochs),0.0),1.0), desc=f"Epoch {last_epoch}/{total_epochs}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
 
634
- fig1 = plt.figure(); plt.title("Loss (see logs)")
635
- fig2 = plt.figure(); plt.title("mAP (see logs)")
636
  yield "\n".join(log_tail), fig1, fig2, None
637
 
 
 
 
 
 
638
  ckpt = find_best_checkpoint(out_dir) or find_best_checkpoint("runs")
639
  if not ckpt or not os.path.exists(ckpt):
640
  raise gr.Error("Training finished, but checkpoint file not found. Check logs/output directory.")
641
  yield "Training complete!", None, None, gr.File.update(value=ckpt, visible=True)
642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
644
- if not model_file: raise gr.Error("No trained model file to upload.")
 
645
  from huggingface_hub import HfApi, HfFolder
646
  hf_status = "Skipped Hugging Face."
647
  if hf_token and hf_repo:
@@ -658,21 +788,27 @@ def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr
658
  if gh_token and gh_repo:
659
  progress(0.5, desc="Uploading to GitHub...")
660
  try:
661
- if '/' not in gh_repo: raise ValueError("GitHub repo must be 'username/repo'.")
 
662
  username, repo_name = gh_repo.split('/')
663
  api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}"
664
  headers = {"Authorization": f"token {gh_token}"}
665
- with open(model_file.name, "rb") as f: content = base64.b64encode(f.read()).decode()
 
666
  get_resp = requests.get(api_url, headers=headers, timeout=30)
667
  sha = get_resp.json().get('sha') if get_resp.ok else None
668
  data = {"message": "Upload trained model from Rolo app", "content": content}
669
- if sha: data["sha"] = sha
 
670
  put_resp = requests.put(api_url, headers=headers, json=data, timeout=60)
671
- if put_resp.ok: gh_status = f"Success! {put_resp.json()['content']['html_url']}"
672
- else: gh_status = f"GitHub Error: {put_resp.json().get('message','Unknown')}"
 
 
673
  except Exception as e:
674
  gh_status = f"GitHub Error: {e}"
675
- progress(1); return hf_status, gh_status
 
676
 
677
  # === UI =======================================================================
678
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
@@ -751,4 +887,10 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
751
 
752
  if __name__ == "__main__":
753
  os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics") # silence stray warnings from other libs
 
 
 
 
 
 
754
  app.launch(debug=True)
 
17
  REPO_DIR = os.path.join(os.getcwd(), "third_party", "RT-DETRv2")
18
  PY_IMPL_DIR = os.path.join(REPO_DIR, "rtdetrv2_pytorch") # Supervisely keeps PyTorch impl here
19
 
20
+ # Core deps + your requested packages; pinned as lower-bounds to avoid downgrades (local runs only)
21
  COMMON_REQUIREMENTS = [
22
  "gradio>=4.36.1",
23
  "ultralytics>=8.2.0",
 
30
  "torchvision>=0.15.2",
31
  "pyyaml>=6.0.1",
32
  "Pillow>=10.0.0",
33
+ "supervisely>=6.0.0",
34
+ "tensorboard>=2.13.0",
35
+ "pycocotools>=2.0.7",
36
  ]
37
 
38
  # === bootstrap (clone + pip) ===================================================
 
51
  except Exception:
52
  logging.warning("git pull failed; continuing with current checkout")
53
 
54
+ # On HF Spaces: expect requirements.txt to be used at build time; skip heavy runtime installs
55
+ if os.getenv("HF_SPACE") == "1" or os.getenv("SPACE_ID"):
56
+ logging.info("Detected Hugging Face Space — skipping runtime pip installs.")
57
+ return
58
 
59
+ # Local fallback (non-Spaces)
60
+ pip_install(COMMON_REQUIREMENTS)
61
  req_file = os.path.join(PY_IMPL_DIR, "requirements.txt")
62
  if os.path.exists(req_file):
63
  pip_install(["-r", req_file])
64
 
 
65
  try:
66
  import supervisely # noqa: F401
67
  except Exception:
 
79
 
80
  # === utilities ================================================================
81
  def handle_remove_readonly(func, path, exc_info):
82
+ try:
83
+ os.chmod(path, stat.S_IWRITE)
84
+ except Exception:
85
+ pass
86
  func(path)
87
 
88
  _ROBO_URL_RX = re.compile(r"""
 
108
  version = None
109
  if len(parts) >= 3:
110
  v = parts[2]
111
+ if v.lower().startswith('v') and v[1:].isdigit():
112
+ version = int(v[1:])
113
+ elif v.isdigit():
114
+ version = int(v)
115
  return parts[0], parts[1], version
116
  if '/' in s and 'roboflow' not in s:
117
  p = s.split('/')
 
119
  version = None
120
  if len(p) >= 3:
121
  v = p[2]
122
+ if v.lower().startswith('v') and v[1:].isdigit():
123
+ version = int(v[1:])
124
+ elif v.isdigit():
125
+ version = int(v)
126
  return p[0], p[1], version
127
  return None, None, None
128
 
 
140
  names = data_yaml.get('names', None)
141
  if isinstance(names, dict):
142
  def _k(x):
143
+ try:
144
+ return int(x)
145
+ except Exception:
146
+ return str(x)
147
  keys = sorted(names.keys(), key=_k)
148
  names_list = [names[k] for k in keys]
149
  elif isinstance(names, list):
 
160
  ver = proj.version(int(version))
161
  dataset = ver.download("yolov8") # labels in YOLO format (we'll convert to COCO)
162
  data_yaml_path = os.path.join(dataset.location, 'data.yaml')
163
+ with open(data_yaml_path, 'r') as f:
164
+ data_yaml = yaml.safe_load(f)
165
  class_names = _extract_class_names(data_yaml)
166
  splits = [s for s in ['train', 'valid', 'test'] if os.path.exists(os.path.join(dataset.location, s))]
167
  return dataset.location, class_names, splits, f"{project}-v{version}"
 
181
  ann_id = 1
182
  img_id = 1
183
  for fname in sorted(os.listdir(split_dir_images)):
184
+ if not fname.lower().endswith((".jpg", ".jpeg", ".png")):
185
+ continue
186
  img_path = os.path.join(split_dir_images, fname)
187
  try:
188
  with Image.open(img_path) as im:
 
195
  with open(label_file, "r") as f:
196
  for line in f:
197
  parts = line.strip().split()
198
+ if len(parts) < 5:
199
+ continue
200
+ try:
201
+ cls = int(float(parts[0]))
202
+ cx, cy, bw, bh = map(float, parts[1:5])
203
+ except Exception:
204
+ continue
205
+ x = max(0.0, (cx - bw / 2.0) * w)
206
+ y = max(0.0, (cy - bh / 2.0) * h)
207
+ ww = max(1.0, bw * w)
208
+ hh = max(1.0, bh * h)
209
+ # clamp right/bottom to image bounds
210
+ if x + ww > w:
211
+ ww = max(1.0, w - x)
212
+ if y + hh > h:
213
+ hh = max(1.0, h - y)
214
  annotations.append({
215
  "id": ann_id,
216
  "image_id": img_id,
217
  "category_id": cls,
218
+ "bbox": [x, y, ww, hh],
219
+ "area": max(1.0, ww * hh),
220
  "iscrowd": 0,
221
  "segmentation": []
222
  })
 
224
  img_id += 1
225
  coco = {"images": images, "annotations": annotations, "categories": categories}
226
  os.makedirs(os.path.dirname(out_json), exist_ok=True)
227
+ with open(out_json, "w") as f:
228
+ json.dump(coco, f)
229
 
230
  def make_coco_annotations(merged_dir, class_names):
231
  ann_dir = os.path.join(merged_dir, "annotations")
 
241
 
242
  # === dataset merging ==========================================================
243
  def gather_class_counts(dataset_info, class_mapping):
244
+ if not dataset_info:
245
+ return {}
246
  final_names = set(v for v in class_mapping.values() if v is not None)
247
  counts = {name: 0 for name in final_names}
248
  for loc, names, splits, _ in dataset_info:
249
  id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)}
250
  for split in splits:
251
  labels_dir = os.path.join(loc, split, 'labels')
252
+ if not os.path.exists(labels_dir):
253
+ continue
254
  for label_file in os.listdir(labels_dir):
255
+ if not label_file.endswith('.txt'):
256
+ continue
257
  found = set()
258
  with open(os.path.join(labels_dir, label_file), 'r') as f:
259
  for line in f:
260
  parts = line.strip().split()
261
+ if not parts:
262
+ continue
263
  try:
264
  cls_id = int(parts[0])
265
  mapped = id_to_name.get(cls_id, None)
266
+ if mapped:
267
+ found.add(mapped)
268
  except Exception:
269
  continue
270
+ for m in found:
271
+ counts[m] += 1
272
  return counts
273
 
274
  def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()):
 
288
  for loc, _, splits, _ in dataset_info:
289
  for split in splits:
290
  img_dir = os.path.join(loc, split, 'images')
291
+ if not os.path.exists(img_dir):
292
+ continue
293
  for img_file in os.listdir(img_dir):
294
  if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
295
  all_images.append((os.path.join(img_dir, img_file), split, loc))
 
301
 
302
  for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"):
303
  lbl_path = label_path_for(img_path)
304
+ if not os.path.exists(lbl_path):
305
+ continue
306
  source_names = loc_to_names.get(source_loc, [])
307
  image_classes = set()
308
  with open(lbl_path, 'r') as f:
309
  for line in f:
310
  parts = line.strip().split()
311
+ if not parts:
312
+ continue
313
  try:
314
  cls_id = int(parts[0])
315
  orig = source_names[cls_id]
316
  mapped = class_mapping.get(orig, orig)
317
+ if mapped in active_classes:
318
+ image_classes.add(mapped)
319
  except Exception:
320
  continue
321
+ if not image_classes:
322
+ continue
323
+ if any(current_counts[c] >= class_limits[c] for c in image_classes):
324
+ continue
325
  selected_images.append((img_path, split))
326
+ for c in image_classes:
327
+ current_counts[c] += 1
328
 
329
  progress(0.6, desc=f"Copying {len(selected_images)} files...")
330
  for img_path, split in progress.tqdm(selected_images, desc="Finalizing files"):
 
335
 
336
  source_loc = None
337
  for info in dataset_info:
338
+ if img_path.startswith(info[0]):
339
+ source_loc = info[0]
340
+ break
341
  source_names = loc_to_names.get(source_loc, [])
342
 
343
  with open(lbl_path, 'r') as f_in, open(out_lbl, 'w') as f_out:
344
  for line in f_in:
345
  parts = line.strip().split()
346
+ if not parts:
347
+ continue
348
  try:
349
  old_id = int(parts[0])
350
  original_name = source_names[old_id]
 
372
 
373
  # === entrypoint + config detection/generation =================================
374
  def find_training_script(repo_root):
375
+ # Hard-prefer the canonical path widely used in the repo/issues
376
+ canonical = os.path.join(repo_root, "rtdetrv2_pytorch", "tools", "train.py")
377
+ if os.path.exists(canonical):
378
+ return canonical
379
+
380
  candidates = []
381
+ for pat in ["**/tools/train.py", "**/train.py", "**/tools/train_net.py"]:
382
  candidates.extend(glob(os.path.join(repo_root, pat), recursive=True))
383
+ # Prefer anything inside rtdetrv2_pytorch, then shorter paths
384
+ def _score(p):
385
+ pl = p.replace("\\", "/").lower()
386
+ return (0 if "rtdetrv2_pytorch" in pl else 1, len(p))
387
+ candidates.sort(key=_score)
388
  return candidates[0] if candidates else None
389
 
390
  def find_model_config_template(model_key):
 
400
  def score(p):
401
  pl = p.lower()
402
  s = 0
403
+ if "/rtdetrv2_pytorch/" in pl:
404
+ s += 4
405
+ if "/config" in pl:
406
+ s += 3
407
  for token in want_tokens:
408
+ if token in os.path.basename(pl):
409
+ s += 3
410
+ if token in pl:
411
+ s += 2
412
+ if "coco" in pl:
413
+ s += 1
414
  return -s, len(p)
415
 
416
  yamls.sort(key=score)
417
  return yamls[0] if yamls else None
418
 
419
+ def _set_first_existing_key(d: dict, keys: list, value, fallback_key: str | None = None):
420
+ """
421
+ If any key from `keys` exists in dict `d`, set the first one found to `value`.
422
+ Otherwise, if `fallback_key` is given, create it with `value`.
423
+ Returns the key that was set, or None.
424
+ """
425
+ for k in keys:
426
+ if k in d:
427
+ d[k] = value
428
+ return k
429
+ if fallback_key:
430
+ d[fallback_key] = value
431
+ return fallback_key
432
+ return None
433
+
434
  def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
435
  epochs, batch, imgsz, lr, optimizer):
436
  if not base_cfg_path or not os.path.exists(base_cfg_path):
 
450
  "out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
451
  }
452
 
453
+ # dataset block: set an existing alias if present, otherwise add a common key
454
  for root_key in ["dataset", "data"]:
455
  if root_key in cfg and isinstance(cfg[root_key], dict):
456
  ds = cfg[root_key]
 
460
  ("test", "test_json", "test_img"),
461
  ]:
462
  if split in ds and isinstance(ds[split], dict):
463
+ node = ds[split]
464
+ node["name"] = node.get("name", "coco")
465
+ _set_first_existing_key(
466
+ node,
467
+ keys=["ann_file", "ann_path", "annotation", "annotations"],
468
+ value=paths[jf],
469
+ fallback_key="ann_file",
470
+ )
471
+ _set_first_existing_key(
472
+ node,
473
+ keys=["img_prefix", "img_dir", "image_root", "data_root"],
474
+ value=paths[ip],
475
+ fallback_key="img_prefix",
476
+ )
477
 
478
  # num_classes
479
  def set_num_classes(node, n):
480
+ if not isinstance(node, dict):
481
+ return False
482
  if "num_classes" in node:
483
+ node["num_classes"] = int(n)
484
+ return True
485
  for k, v in node.items():
486
+ if isinstance(v, dict) and set_num_classes(v, n):
487
+ return True
488
  return False
489
 
490
  if "model" in cfg and isinstance(cfg["model"], dict):
 
497
  updated_epoch = False
498
  for key in ["max_epoch", "epochs", "num_epochs"]:
499
  if key in cfg:
500
+ cfg[key] = int(epochs)
501
+ updated_epoch = True
502
+ break
503
  if "solver" in cfg and isinstance(cfg["solver"], dict):
504
  for key in ["max_epoch", "epochs", "num_epochs"]:
505
  if key in cfg["solver"]:
506
+ cfg["solver"][key] = int(epochs)
507
+ updated_epoch = True
508
+ break
509
  if not updated_epoch:
510
  cfg["max_epoch"] = int(epochs)
511
 
512
  for key in ["input_size", "img_size", "imgsz"]:
513
+ if key in cfg:
514
+ cfg[key] = int(imgsz)
515
+ if "input_size" not in cfg:
516
+ cfg["input_size"] = int(imgsz)
517
 
518
  # lr / optimizer / batch
519
  if "solver" not in cfg or not isinstance(cfg["solver"], dict):
 
521
  sol = cfg["solver"]
522
  for key in ["base_lr", "lr", "learning_rate"]:
523
  if key in sol:
524
+ sol[key] = float(lr)
525
+ break
526
  else:
527
  sol["base_lr"] = float(lr)
528
 
 
540
  else:
541
  cfg["output_dir"] = paths["out_dir"]
542
 
543
+ cfg_out_dir = os.path.join("generated_configs")
544
+ os.makedirs(cfg_out_dir, exist_ok=True)
545
  out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
546
+ with open(out_path, "w") as f:
547
+ yaml.safe_dump(cfg, f, sort_keys=False)
548
  return out_path
549
 
550
  def find_best_checkpoint(out_dir):
 
556
  ]
557
  for p in pats:
558
  f = sorted(glob(p, recursive=True))
559
+ if f:
560
+ return f[0]
561
+ any_ckpt = sorted(
562
+ glob(os.path.join(out_dir, "**", "*.pt"), recursive=True)
563
+ + glob(os.path.join(out_dir, "**", "*.pth"), recursive=True)
564
+ )
565
  return any_ckpt[-1] if any_ckpt else None
566
 
567
  # === Gradio handlers ==========================================================
568
  def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
569
  api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "")
570
+ if not api_key:
571
+ raise gr.Error("Roboflow API Key is required (or set ROBOFLOW_API_KEY).")
572
+ if not url_file:
573
+ raise gr.Error("Upload a .txt with Roboflow URLs or 'workspace/project[/vN]' lines.")
574
 
575
  with open(url_file.name, 'r', encoding='utf-8', errors='ignore') as f:
576
  urls = [line.strip() for line in f if line.strip()]
 
588
  failures.append((raw, f"No latest version for {ws}/{proj}"))
589
  continue
590
  loc, names, splits, name_str = download_dataset(api_key, ws, proj, int(ver))
591
+ if loc:
592
+ dataset_info.append((loc, names, splits, name_str))
593
+ else:
594
+ failures.append((raw, f"DownloadError: {ws}/{proj}/v{ver}"))
595
 
596
  if not dataset_info:
597
  msg = "No datasets loaded.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
 
603
  df = pd.DataFrame([[n, n, counts.get(n, 0), False] for n in all_names],
604
  columns=["Original Name", "Rename To", "Max Images", "Remove"])
605
  status = "Datasets loaded successfully."
606
+ if failures:
607
+ status += f" ({len(dataset_info)} OK, {len(failures)} failed; see logs)."
608
  return status, dataset_info, df
609
 
610
  def update_class_counts_handler(class_df, dataset_info):
611
+ if class_df is None or not dataset_info:
612
+ return None
613
  class_df = pd.DataFrame(class_df)
614
  mapping = {row["Original Name"]: (None if bool(row["Remove"]) else row["Rename To"])
615
  for _, row in class_df.iterrows()}
 
619
  id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)}
620
  for split in splits:
621
  labels_dir = os.path.join(loc, split, 'labels')
622
+ if not os.path.exists(labels_dir):
623
+ continue
624
  for label_file in os.listdir(labels_dir):
625
+ if not label_file.endswith('.txt'):
626
+ continue
627
  found = set()
628
  with open(os.path.join(labels_dir, label_file), 'r') as f:
629
  for line in f:
630
  parts = line.strip().split()
631
+ if not parts:
632
+ continue
633
  try:
634
  cls_id = int(parts[0])
635
  mapped = id_to_final.get(cls_id, None)
636
+ if mapped:
637
+ found.add(mapped)
638
  except Exception:
639
  continue
640
+ for m in found:
641
+ counts[m] += 1
642
  return pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"])
643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644
  def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()):
645
+ if not dataset_path:
646
+ raise gr.Error("Finalize a dataset in Tab 2 before training.")
647
 
648
  train_script = find_training_script(REPO_DIR)
649
+ logging.info(f"Resolved training script: {train_script}")
650
  if not train_script:
651
  raise gr.Error("RT-DETRv2 training script not found inside the repo (looked for **/tools/train.py).")
652
 
 
655
  raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/L/X).")
656
 
657
  data_yaml = os.path.join(dataset_path, "data.yaml")
658
+ with open(data_yaml, "r") as f:
659
+ dy = yaml.safe_load(f)
660
  class_names = [str(x) for x in dy.get("names", [])]
661
  make_coco_annotations(dataset_path, class_names)
662
 
 
689
  proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
690
  stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
691
  bufsize=1, text=True, env=env)
692
+ for line in proc.stdout:
693
+ q.put(line.rstrip())
694
  proc.wait()
695
  q.put(f"__EXITCODE__:{proc.returncode}")
696
  except Exception as e:
 
700
 
701
  log_tail, last_epoch, total_epochs = [], 0, int(epochs)
702
  first_lines = []
703
+ line_no = 0
704
  while True:
705
  line = q.get()
706
  if line.startswith("__EXITCODE__"):
707
+ code = int(line.split(":", 1)[1])
708
  if code != 0:
709
  head = "\n".join(first_lines[:60])
710
  raise gr.Error(f"Training exited with code {code}.\nLast output:\n{head or 'No logs captured.'}")
711
  break
712
  if line.startswith("__ERROR__"):
713
+ raise gr.Error(f"Training failed: {line.split(':', 1)[1]}")
714
 
715
+ if len(first_lines) < 120:
716
+ first_lines.append(line)
717
+ log_tail.append(line)
718
+ log_tail = log_tail[-40:]
719
 
720
  m = re.search(r"[Ee]poch\s+(\d+)\s*/\s*(\d+)", line)
721
  if m:
722
  try:
723
+ last_epoch = int(m.group(1))
724
+ total_epochs = max(total_epochs, int(m.group(2)))
725
+ except Exception:
726
+ pass
727
+ progress(min(max(last_epoch / max(1, total_epochs), 0.0), 1.0), desc=f"Epoch {last_epoch}/{total_epochs}")
728
+
729
+ # Throttle plotting; close figs after yield to avoid leaks
730
+ line_no += 1
731
+ fig1 = fig2 = None
732
+ if line_no % 80 == 0:
733
+ fig1 = plt.figure()
734
+ plt.title("Loss (see logs)")
735
+ plt.plot([0, last_epoch], [0, 0])
736
+ plt.tight_layout()
737
+
738
+ fig2 = plt.figure()
739
+ plt.title("mAP (see logs)")
740
+ plt.plot([0, last_epoch], [0, 0])
741
+ plt.tight_layout()
742
 
 
 
743
  yield "\n".join(log_tail), fig1, fig2, None
744
 
745
+ if fig1 is not None:
746
+ plt.close(fig1)
747
+ if fig2 is not None:
748
+ plt.close(fig2)
749
+
750
  ckpt = find_best_checkpoint(out_dir) or find_best_checkpoint("runs")
751
  if not ckpt or not os.path.exists(ckpt):
752
  raise gr.Error("Training finished, but checkpoint file not found. Check logs/output directory.")
753
  yield "Training complete!", None, None, gr.File.update(value=ckpt, visible=True)
754
 
755
+ def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
756
+ if not dataset_info:
757
+ raise gr.Error("Load datasets first in Tab 1.")
758
+ if class_df is None:
759
+ raise gr.Error("Class data is missing.")
760
+ class_df = pd.DataFrame(class_df)
761
+ class_mapping, class_limits = {}, {}
762
+ for _, row in class_df.iterrows():
763
+ orig = row["Original Name"]
764
+ if bool(row["Remove"]):
765
+ continue
766
+ final_name = row["Rename To"]
767
+ class_mapping[orig] = final_name
768
+ class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"])
769
+ status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
770
+ return status, path
771
+
772
  def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
773
+ if not model_file:
774
+ raise gr.Error("No trained model file to upload.")
775
  from huggingface_hub import HfApi, HfFolder
776
  hf_status = "Skipped Hugging Face."
777
  if hf_token and hf_repo:
 
788
  if gh_token and gh_repo:
789
  progress(0.5, desc="Uploading to GitHub...")
790
  try:
791
+ if '/' not in gh_repo:
792
+ raise ValueError("GitHub repo must be 'username/repo'.")
793
  username, repo_name = gh_repo.split('/')
794
  api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}"
795
  headers = {"Authorization": f"token {gh_token}"}
796
+ with open(model_file.name, "rb") as f:
797
+ content = base64.b64encode(f.read()).decode()
798
  get_resp = requests.get(api_url, headers=headers, timeout=30)
799
  sha = get_resp.json().get('sha') if get_resp.ok else None
800
  data = {"message": "Upload trained model from Rolo app", "content": content}
801
+ if sha:
802
+ data["sha"] = sha
803
  put_resp = requests.put(api_url, headers=headers, json=data, timeout=60)
804
+ if put_resp.ok:
805
+ gh_status = f"Success! {put_resp.json()['content']['html_url']}"
806
+ else:
807
+ gh_status = f"GitHub Error: {put_resp.json().get('message','Unknown')}"
808
  except Exception as e:
809
  gh_status = f"GitHub Error: {e}"
810
+ progress(1)
811
+ return hf_status, gh_status
812
 
813
  # === UI =======================================================================
814
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
 
887
 
888
  if __name__ == "__main__":
889
  os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics") # silence stray warnings from other libs
890
+ # Log training script resolution at startup for quick troubleshooting
891
+ try:
892
+ ts = find_training_script(REPO_DIR)
893
+ logging.info(f"Startup check — training script at: {ts}")
894
+ except Exception as e:
895
+ logging.warning(f"Startup training-script check failed: {e}")
896
  app.launch(debug=True)