wuhp commited on
Commit
a5f6137
·
verified ·
1 Parent(s): 2716e64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -33
app.py CHANGED
@@ -145,7 +145,7 @@ def parse_roboflow_url(s: str):
145
  version = None
146
  if len(p) >= 3:
147
  v = p[2]
148
- if v.lower().startswith('v') and v[1:].isdigit(): # <-- FIXED
149
  version = int(v[1:])
150
  elif v.isdigit():
151
  version = int(v)
@@ -459,7 +459,7 @@ def _install_supervisely_logger_shim():
459
  """))
460
  return str(root)
461
 
462
- # ---- [UPDATED] robust sitecustomize shim with lazy import hook --------------------
463
  def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_pytorch.src"):
464
  """
465
  sitecustomize shim that:
@@ -566,23 +566,6 @@ if not _try_patch_now():
566
  f.write(code)
567
  return sc_path
568
 
569
- # ---- Deprecated: on-disk workspace patch (no-op now) -------------------------
570
- def _patch_workspace_create(repo_root: str, module_default: str = "rtdetrv2_pytorch.src") -> str | None:
571
- """
572
- Deprecated: we no longer edit third-party files on disk.
573
- The shim in sitecustomize.py handles cfg/_pymodule safely.
574
- """
575
- return None
576
-
577
- def _unpatch_workspace_create(repo_root: str):
578
- ws_path = os.path.join(repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py")
579
- bak_path = ws_path + ".bak"
580
- if os.path.exists(bak_path):
581
- try:
582
- shutil.copy2(bak_path, ws_path)
583
- except Exception:
584
- pass
585
-
586
  def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
587
  url = CKPT_URLS.get(model_key)
588
  if not url:
@@ -731,7 +714,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
731
  "shuffle": bool(default_shuffle),
732
  "num_workers": 2,
733
  "drop_last": bool(dl_key == "train_dataloader"),
734
- "collate_fn": {"type": "BatchImageCollateFunction"}, # correct spelling
735
  "total_batch_size": int(batch),
736
  }
737
  cfg[dl_key] = block
@@ -751,7 +734,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
751
  block.setdefault("shuffle", bool(default_shuffle))
752
  block.setdefault("drop_last", bool(dl_key == "train_dataloader"))
753
 
754
- # ---- FORCE-FIX collate name even if it existed already
755
  cf = block.get("collate_fn", {})
756
  if isinstance(cf, dict):
757
  t = str(cf.get("type", ""))
@@ -805,7 +788,6 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
805
  _maybe_set_model_field(cfg, "pretrain", p)
806
  _maybe_set_model_field(cfg, "pretrained", p)
807
 
808
- # Defensive: if after keeping includes we still don't have a model block, add a stub
809
  if not cfg.get("model"):
810
  cfg["model"] = {"type": "RTDETR", "num_classes": int(class_count)}
811
 
@@ -961,21 +943,13 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
961
  try:
962
  train_cwd = os.path.dirname(train_script)
963
 
964
- # --- NEW: create a temp dir for sitecustomize and put it FIRST on PYTHONPATH
965
  shim_dir = tempfile.mkdtemp(prefix="rtdetr_site_")
966
  _install_workspace_shim_v3(shim_dir, module_default="rtdetrv2_pytorch.src")
967
 
968
  env = os.environ.copy()
969
 
970
- # Supervisely logger shim (can be later in path)
971
  sly_shim_root = _install_supervisely_logger_shim()
972
 
