Update app.py
Browse files
app.py
CHANGED
@@ -145,7 +145,7 @@ def parse_roboflow_url(s: str):
|
|
145 |
version = None
|
146 |
if len(p) >= 3:
|
147 |
v = p[2]
|
148 |
-
if v.lower().startswith('v') and v[1:].isdigit():
|
149 |
version = int(v[1:])
|
150 |
elif v.isdigit():
|
151 |
version = int(v)
|
@@ -459,7 +459,7 @@ def _install_supervisely_logger_shim():
|
|
459 |
"""))
|
460 |
return str(root)
|
461 |
|
462 |
-
# ---- [
|
463 |
def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_pytorch.src"):
|
464 |
"""
|
465 |
sitecustomize shim that:
|
@@ -566,23 +566,6 @@ if not _try_patch_now():
|
|
566 |
f.write(code)
|
567 |
return sc_path
|
568 |
|
569 |
-
# ---- Deprecated: on-disk workspace patch (no-op now) -------------------------
|
570 |
-
def _patch_workspace_create(repo_root: str, module_default: str = "rtdetrv2_pytorch.src") -> str | None:
|
571 |
-
"""
|
572 |
-
Deprecated: we no longer edit third-party files on disk.
|
573 |
-
The shim in sitecustomize.py handles cfg/_pymodule safely.
|
574 |
-
"""
|
575 |
-
return None
|
576 |
-
|
577 |
-
def _unpatch_workspace_create(repo_root: str):
|
578 |
-
ws_path = os.path.join(repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py")
|
579 |
-
bak_path = ws_path + ".bak"
|
580 |
-
if os.path.exists(bak_path):
|
581 |
-
try:
|
582 |
-
shutil.copy2(bak_path, ws_path)
|
583 |
-
except Exception:
|
584 |
-
pass
|
585 |
-
|
586 |
def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
|
587 |
url = CKPT_URLS.get(model_key)
|
588 |
if not url:
|
@@ -731,7 +714,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
731 |
"shuffle": bool(default_shuffle),
|
732 |
"num_workers": 2,
|
733 |
"drop_last": bool(dl_key == "train_dataloader"),
|
734 |
-
"collate_fn": {"type": "BatchImageCollateFunction"},
|
735 |
"total_batch_size": int(batch),
|
736 |
}
|
737 |
cfg[dl_key] = block
|
@@ -751,7 +734,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
751 |
block.setdefault("shuffle", bool(default_shuffle))
|
752 |
block.setdefault("drop_last", bool(dl_key == "train_dataloader"))
|
753 |
|
754 |
-
# ---- FORCE-FIX collate name even if it existed already
|
755 |
cf = block.get("collate_fn", {})
|
756 |
if isinstance(cf, dict):
|
757 |
t = str(cf.get("type", ""))
|
@@ -805,7 +788,6 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
805 |
_maybe_set_model_field(cfg, "pretrain", p)
|
806 |
_maybe_set_model_field(cfg, "pretrained", p)
|
807 |
|
808 |
-
# Defensive: if after keeping includes we still don't have a model block, add a stub
|
809 |
if not cfg.get("model"):
|
810 |
cfg["model"] = {"type": "RTDETR", "num_classes": int(class_count)}
|
811 |
|
@@ -961,21 +943,13 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
961 |
try:
|
962 |
train_cwd = os.path.dirname(train_script)
|
963 |
|
964 |
-
# --- NEW: create a temp dir for sitecustomize and put it FIRST on PYTHONPATH
|
965 |
shim_dir = tempfile.mkdtemp(prefix="rtdetr_site_")
|
966 |
_install_workspace_shim_v3(shim_dir, module_default="rtdetrv2_pytorch.src")
|
967 |
|
968 |
env = os.environ.copy()
|
969 |
|
970 |
-
# Supervisely logger shim (can be later in path)
|
971 |
sly_shim_root = _install_supervisely_logger_shim()
|
972 |
|
973 |
-
# Build PYTHONPATH — order matters!
|
974 |
-
# 1) shim_dir (so sitecustomize auto-imports)
|
975 |
-
# 2) train_cwd (belt & suspenders; makes local imports easy)
|
976 |
-
# 3) PY_IMPL_DIR + REPO_DIR (RT-DETRv2 code)
|
977 |
-
# 4) sly_shim_root (optional)
|
978 |
-
# 5) existing PYTHONPATH
|
979 |
env["PYTHONPATH"] = os.pathsep.join(filter(None, [
|
980 |
shim_dir,
|
981 |
train_cwd,
|
@@ -987,8 +961,7 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
987 |
|
988 |
env.setdefault("WANDB_DISABLED", "true")
|
989 |
env.setdefault("RTDETR_PYMODULE", "rtdetrv2_pytorch.src")
|
990 |
-
env.setdefault("PYTHONUNBUFFERED", "1")
|
991 |
-
# Optional tiny guard: pick a single visible GPU if available
|
992 |
if torch.cuda.is_available():
|
993 |
env.setdefault("CUDA_VISIBLE_DEVICES", "0")
|
994 |
|
@@ -1147,7 +1120,9 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
|
|
1147 |
gr.Markdown("Pick RT-DETRv2 model, set hyper-params, press Start.")
|
1148 |
with gr.Row():
|
1149 |
with gr.Column(scale=1):
|
1150 |
-
|
|
|
|
|
1151 |
label="Model (RT-DETRv2)")
|
1152 |
run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
|
1153 |
epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
|
|
|
145 |
version = None
|
146 |
if len(p) >= 3:
|
147 |
v = p[2]
|
148 |
+
if v.lower().startswith('v') and v[1:].isdigit():
|
149 |
version = int(v[1:])
|
150 |
elif v.isdigit():
|
151 |
version = int(v)
|
|
|
459 |
"""))
|
460 |
return str(root)
|
461 |
|
462 |
+
# ---- [!! CORRECTED !!] robust sitecustomize shim with lazy import hook --------------------
|
463 |
def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_pytorch.src"):
|
464 |
"""
|
465 |
sitecustomize shim that:
|
|
|
566 |
f.write(code)
|
567 |
return sc_path
|
568 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
569 |
def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
|
570 |
url = CKPT_URLS.get(model_key)
|
571 |
if not url:
|
|
|
714 |
"shuffle": bool(default_shuffle),
|
715 |
"num_workers": 2,
|
716 |
"drop_last": bool(dl_key == "train_dataloader"),
|
717 |
+
"collate_fn": {"type": "BatchImageCollateFunction"},
|
718 |
"total_batch_size": int(batch),
|
719 |
}
|
720 |
cfg[dl_key] = block
|
|
|
734 |
block.setdefault("shuffle", bool(default_shuffle))
|
735 |
block.setdefault("drop_last", bool(dl_key == "train_dataloader"))
|
736 |
|
737 |
+
# ---- FORCE-FIX collate name typo even if it existed already
|
738 |
cf = block.get("collate_fn", {})
|
739 |
if isinstance(cf, dict):
|
740 |
t = str(cf.get("type", ""))
|
|
|
788 |
_maybe_set_model_field(cfg, "pretrain", p)
|
789 |
_maybe_set_model_field(cfg, "pretrained", p)
|
790 |
|
|
|
791 |
if not cfg.get("model"):
|
792 |
cfg["model"] = {"type": "RTDETR", "num_classes": int(class_count)}
|
793 |
|
|
|
943 |
try:
|
944 |
train_cwd = os.path.dirname(train_script)
|
945 |
|
|
|
946 |
shim_dir = tempfile.mkdtemp(prefix="rtdetr_site_")
|
947 |
_install_workspace_shim_v3(shim_dir, module_default="rtdetrv2_pytorch.src")
|
948 |
|
949 |
env = os.environ.copy()
|
950 |
|
|
|
951 |
sly_shim_root = _install_supervisely_logger_shim()
|
952 |
|
|
|
|
|
|
|
|
|
|
|
|
|
953 |
env["PYTHONPATH"] = os.pathsep.join(filter(None, [
|
954 |
shim_dir,
|
955 |
train_cwd,
|
|
|
961 |
|
962 |
env.setdefault("WANDB_DISABLED", "true")
|
963 |
env.setdefault("RTDETR_PYMODULE", "rtdetrv2_pytorch.src")
|
964 |
+
env.setdefault("PYTHONUNBUFFERED", "1")
|
|
|
965 |
if torch.cuda.is_available():
|
966 |
env.setdefault("CUDA_VISIBLE_DEVICES", "0")
|
967 |
|
|
|
1120 |
gr.Markdown("Pick RT-DETRv2 model, set hyper-params, press Start.")
|
1121 |
with gr.Row():
|
1122 |
with gr.Column(scale=1):
|
1123 |
+
# [UI IMPROVEMENT] Using (label, value) format for a better user experience
|
1124 |
+
model_dd = gr.Dropdown(choices=[(label, value) for value, label in MODEL_CHOICES],
|
1125 |
+
value=DEFAULT_MODEL_KEY,
|
1126 |
label="Model (RT-DETRv2)")
|
1127 |
run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
|
1128 |
epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
|