|
|
|
import os, sys, subprocess, shutil, stat, yaml, gradio as gr, re, random, logging, requests, json, base64, time, pathlib, tempfile, textwrap |
|
from urllib.parse import urlparse |
|
from glob import glob |
|
from threading import Thread |
|
from queue import Queue |
|
|
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from roboflow import Roboflow |
|
from PIL import Image |
|
import torch |
|
from string import Template |
|
|
|
|
|
os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics") |
|
os.environ.setdefault("WANDB_DISABLED", "true") |
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
|
|
REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2" |
|
REPO_DIR = os.path.join(os.getcwd(), "third_party", "RT-DETRv2") |
|
PY_IMPL_DIR = os.path.join(REPO_DIR, "rtdetrv2_pytorch") |
|
|
|
|
|
COMMON_REQUIREMENTS = [ |
|
"gradio>=4.36.1", |
|
"roboflow>=1.1.28", |
|
"requests>=2.31.0", |
|
"huggingface_hub>=0.22.0", |
|
"pandas>=2.0.0", |
|
"matplotlib>=3.7.0", |
|
"torch>=2.0.1", |
|
"torchvision>=0.15.2", |
|
"pyyaml>=6.0.1", |
|
"Pillow>=10.0.0", |
|
"supervisely>=6.0.0", |
|
"tensorboard>=2.13.0", |
|
"pycocotools>=2.0.7", |
|
] |
|
|
|
|
|
def pip_install(args): |
|
logging.info(f"pip install {' '.join(args)}") |
|
subprocess.check_call([sys.executable, "-m", "pip", "install"] + args) |
|
|
|
def ensure_repo_and_requirements(): |
|
os.makedirs(os.path.dirname(REPO_DIR), exist_ok=True) |
|
if not os.path.exists(REPO_DIR): |
|
logging.info(f"Cloning RT-DETRv2 repo to {REPO_DIR} ...") |
|
subprocess.check_call(["git", "clone", "--depth", "1", REPO_URL, REPO_DIR]) |
|
else: |
|
try: |
|
subprocess.check_call(["git", "-C", REPO_DIR, "pull", "--ff-only"]) |
|
except Exception: |
|
logging.warning("git pull failed; continuing with current checkout") |
|
|
|
|
|
if os.getenv("HF_SPACE") == "1" or os.getenv("SPACE_ID"): |
|
logging.info("Detected Hugging Face Space — skipping runtime pip installs.") |
|
return |
|
|
|
|
|
pip_install(COMMON_REQUIREMENTS) |
|
req_file = os.path.join(PY_IMPL_DIR, "requirements.txt") |
|
if os.path.exists(req_file): |
|
pip_install(["-r", req_file]) |
|
|
|
try: |
|
import supervisely |
|
except Exception: |
|
logging.warning("supervisely not importable after first pass; retrying install…") |
|
pip_install(["supervisely>=6.0.0"]) |
|
|
|
try: |
|
ensure_repo_and_requirements() |
|
except Exception: |
|
logging.exception("Bootstrap failed, UI will still load so you can see errors") |
|
|
|
|
|
MODEL_CHOICES = [ |
|
("rtdetrv2_s", "S (r18vd, 120e) — default"), |
|
("rtdetrv2_m", "M (r34vd, 120e)"), |
|
("rtdetrv2_msp", "M* (r50vd_m, 7x)"), |
|
("rtdetrv2_l", "L (r50vd, 6x)"), |
|
("rtdetrv2_x", "X (r101vd, 6x)"), |
|
] |
|
DEFAULT_MODEL_KEY = "rtdetrv2_s" |
|
|
|
CONFIG_PATHS = { |
|
"rtdetrv2_s": "rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r18vd_120e_coco.yml", |
|
"rtdetrv2_m": "rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r34vd_120e_coco.yml", |
|
"rtdetrv2_msp": "rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r50vd_m_7x_coco.yml", |
|
"rtdetrv2_l": "rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r50vd_6x_coco.yml", |
|
"rtdetrv2_x": "rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r101vd_6x_coco.yml", |
|
} |
|
|
|
CKPT_URLS = { |
|
"rtdetrv2_s": "https://github.com/lyuwenyu/storage/releases/download/v0.2/rtdetrv2_r18vd_120e_coco_rerun_48.1.pth", |
|
"rtdetrv2_m": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r34vd_120e_coco_ema.pth", |
|
"rtdetrv2_msp": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r50vd_m_7x_coco_ema.pth", |
|
"rtdetrv2_l": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r50vd_6x_coco_ema.pth", |
|
"rtdetrv2_x": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r101vd_6x_coco_from_paddle.pth", |
|
} |
|
|
|
|
|
def handle_remove_readonly(func, path, exc_info): |
|
try: |
|
os.chmod(path, stat.S_IWRITE) |
|
except Exception: |
|
pass |
|
func(path) |
|
|
|
_ROBO_URL_RX = re.compile(r""" |
|
^(?: |
|
(?:https?://)?(?:universe|app|www)?\.?roboflow\.com/ |
|
(?P<ws>[A-Za-z0-9\-_]+)/(?P<proj>[A-Za-z0-9\-_]+)/?(?:(?:dataset/[^/]+/)?(?:v?(?P<ver>\d+))?)? |
|
| |
|
(?P<ws2>[A-Za-z0-9\-_]+)/(?P<proj2>[A-Za-z0-9\-_]+)(?:/(?:v)?(?P<ver2>\d+))? |
|
)$ |
|
""", re.VERBOSE | re.IGNORECASE) |
|
|
|
def parse_roboflow_url(s: str): |
|
s = s.strip() |
|
m = _ROBO_URL_RX.match(s) |
|
if m: |
|
ws = m.group('ws') or m.group('ws2') |
|
proj = m.group('proj') or m.group('proj2') |
|
ver = m.group('ver') or m.group('ver2') |
|
return ws, proj, (int(ver) if ver else None) |
|
parsed = urlparse(s) |
|
parts = [p for p in parsed.path.strip('/').split('/') if p] |
|
if len(parts) >= 2: |
|
version = None |
|
if len(parts) >= 3: |
|
v = parts[2] |
|
if v.lower().startswith('v') and v[1:].isdigit(): |
|
version = int(v[1:]) |
|
elif v.isdigit(): |
|
version = int(v) |
|
return parts[0], parts[1], version |
|
if '/' in s and 'roboflow' not in s: |
|
p = s.split('/') |
|
if len(p) >= 2: |
|
version = None |
|
if len(p) >= 3: |
|
v = p[2] |
|
if v.lower().startswith('v') and v[1:].isdigit(): |
|
version = int(v[1:]) |
|
elif v.isdigit(): |
|
version = int(v) |
|
return p[0], p[1], version |
|
return None, None, None |
|
|
|
def get_latest_version(api_key, workspace, project): |
|
try: |
|
rf = Roboflow(api_key=api_key) |
|
proj = rf.workspace(workspace).project(project) |
|
versions = sorted([int(v.version) for v in proj.versions()], reverse=True) |
|
return versions[0] if versions else None |
|
except Exception as e: |
|
logging.error(f"Could not get latest version for {workspace}/{project}: {e}") |
|
return None |
|
|
|
def _extract_class_names(data_yaml): |
|
names = data_yaml.get('names', None) |
|
if isinstance(names, dict): |
|
def _k(x): |
|
try: |
|
return int(x) |
|
except Exception: |
|
return str(x) |
|
keys = sorted(names.keys(), key=_k) |
|
names_list = [names[k] for k in keys] |
|
elif isinstance(names, list): |
|
names_list = names |
|
else: |
|
nc = int(data_yaml.get('nc', 0) or 0) |
|
names_list = [f"class_{i}" for i in range(nc)] |
|
return [str(x) for x in names_list] |
|
|
|
def download_dataset(api_key, workspace, project, version): |
|
try: |
|
rf = Roboflow(api_key=api_key) |
|
proj = rf.workspace(workspace).project(project) |
|
ver = proj.version(int(version)) |
|
dataset = ver.download("yolov8") |
|
data_yaml_path = os.path.join(dataset.location, 'data.yaml') |
|
with open(data_yaml_path, 'r', encoding="utf-8") as f: |
|
data_yaml = yaml.safe_load(f) |
|
class_names = _extract_class_names(data_yaml) |
|
splits = [s for s in ['train', 'valid', 'test'] if os.path.exists(os.path.join(dataset.location, s))] |
|
return dataset.location, class_names, splits, f"{project}-v{version}" |
|
except Exception as e: |
|
logging.error(f"Failed to download {workspace}/{project}/v{version}: {e}") |
|
return None, [], [], None |
|
|
|
def label_path_for(img_path: str) -> str: |
|
split_dir = os.path.dirname(os.path.dirname(img_path)) |
|
base = os.path.splitext(os.path.basename(img_path))[0] + '.txt' |
|
return os.path.join(split_dir, 'labels', base) |
|
|
|
|
|
def yolo_to_coco(split_dir_images, split_dir_labels, class_names, out_json): |
|
images, annotations = [], [] |
|
categories = [{"id": i, "name": n} for i, n in enumerate(class_names)] |
|
ann_id = 1 |
|
img_id = 1 |
|
for fname in sorted(os.listdir(split_dir_images)): |
|
if not fname.lower().endswith((".jpg", ".jpeg", ".png")): |
|
continue |
|
img_path = os.path.join(split_dir_images, fname) |
|
try: |
|
with Image.open(img_path) as im: |
|
w, h = im.size |
|
except Exception: |
|
continue |
|
images.append({"id": img_id, "file_name": fname, "width": w, "height": h}) |
|
label_file = os.path.join(split_dir_labels, os.path.splitext(fname)[0] + ".txt") |
|
if os.path.exists(label_file): |
|
with open(label_file, "r", encoding="utf-8") as f: |
|
for line in f: |
|
parts = line.strip().split() |
|
if len(parts) < 5: |
|
continue |
|
try: |
|
cls = int(float(parts[0])) |
|
cx, cy, bw, bh = map(float, parts[1:5]) |
|
except Exception: |
|
continue |
|
x = max(0.0, (cx - bw / 2.0) * w) |
|
y = max(0.0, (cy - bh / 2.0) * h) |
|
ww = max(1.0, bw * w) |
|
hh = max(1.0, bh * h) |
|
if x + ww > w: |
|
ww = max(1.0, w - x) |
|
if y + hh > h: |
|
hh = max(1.0, h - y) |
|
annotations.append({ |
|
"id": ann_id, |
|
"image_id": img_id, |
|
"category_id": cls, |
|
"bbox": [x, y, ww, hh], |
|
"area": max(1.0, ww * hh), |
|
"iscrowd": 0, |
|
"segmentation": [] |
|
}) |
|
ann_id += 1 |
|
img_id += 1 |
|
coco = {"images": images, "annotations": annotations, "categories": categories} |
|
os.makedirs(os.path.dirname(out_json), exist_ok=True) |
|
with open(out_json, "w", encoding="utf-8") as f: |
|
json.dump(coco, f) |
|
|
|
def make_coco_annotations(merged_dir, class_names): |
|
ann_dir = os.path.join(merged_dir, "annotations") |
|
os.makedirs(ann_dir, exist_ok=True) |
|
mapping = {"train": "instances_train.json", "valid": "instances_val.json", "test": "instances_test.json"} |
|
for split, outname in mapping.items(): |
|
img_dir = os.path.join(merged_dir, split, "images") |
|
lbl_dir = os.path.join(merged_dir, split, "labels") |
|
out_json = os.path.join(ann_dir, outname) |
|
if os.path.exists(img_dir) and os.listdir(img_dir): |
|
yolo_to_coco(img_dir, lbl_dir, class_names, out_json) |
|
return ann_dir |
|
|
|
|
|
def gather_class_counts(dataset_info, class_mapping): |
|
if not dataset_info: |
|
return {} |
|
final_names = set(v for v in class_mapping.values() if v is not None) |
|
counts = {name: 0 for name in final_names} |
|
for loc, names, splits, _ in dataset_info: |
|
id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)} |
|
for split in splits: |
|
labels_dir = os.path.join(loc, split, 'labels') |
|
if not os.path.exists(labels_dir): |
|
continue |
|
for label_file in os.listdir(labels_dir): |
|
if not label_file.endswith('.txt'): |
|
continue |
|
found = set() |
|
with open(os.path.join(labels_dir, label_file), 'r', encoding="utf-8") as f: |
|
for line in f: |
|
parts = line.strip().split() |
|
if not parts: |
|
continue |
|
try: |
|
cls_id = int(parts[0]) |
|
mapped = id_to_name.get(cls_id, None) |
|
if mapped: |
|
found.add(mapped) |
|
except Exception: |
|
continue |
|
for m in found: |
|
counts[m] += 1 |
|
return counts |
|
|
|
def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()): |
|
merged_dir = 'rolo_merged_dataset' |
|
if os.path.exists(merged_dir): |
|
shutil.rmtree(merged_dir, onerror=handle_remove_readonly) |
|
|
|
progress(0, desc="Creating directories...") |
|
for split in ['train', 'valid', 'test']: |
|
os.makedirs(os.path.join(merged_dir, split, 'images'), exist_ok=True) |
|
os.makedirs(os.path.join(merged_dir, split, 'labels'), exist_ok=True) |
|
|
|
active_classes = sorted({cls for cls, limit in class_limits.items() if limit > 0}) |
|
final_class_map = {name: i for i, name in enumerate(active_classes)} |
|
|
|
all_images = [] |
|
for loc, _, splits, _ in dataset_info: |
|
for split in splits: |
|
img_dir = os.path.join(loc, split, 'images') |
|
if not os.path.exists(img_dir): |
|
continue |
|
for img_file in os.listdir(img_dir): |
|
if img_file.lower().endswith(('.jpg', '.jpeg', '.png')): |
|
all_images.append((os.path.join(img_dir, img_file), split, loc)) |
|
random.shuffle(all_images) |
|
|
|
progress(0.2, desc="Selecting images based on limits...") |
|
selected_images, current_counts = [], {cls: 0 for cls in active_classes} |
|
loc_to_names = {info[0]: info[1] for info in dataset_info} |
|
|
|
for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"): |
|
lbl_path = label_path_for(img_path) |
|
if not os.path.exists(lbl_path): |
|
continue |
|
source_names = loc_to_names.get(source_loc, []) |
|
image_classes = set() |
|
with open(lbl_path, 'r', encoding="utf-8") as f: |
|
for line in f: |
|
parts = line.strip().split() |
|
if not parts: |
|
continue |
|
try: |
|
cls_id = int(parts[0]) |
|
orig = source_names[cls_id] |
|
mapped = class_mapping.get(orig, orig) |
|
if mapped in active_classes: |
|
image_classes.add(mapped) |
|
except Exception: |
|
continue |
|
if not image_classes: |
|
continue |
|
if any(current_counts[c] >= class_limits[c] for c in image_classes): |
|
continue |
|
selected_images.append((img_path, split)) |
|
for c in image_classes: |
|
current_counts[c] += 1 |
|
|
|
progress(0.6, desc=f"Copying {len(selected_images)} files...") |
|
for img_path, split in progress.tqdm(selected_images, desc="Finalizing files"): |
|
lbl_path = label_path_for(img_path) |
|
out_img = os.path.join(merged_dir, split, 'images', os.path.basename(img_path)) |
|
out_lbl = os.path.join(merged_dir, split, 'labels', os.path.basename(lbl_path)) |
|
shutil.copy(img_path, out_img) |
|
|
|
source_loc = None |
|
for info in dataset_info: |
|
if img_path.startswith(info[0]): |
|
source_loc = info[0] |
|
break |
|
source_names = loc_to_names.get(source_loc, []) |
|
|
|
with open(lbl_path, 'r', encoding="utf-8") as f_in, open(out_lbl, 'w', encoding="utf-8") as f_out: |
|
for line in f_in: |
|
parts = line.strip().split() |
|
if not parts: |
|
continue |
|
try: |
|
old_id = int(parts[0]) |
|
original_name = source_names[old_id] |
|
mapped_name = class_mapping.get(original_name, original_name) |
|
if mapped_name in final_class_map: |
|
new_id = final_class_map[mapped_name] |
|
f_out.write(f"{new_id} {' '.join(parts[1:])}\n") |
|
except Exception: |
|
continue |
|
|
|
progress(0.9, desc="Writing data.yaml + COCO annotations...") |
|
with open(os.path.join(merged_dir, 'data.yaml'), 'w', encoding="utf-8") as f: |
|
yaml.dump({ |
|
'path': os.path.abspath(merged_dir), |
|
'train': 'train/images', |
|
'val': 'valid/images', |
|
'test': 'test/images', |
|
'nc': len(active_classes), |
|
'names': active_classes |
|
}, f) |
|
|
|
ann_dir = make_coco_annotations(merged_dir, active_classes) |
|
progress(0.98, desc="Finalizing...") |
|
return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir) |
|
|
|
|
|
def find_training_script(repo_root): |
|
canonical = os.path.join(repo_root, "rtdetrv2_pytorch", "tools", "train.py") |
|
if os.path.exists(canonical): |
|
return canonical |
|
candidates = [] |
|
for pat in ["**/tools/train.py", "**/train.py", "**/tools/train_net.py"]: |
|
candidates.extend(glob(os.path.join(repo_root, pat), recursive=True)) |
|
def _score(p): |
|
pl = p.replace("\\", "/").lower() |
|
return (0 if "rtdetrv2_pytorch" in pl else 1, len(p)) |
|
candidates.sort(key=_score) |
|
return candidates[0] if candidates else None |
|
|
|
def find_model_config_template(model_key): |
|
rel = CONFIG_PATHS.get(model_key) |
|
if not rel: |
|
return None |
|
path = os.path.join(REPO_DIR, rel) |
|
return path if os.path.exists(path) else None |
|
|
|
def _set_first_existing_key(d: dict, keys: list, value, fallback_key: str | None = None): |
|
for k in keys: |
|
if k in d: |
|
d[k] = value |
|
return k |
|
if fallback_key: |
|
d[fallback_key] = value |
|
return fallback_key |
|
return None |
|
|
|
def _set_first_existing_key_deep(cfg: dict, keys: list, value): |
|
for scope in [cfg, cfg.get("model", {}), cfg.get("solver", {})]: |
|
if isinstance(scope, dict): |
|
for k in keys: |
|
if k in scope: |
|
scope[k] = value |
|
return True |
|
if "model" not in cfg or not isinstance(cfg["model"], dict): |
|
cfg["model"] = {} |
|
cfg["model"][keys[0]] = value |
|
return True |
|
|
|
def _install_supervisely_logger_shim(): |
|
root = pathlib.Path(tempfile.gettempdir()) / "sly_shim_pkg" |
|
pkg_training = root / "supervisely" / "nn" / "training" |
|
pkg_training.mkdir(parents=True, exist_ok=True) |
|
|
|
for p in [root / "supervisely", root / "supervisely" / "nn", pkg_training]: |
|
init_file = p / "__init__.py" |
|
if not init_file.exists(): |
|
init_file.write_text("") |
|
|
|
(pkg_training / "__init__.py").write_text(textwrap.dedent(""" |
|
class _TrainLogger: |
|
def __init__(self): pass |
|
def reset(self): pass |
|
def log_metrics(self, metrics: dict, step: int | None = None): pass |
|
def log_artifacts(self, *a, **k): pass |
|
def log_image(self, *a, **k): pass |
|
train_logger = _TrainLogger() |
|
""")) |
|
return str(root) |
|
|
|
|
|
def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_pytorch.src"): |
|
""" |
|
sitecustomize shim that: |
|
- patches workspace.create to handle dict-based component definitions, |
|
- ensures cfg is a dict, |
|
- injects cfg['_pymodule'] as a *module object*, |
|
even if the target module is imported after sitecustomize runs. |
|
""" |
|
os.makedirs(dest_dir, exist_ok=True) |
|
sc_path = os.path.join(dest_dir, "sitecustomize.py") |
|
|
|
tmpl = Template(r""" |
|
import os, sys, importlib, importlib.abc, importlib.util, importlib.machinery, types |
|
MOD_DEFAULT = os.environ.get("RTDETR_PYMODULE", "$module_default") or "$module_default" |
|
TARGET = "rtdetrv2_pytorch.src.core.workspace" |
|
|
|
def _ensure_pymodule_object(cfg: dict): |
|
pm = cfg.get("_pymodule", None) |
|
if isinstance(pm, types.ModuleType): |
|
return pm |
|
name = (pm or "").strip() if isinstance(pm, str) else MOD_DEFAULT |
|
if not name: |
|
name = MOD_DEFAULT |
|
try: |
|
mod = importlib.import_module(name) |
|
except Exception: |
|
mod = importlib.import_module(MOD_DEFAULT) |
|
cfg["_pymodule"] = mod |
|
return mod |
|
|
|
def _patch_ws(ws_mod): |
|
if getattr(ws_mod, "__rolo_patched__", False): |
|
return |
|
_orig_create = ws_mod.create |
|
|
|
# NEW, FIXED create function |
|
def create(name, *args, **kwargs): |
|
# Unify all config sources into one dictionary. The main config is often the second arg. |
|
cfg = {} |
|
if args and isinstance(args[0], dict): |
|
cfg.update(args[0]) |
|
if 'cfg' in kwargs and isinstance(kwargs['cfg'], dict): |
|
cfg.update(kwargs['cfg']) |
|
|
|
_ensure_pymodule_object(cfg) |
|
|
|
# The core of the fix: handle when the component itself is passed as a dict. |
|
# This is what happens when the library tries to create the model. |
|
if isinstance(name, dict): |
|
component_params = name.copy() |
|
type_name = component_params.pop('type', None) |
|
if type_name is None: |
|
# If no 'type' key, we can't proceed. Fall back to original to get the original error. |
|
return _orig_create(name, *args, **kwargs) |
|
|
|
# Merge the component's own parameters (like num_classes) into the main config. |
|
cfg.update(component_params) |
|
|
|
# Now, call the original `create` function the way it expects: |
|
# with the component name as a string, and the full config. |
|
return _orig_create(type_name, cfg=cfg) |
|
|
|
# If 'name' was already a string (the normal case for solvers, etc.), proceed as expected. |
|
return _orig_create(name, cfg=cfg) |
|
|
|
ws_mod.create = create |
|
ws_mod.__rolo_patched__ = True |
|
|
|
def _try_patch_now(): |
|
try: |
|
ws_mod = importlib.import_module(TARGET) |
|
_patch_ws(ws_mod) |
|
return True |
|
except Exception: |
|
return False |
|
|
|
if not _try_patch_now(): |
|
class _RoloFinder(importlib.abc.MetaPathFinder): |
|
def find_spec(self, fullname, path, target=None): |
|
if fullname != TARGET: |
|
return None |
|
origin_spec = importlib.util.find_spec(fullname) |
|
if origin_spec is None or origin_spec.loader is None: |
|
return None |
|
loader = origin_spec.loader |
|
class _RoloLoader(importlib.abc.Loader): |
|
def create_module(self, spec): |
|
if hasattr(loader, "create_module"): |
|
return loader.create_module(spec) |
|
return None |
|
def exec_module(self, module): |
|
loader.exec_module(module) |
|
try: |
|
_patch_ws(module) |
|
except Exception: |
|
pass |
|
spec = importlib.machinery.ModuleSpec(fullname, _RoloLoader(), origin=origin_spec.origin) |
|
spec.submodule_search_locations = origin_spec.submodule_search_locations |
|
return spec |
|
sys.meta_path.insert(0, _RoloFinder()) |
|
""") |
|
code = tmpl.substitute(module_default=module_default) |
|
with open(sc_path, "w", encoding="utf-8") as f: |
|
f.write(code) |
|
return sc_path |
|
|
|
def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None: |
|
url = CKPT_URLS.get(model_key) |
|
if not url: |
|
return None |
|
os.makedirs(out_dir, exist_ok=True) |
|
fname = os.path.join(out_dir, os.path.basename(url)) |
|
if os.path.exists(fname) and os.path.getsize(fname) > 0: |
|
return fname |
|
logging.info(f"Downloading pretrained checkpoint for {model_key} from {url}") |
|
try: |
|
with requests.get(url, stream=True, timeout=60) as r: |
|
r.raise_for_status() |
|
with open(fname, "wb") as f: |
|
for chunk in r.iter_content(chunk_size=1024 * 1024): |
|
if chunk: |
|
f.write(chunk) |
|
return fname |
|
except Exception as e: |
|
logging.warning(f"Could not fetch checkpoint: {e}") |
|
try: |
|
if os.path.exists(fname): |
|
os.remove(fname) |
|
except Exception: |
|
pass |
|
return None |
|
|
|
|
|
def _absify_any_paths_deep(node, base_dir, include_keys=("base", "_base_", "BASE", "BASE_YAML", |
|
"includes", "include", "BASES", "__include__")): |
|
def _absify(s: str) -> str: |
|
if os.path.isabs(s): |
|
return s |
|
if s.startswith("../") or s.endswith((".yml", ".yaml")): |
|
return os.path.abspath(os.path.join(base_dir, s)) |
|
return s |
|
|
|
if isinstance(node, dict): |
|
for k in list(node.keys()): |
|
v = node[k] |
|
if k in include_keys: |
|
if isinstance(v, str): |
|
node[k] = _absify(v) |
|
elif isinstance(v, list): |
|
node[k] = [_absify(x) if isinstance(x, str) else x for x in v] |
|
for k, v in list(node.items()): |
|
if isinstance(v, (dict, list)): |
|
_absify_any_paths_deep(v, base_dir, include_keys) |
|
elif isinstance(v, str): |
|
node[k] = _absify(v) |
|
elif isinstance(node, list): |
|
for i, v in enumerate(list(node)): |
|
if isinstance(v, (dict, list)): |
|
_absify_any_paths_deep(v, base_dir, include_keys) |
|
elif isinstance(v, str): |
|
node[i] = _absify(v) |
|
|
|
|
|
def _set_num_classes_safely(cfg: dict, n: int): |
|
def set_num_classes(node): |
|
if not isinstance(node, dict): |
|
return False |
|
if "num_classes" in node: |
|
node["num_classes"] = int(n) |
|
return True |
|
for k, v in node.items(): |
|
if isinstance(v, dict) and set_num_classes(v): |
|
return True |
|
return False |
|
|
|
m = cfg.get("model", None) |
|
if isinstance(m, dict): |
|
if not set_num_classes(m): |
|
m["num_classes"] = int(n) |
|
return |
|
|
|
if isinstance(m, str): |
|
block = cfg.get(m, None) |
|
if isinstance(block, dict): |
|
if not set_num_classes(block): |
|
block["num_classes"] = int(n) |
|
return |
|
|
|
cfg["num_classes"] = int(n) |
|
|
|
def _maybe_set_model_field(cfg: dict, key: str, value): |
|
m = cfg.get("model", None) |
|
if isinstance(m, dict): |
|
m[key] = value |
|
return |
|
if isinstance(m, str) and isinstance(cfg.get(m), dict): |
|
cfg[m][key] = value |
|
return |
|
cfg[key] = value |
|
|
|
|
|
def patch_base_config(base_cfg_path, merged_dir, class_count, run_name, |
|
epochs, batch, imgsz, lr, optimizer, pretrained_path: str | None): |
|
if not base_cfg_path or not os.path.exists(base_cfg_path): |
|
raise gr.Error("Could not locate a model config inside the RT-DETRv2 repo.") |
|
|
|
template_dir = os.path.dirname(base_cfg_path) |
|
|
|
|
|
with open(base_cfg_path, "r", encoding="utf-8") as f: |
|
cfg = yaml.safe_load(f) |
|
_absify_any_paths_deep(cfg, template_dir) |
|
|
|
|
|
cfg["task"] = cfg.get("task", "detection") |
|
cfg["_pymodule"] = cfg.get("_pymodule", "rtdetrv2_pytorch.src") |
|
|
|
|
|
cfg["sync_bn"] = False |
|
cfg.setdefault("device", "") |
|
cfg["find_unused_parameters"] = False |
|
|
|
ann_dir = os.path.join(merged_dir, "annotations") |
|
paths = { |
|
"train_json": os.path.abspath(os.path.join(ann_dir, "instances_train.json")), |
|
"val_json": os.path.abspath(os.path.join(ann_dir, "instances_val.json")), |
|
"test_json": os.path.abspath(os.path.join(ann_dir, "instances_test.json")), |
|
"train_img": os.path.abspath(os.path.join(merged_dir, "train", "images")), |
|
"val_img": os.path.abspath(os.path.join(merged_dir, "valid", "images")), |
|
"test_img": os.path.abspath(os.path.join(merged_dir, "test", "images")), |
|
"out_dir": os.path.abspath(os.path.join("runs", "train", run_name)), |
|
} |
|
|
|
def ensure_and_patch_dl(dl_key, img_key, json_key, default_shuffle): |
|
block = cfg.get(dl_key) |
|
if not isinstance(block, dict): |
|
block = { |
|
"type": "DataLoader", |
|
"dataset": { |
|
"type": "CocoDetection", |
|
"img_folder": paths[img_key], |
|
"ann_file": paths[json_key], |
|
"return_masks": False, |
|
"transforms": { |
|
"type": "Compose", |
|
"ops": [ |
|
{"type": "Resize", "size": [int(imgsz), int(imgsz)]}, |
|
{"type": "ConvertPILImage", "dtype": "float32", "scale": True}, |
|
], |
|
}, |
|
}, |
|
"shuffle": bool(default_shuffle), |
|
"num_workers": 2, |
|
"drop_last": bool(dl_key == "train_dataloader"), |
|
"collate_fn": {"type": "BatchImageCollateFunction"}, |
|
"total_batch_size": int(batch), |
|
} |
|
cfg[dl_key] = block |
|
|
|
ds = block.get("dataset", {}) |
|
if isinstance(ds, dict): |
|
ds["img_folder"] = paths[img_key] |
|
ds["ann_file"] = paths[json_key] |
|
for k in ("img_dir", "image_root", "data_root"): |
|
if k in ds: ds[k] = paths[img_key] |
|
for k in ("ann_path", "annotation", "annotations"): |
|
if k in ds: ds[k] = paths[json_key] |
|
block["dataset"] = ds |
|
|
|
block["total_batch_size"] = int(batch) |
|
block.setdefault("num_workers", 2) |
|
block.setdefault("shuffle", bool(default_shuffle)) |
|
block.setdefault("drop_last", bool(dl_key == "train_dataloader")) |
|
|
|
|
|
cf = block.get("collate_fn", {}) |
|
if isinstance(cf, dict): |
|
t = str(cf.get("type", "")) |
|
if t.lower() == "batchimagecollatefuncion" or "Funcion" in t: |
|
cf["type"] = "BatchImageCollateFunction" |
|
block["collate_fn"] = cf |
|
else: |
|
block["collate_fn"] = {"type": "BatchImageCollateFunction"} |
|
|
|
ensure_and_patch_dl("train_dataloader", "train_img", "train_json", default_shuffle=True) |
|
ensure_and_patch_dl("val_dataloader", "val_img", "val_json", default_shuffle=False) |
|
|
|
_set_num_classes_safely(cfg, int(class_count)) |
|
|
|
applied_epoch = False |
|
for key in ("epoches", "max_epoch", "epochs", "num_epochs"): |
|
if key in cfg: |
|
cfg[key] = int(epochs) |
|
applied_epoch = True |
|
break |
|
if "solver" in cfg and isinstance(cfg["solver"], dict): |
|
for key in ("epoches", "max_epoch", "epochs", "num_epochs"): |
|
if key in cfg["solver"]: |
|
cfg["solver"][key] = int(epochs) |
|
applied_epoch = True |
|
break |
|
if not applied_epoch: |
|
cfg["epoches"] = int(epochs) |
|
cfg["input_size"] = int(imgsz) |
|
|
|
if "solver" not in cfg or not isinstance(cfg["solver"], dict): |
|
cfg["solver"] = {} |
|
sol = cfg["solver"] |
|
for key in ("base_lr", "lr", "learning_rate"): |
|
if key in sol: |
|
sol[key] = float(lr) |
|
break |
|
else: |
|
sol["base_lr"] = float(lr) |
|
sol["optimizer"] = str(optimizer).lower() |
|
if "train_dataloader" not in cfg or not isinstance(cfg["train_dataloader"], dict): |
|
sol["batch_size"] = int(batch) |
|
|
|
if "output_dir" in cfg: |
|
cfg["output_dir"] = paths["out_dir"] |
|
else: |
|
sol["output_dir"] = paths["out_dir"] |
|
|
|
if pretrained_path: |
|
p = os.path.abspath(pretrained_path) |
|
_maybe_set_model_field(cfg, "pretrain", p) |
|
_maybe_set_model_field(cfg, "pretrained", p) |
|
|
|
if not cfg.get("model"): |
|
cfg["model"] = {"type": "RTDETR", "num_classes": int(class_count)} |
|
|
|
cfg_out_dir = os.path.join(template_dir, "generated") |
|
os.makedirs(cfg_out_dir, exist_ok=True) |
|
out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml") |
|
|
|
class _NoFlowDumper(yaml.SafeDumper): ... |
|
def _repr_list_block(dumper, data): |
|
return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=False) |
|
_NoFlowDumper.add_representer(list, _repr_list_block) |
|
|
|
with open(out_path, "w", encoding="utf-8") as f: |
|
yaml.dump(cfg, f, Dumper=_NoFlowDumper, sort_keys=False, allow_unicode=True) |
|
return out_path |
|
|
|
def find_best_checkpoint(out_dir): |
|
pats = [ |
|
os.path.join(out_dir, "**", "best*.pt"), |
|
os.path.join(out_dir, "**", "best*.pth"), |
|
os.path.join(out_dir, "**", "model_best*.pt"), |
|
os.path.join(out_dir, "**", "model_best*.pth"), |
|
] |
|
for p in pats: |
|
f = sorted(glob(p, recursive=True)) |
|
if f: |
|
return f[0] |
|
any_ckpt = sorted( |
|
glob(os.path.join(out_dir, "**", "*.pt"), recursive=True) |
|
+ glob(os.path.join(out_dir, "**", "*.pth"), recursive=True) |
|
) |
|
return any_ckpt[-1] if any_ckpt else None |
|
|
|
|
|
def load_datasets_handler(api_key, url_file, progress=gr.Progress()): |
|
api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "") |
|
if not api_key: |
|
raise gr.Error("Roboflow API Key is required (or set ROBOFLOW_API_KEY).") |
|
if not url_file: |
|
raise gr.Error("Upload a .txt with Roboflow URLs or 'workspace/project[/vN]' lines.") |
|
|
|
with open(url_file.name, 'r', encoding='utf-8', errors='ignore') as f: |
|
urls = [line.strip() for line in f if line.strip()] |
|
|
|
dataset_info, failures = [], [] |
|
for i, raw in enumerate(urls): |
|
progress((i + 1) / max(1, len(urls)), desc=f"Parsing {i+1}/{len(urls)}") |
|
ws, proj, ver = parse_roboflow_url(raw) |
|
if not (ws and proj): |
|
failures.append((raw, "ParseError: could not resolve workspace/project")) |
|
continue |
|
if ver is None: |
|
ver = get_latest_version(api_key, ws, proj) |
|
if ver is None: |
|
failures.append((raw, f"No latest version for {ws}/{proj}")) |
|
continue |
|
loc, names, splits, name_str = download_dataset(api_key, ws, proj, int(ver)) |
|
if loc: |
|
dataset_info.append((loc, names, splits, name_str)) |
|
else: |
|
failures.append((raw, f"DownloadError: {ws}/{proj}/v{ver}")) |
|
|
|
if not dataset_info: |
|
msg = "No datasets loaded.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]]) |
|
raise gr.Error(msg) |
|
|
|
all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names}) |
|
class_map = {name: name for name in all_names} |
|
counts = gather_class_counts(dataset_info, class_map) |
|
df = pd.DataFrame([[n, n, counts.get(n, 0), False] for n in all_names], |
|
columns=["Original Name", "Rename To", "Max Images", "Remove"]) |
|
status = "Datasets loaded successfully." |
|
if failures: |
|
status += f" ({len(dataset_info)} OK, {len(failures)} failed; see logs)." |
|
return status, dataset_info, df |
|
|
|
def update_class_counts_handler(class_df, dataset_info): |
|
if class_df is None or not dataset_info: |
|
return None |
|
class_df = pd.DataFrame(class_df) |
|
mapping = {row["Original Name"]: (None if bool(row["Remove"]) else row["Rename To"]) |
|
for _, row in class_df.iterrows()} |
|
final_names = sorted(set(v for v in mapping.values() if v)) |
|
counts = {k: 0 for k in final_names} |
|
for loc, names, splits, _ in dataset_info: |
|
id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)} |
|
for split in splits: |
|
labels_dir = os.path.join(loc, split, 'labels') |
|
if not os.path.exists(labels_dir): |
|
continue |
|
for label_file in os.listdir(labels_dir): |
|
if not label_file.endswith('.txt'): |
|
continue |
|
found = set() |
|
with open(os.path.join(labels_dir, label_file), 'r', encoding="utf-8") as f: |
|
for line in f: |
|
parts = line.strip().split() |
|
if not parts: |
|
continue |
|
try: |
|
cls_id = int(parts[0]) |
|
mapped = id_to_final.get(cls_id, None) |
|
if mapped: |
|
found.add(mapped) |
|
except Exception: |
|
continue |
|
for m in found: |
|
counts[m] += 1 |
|
return pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"]) |
|
|
|
def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()): |
|
if not dataset_path: |
|
raise gr.Error("Finalize a dataset in Tab 2 before training.") |
|
|
|
train_script = find_training_script(REPO_DIR) |
|
logging.info(f"Resolved training script: {train_script}") |
|
if not train_script: |
|
raise gr.Error("RT-DETRv2 training script not found inside the repo (looked for **/tools/train.py).") |
|
|
|
base_cfg = find_model_config_template(model_key) |
|
if not base_cfg: |
|
raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).") |
|
|
|
data_yaml = os.path.join(dataset_path, "data.yaml") |
|
with open(data_yaml, "r", encoding="utf-8") as f: |
|
dy = yaml.safe_load(f) |
|
class_names = [str(x) for x in dy.get("names", [])] |
|
make_coco_annotations(dataset_path, class_names) |
|
|
|
out_dir = os.path.abspath(os.path.join("runs", "train", run_name)) |
|
os.makedirs(out_dir, exist_ok=True) |
|
|
|
pretrained_path = _ensure_checkpoint(model_key, out_dir) |
|
|
|
cfg_path = patch_base_config( |
|
base_cfg_path=base_cfg, |
|
merged_dir=dataset_path, |
|
class_count=len(class_names), |
|
run_name=run_name, |
|
epochs=epochs, |
|
batch=batch, |
|
imgsz=imgsz, |
|
lr=lr, |
|
optimizer=opt, |
|
pretrained_path=pretrained_path, |
|
) |
|
|
|
cmd = [sys.executable, train_script, "-c", os.path.abspath(cfg_path)] |
|
logging.info(f"Training command: {' '.join(cmd)}") |
|
|
|
q = Queue() |
|
def run_train(): |
|
try: |
|
train_cwd = os.path.dirname(train_script) |
|
|
|
shim_dir = tempfile.mkdtemp(prefix="rtdetr_site_") |
|
_install_workspace_shim_v3(shim_dir, module_default="rtdetrv2_pytorch.src") |
|
|
|
env = os.environ.copy() |
|
|
|
sly_shim_root = _install_supervisely_logger_shim() |
|
|
|
env["PYTHONPATH"] = os.pathsep.join(filter(None, [ |
|
shim_dir, |
|
train_cwd, |
|
PY_IMPL_DIR, |
|
REPO_DIR, |
|
sly_shim_root, |
|
env.get("PYTHONPATH", "") |
|
])) |
|
|
|
env.setdefault("WANDB_DISABLED", "true") |
|
env.setdefault("RTDETR_PYMODULE", "rtdetrv2_pytorch.src") |
|
env.setdefault("PYTHONUNBUFFERED", "1") |
|
if torch.cuda.is_available(): |
|
env.setdefault("CUDA_VISIBLE_DEVICES", "0") |
|
|
|
proc = subprocess.Popen(cmd, cwd=train_cwd, |
|
stdout=subprocess.PIPE, stderr=subprocess.STDOUT, |
|
bufsize=1, text=True, env=env) |
|
for line in proc.stdout: |
|
q.put(line.rstrip()) |
|
proc.wait() |
|
q.put(f"__EXITCODE__:{proc.returncode}") |
|
except Exception as e: |
|
q.put(f"__ERROR__:{e}") |
|
|
|
Thread(target=run_train, daemon=True).start() |
|
|
|
log_tail, last_epoch, total_epochs = [], 0, int(epochs) |
|
first_lines = [] |
|
line_no = 0 |
|
while True: |
|
line = q.get() |
|
if line.startswith("__EXITCODE__"): |
|
code = int(line.split(":", 1)[1]) |
|
if code != 0: |
|
head = "\n".join(first_lines[-200:]) |
|
raise gr.Error(f"Training exited with code {code}.\nLast output:\n{head or 'No logs captured.'}") |
|
break |
|
if line.startswith("__ERROR__"): |
|
raise gr.Error(f"Training failed: {line.split(':', 1)[1]}") |
|
|
|
if len(first_lines) < 2000: |
|
first_lines.append(line) |
|
log_tail.append(line) |
|
log_tail = log_tail[-40:] |
|
|
|
m = re.search(r"[Ee]poch\s+(\d+)\s*/\s*(\d+)", line) |
|
if m: |
|
try: |
|
last_epoch = int(m.group(1)) |
|
total_epochs = max(total_epochs, int(m.group(2))) |
|
except Exception: |
|
pass |
|
progress(min(max(last_epoch / max(1, total_epochs), 0.0), 1.0), desc=f"Epoch {last_epoch}/{total_epochs}") |
|
|
|
line_no += 1 |
|
fig1 = fig2 = None |
|
if line_no % 80 == 0: |
|
fig1 = plt.figure() |
|
plt.title("Loss (see logs)") |
|
plt.plot([0, last_epoch], [0, 0]) |
|
plt.tight_layout() |
|
|
|
fig2 = plt.figure() |
|
plt.title("mAP (see logs)") |
|
plt.plot([0, last_epoch], [0, 0]) |
|
plt.tight_layout() |
|
|
|
yield "\n".join(log_tail), fig1, fig2, None |
|
|
|
if fig1 is not None: |
|
plt.close(fig1) |
|
if fig2 is not None: |
|
plt.close(fig2) |
|
|
|
ckpt = find_best_checkpoint(out_dir) or find_best_checkpoint("runs") |
|
if not ckpt or not os.path.exists(ckpt): |
|
raise gr.Error("Training finished, but checkpoint file not found. Check logs/output directory.") |
|
yield "Training complete!", None, None, gr.File.update(value=ckpt, visible=True) |
|
|
|
def finalize_handler(dataset_info, class_df, progress=gr.Progress()): |
|
if not dataset_info: |
|
raise gr.Error("Load datasets first in Tab 1.") |
|
if class_df is None: |
|
raise gr.Error("Class data is missing.") |
|
class_df = pd.DataFrame(class_df) |
|
class_mapping, class_limits = {}, {} |
|
for _, row in class_df.iterrows(): |
|
orig = row["Original Name"] |
|
if bool(row["Remove"]): |
|
continue |
|
final_name = row["Rename To"] |
|
class_mapping[orig] = final_name |
|
class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"]) |
|
status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress) |
|
return status, path |
|
|
|
def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()): |
|
if not model_file: |
|
raise gr.Error("No trained model file to upload.") |
|
from huggingface_hub import HfApi, HfFolder |
|
hf_status = "Skipped Hugging Face." |
|
if hf_token and hf_repo: |
|
progress(0, desc="Uploading to Hugging Face...") |
|
try: |
|
api = HfApi(); HfFolder.save_token(hf_token) |
|
repo_url = api.create_repo(repo_id=hf_repo, exist_ok=True, token=hf_token) |
|
api.upload_file(model_file.name, os.path.basename(model_file.name), repo_id=hf_repo, token=hf_token) |
|
hf_status = f"Success! {repo_url}" |
|
except Exception as e: |
|
hf_status = f"Hugging Face Error: {e}" |
|
|
|
gh_status = "Skipped GitHub." |
|
if gh_token and gh_repo: |
|
progress(0.5, desc="Uploading to GitHub...") |
|
try: |
|
if '/' not in gh_repo: |
|
raise ValueError("GitHub repo must be 'username/repo'.") |
|
username, repo_name = gh_repo.split('/') |
|
api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}" |
|
headers = {"Authorization": f"token {gh_token}"} |
|
with open(model_file.name, "rb") as f: |
|
content = base64.b64encode(f.read()).decode() |
|
get_resp = requests.get(api_url, headers=headers, timeout=30) |
|
sha = get_resp.json().get('sha') if get_resp.ok else None |
|
data = {"message": "Upload trained model from Rolo app", "content": content} |
|
if sha: |
|
data["sha"] = sha |
|
put_resp = requests.put(api_url, headers=headers, json=data, timeout=60) |
|
if put_resp.ok: |
|
gh_status = f"Success! {put_resp.json()['content']['html_url']}" |
|
else: |
|
gh_status = f"GitHub Error: {put_resp.json().get('message','Unknown')}" |
|
except Exception as e: |
|
gh_status = f"GitHub Error: {e}" |
|
progress(1) |
|
return hf_status, gh_status |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app: |
|
gr.Markdown("# Rolo — RT-DETRv2 Trainer (Supervisely repo only)") |
|
|
|
dataset_info_state = gr.State([]) |
|
final_dataset_path_state = gr.State(None) |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("1. Prepare Datasets"): |
|
gr.Markdown("Upload a `.txt` with Roboflow URLs or `workspace/project[/vN]` per line. We’ll pull and merge them.") |
|
with gr.Row(): |
|
rf_api_key = gr.Textbox(label="Roboflow API Key (or set ROBOFLOW_API_KEY)", type="password", scale=2) |
|
rf_url_file = gr.File(label="Roboflow URLs (.txt)", file_types=[".txt"], scale=1) |
|
load_btn = gr.Button("Load Datasets", variant="primary") |
|
dataset_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
with gr.TabItem("2. Manage & Merge"): |
|
gr.Markdown("Rename/merge/remove classes and set per-class image caps. Then finalize.") |
|
with gr.Row(): |
|
class_df = gr.DataFrame(headers=["Original Name","Rename To","Max Images","Remove"], |
|
datatype=["str","str","number","bool"], label="Class Config", interactive=True, scale=3) |
|
with gr.Column(scale=1): |
|
class_count_summary_df = gr.DataFrame(label="Merged Class Counts Preview", |
|
headers=["Final Class Name","Est. Total Images"], interactive=False) |
|
update_counts_btn = gr.Button("Update Counts") |
|
finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary") |
|
finalize_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
with gr.TabItem("3. Configure & Train"): |
|
gr.Markdown("Pick RT-DETRv2 model, set hyper-params, press Start.") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
model_dd = gr.Dropdown(choices=[(label, value) for value, label in MODEL_CHOICES], |
|
value=DEFAULT_MODEL_KEY, |
|
label="Model (RT-DETRv2)") |
|
run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1") |
|
epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs") |
|
batch_sl = gr.Slider(1, 64, 16, step=1, label="Batch Size") |
|
imgsz_num = gr.Number(label="Image Size", value=640) |
|
lr_num = gr.Number(label="Learning Rate", value=0.001) |
|
opt_dd = gr.Dropdown(["Adam","AdamW","SGD"], value="Adam", label="Optimizer") |
|
train_btn = gr.Button("Start Training", variant="primary") |
|
with gr.Column(scale=2): |
|
train_status = gr.Textbox(label="Live Logs (tail)", interactive=False, lines=12) |
|
loss_plot = gr.Plot(label="Loss") |
|
map_plot = gr.Plot(label="mAP") |
|
final_model_file = gr.File(label="Download Trained Checkpoint", interactive=False, visible=False) |
|
|
|
with gr.TabItem("4. Upload Model"): |
|
gr.Markdown("Optionally push your checkpoint to Hugging Face / GitHub.") |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("**Hugging Face**") |
|
hf_token = gr.Textbox(label="HF Token", type="password") |
|
hf_repo = gr.Textbox(label="HF Repo (user/repo)") |
|
with gr.Column(): |
|
gr.Markdown("**GitHub**") |
|
gh_token = gr.Textbox(label="GitHub PAT", type="password") |
|
gh_repo = gr.Textbox(label="GitHub Repo (user/repo)") |
|
upload_btn = gr.Button("Upload", variant="primary") |
|
with gr.Row(): |
|
hf_status = gr.Textbox(label="Hugging Face Status", interactive=False) |
|
gh_status = gr.Textbox(label="GitHub Status", interactive=False) |
|
|
|
load_btn.click(load_datasets_handler, [rf_api_key, rf_url_file], |
|
[dataset_status, dataset_info_state, class_df]) |
|
update_counts_btn.click(update_class_counts_handler, [class_df, dataset_info_state], |
|
[class_count_summary_df]) |
|
finalize_btn.click(finalize_handler, [dataset_info_state, class_df], |
|
[finalize_status, final_dataset_path_state]) |
|
train_btn.click(training_handler, |
|
[final_dataset_path_state, model_dd, run_name_tb, epochs_sl, batch_sl, imgsz_num, lr_num, opt_dd], |
|
[train_status, loss_plot, map_plot, final_model_file]) |
|
upload_btn.click(upload_handler, [final_model_file, hf_token, hf_repo, gh_token, gh_repo], |
|
[hf_status, gh_status]) |
|
|
|
if __name__ == "__main__": |
|
try: |
|
ts = find_training_script(REPO_DIR) |
|
logging.info(f"Startup check — training script at: {ts}") |
|
except Exception as e: |
|
logging.warning(f"Startup training-script check failed: {e}") |
|
app.launch(debug=True) |