973
- # Build PYTHONPATH — order matters!
974
- # 1) shim_dir (so sitecustomize auto-imports)
975
- # 2) train_cwd (belt & suspenders; makes local imports easy)
976
- # 3) PY_IMPL_DIR + REPO_DIR (RT-DETRv2 code)
977
- # 4) sly_shim_root (optional)
978
- # 5) existing PYTHONPATH
979
  env["PYTHONPATH"] = os.pathsep.join(filter(None, [
980
  shim_dir,
981
  train_cwd,
@@ -987,8 +961,7 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
987
 
988
  env.setdefault("WANDB_DISABLED", "true")
989
  env.setdefault("RTDETR_PYMODULE", "rtdetrv2_pytorch.src")
990
- env.setdefault("PYTHONUNBUFFERED", "1") # nicer real-time logs
991
- # Optional tiny guard: pick a single visible GPU if available
992
  if torch.cuda.is_available():
993
  env.setdefault("CUDA_VISIBLE_DEVICES", "0")
994
 
@@ -1147,7 +1120,9 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
1147
  gr.Markdown("Pick RT-DETRv2 model, set hyper-params, press Start.")
1148
  with gr.Row():
1149
  with gr.Column(scale=1):
1150
- model_dd = gr.Dropdown(choices=[(label, k) for k, label in MODEL_CHOICES], value=DEFAULT_MODEL_KEY,
 
 
1151
  label="Model (RT-DETRv2)")
1152
  run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
1153
  epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
 
145
  version = None
146
  if len(p) >= 3:
147
  v = p[2]
148
+ if v.lower().startswith('v') and v[1:].isdigit():
149
  version = int(v[1:])
150
  elif v.isdigit():
151
  version = int(v)
 
459
  """))
460
  return str(root)
461
 
462
+ # ---- [!! CORRECTED !!] robust sitecustomize shim with lazy import hook --------------------
463
  def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_pytorch.src"):
464
  """
465
  sitecustomize shim that:
 
566
  f.write(code)
567
  return sc_path
568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
  def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
570
  url = CKPT_URLS.get(model_key)
571
  if not url:
 
714
  "shuffle": bool(default_shuffle),
715
  "num_workers": 2,
716
  "drop_last": bool(dl_key == "train_dataloader"),
717
+ "collate_fn": {"type": "BatchImageCollateFunction"},
718
  "total_batch_size": int(batch),
719
  }
720
  cfg[dl_key] = block
 
734
  block.setdefault("shuffle", bool(default_shuffle))
735
  block.setdefault("drop_last", bool(dl_key == "train_dataloader"))
736
 
737
+ # ---- FORCE-FIX collate name typo even if it existed already
738
  cf = block.get("collate_fn", {})
739
  if isinstance(cf, dict):
740
  t = str(cf.get("type", ""))
 
788
  _maybe_set_model_field(cfg, "pretrain", p)
789
  _maybe_set_model_field(cfg, "pretrained", p)
790
 
 
791
  if not cfg.get("model"):
792
  cfg["model"] = {"type": "RTDETR", "num_classes": int(class_count)}
793
 
 
943
  try:
944
  train_cwd = os.path.dirname(train_script)
945
 
 
946
  shim_dir = tempfile.mkdtemp(prefix="rtdetr_site_")
947
  _install_workspace_shim_v3(shim_dir, module_default="rtdetrv2_pytorch.src")
948
 
949
  env = os.environ.copy()
950
 
 
951
  sly_shim_root = _install_supervisely_logger_shim()
952
 
 
 
 
 
 
 
953
  env["PYTHONPATH"] = os.pathsep.join(filter(None, [
954
  shim_dir,
955
  train_cwd,
 
961
 
962
  env.setdefault("WANDB_DISABLED", "true")
963
  env.setdefault("RTDETR_PYMODULE", "rtdetrv2_pytorch.src")
964
+ env.setdefault("PYTHONUNBUFFERED", "1")
 
965
  if torch.cuda.is_available():
966
  env.setdefault("CUDA_VISIBLE_DEVICES", "0")
967
 
 
1120
  gr.Markdown("Pick RT-DETRv2 model, set hyper-params, press Start.")
1121
  with gr.Row():
1122
  with gr.Column(scale=1):
1123
+ # [UI IMPROVEMENT] Using (label, value) format for a better user experience
1124
+ model_dd = gr.Dropdown(choices=[(label, value) for value, label in MODEL_CHOICES],
1125
+ value=DEFAULT_MODEL_KEY,
1126
  label="Model (RT-DETRv2)")
1127
  run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
1128
  epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")