Spaces:
Sleeping
Sleeping
import os | |
import json | |
import random | |
import shutil | |
import tempfile | |
import time | |
from urllib.parse import urlparse | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
from roboflow import Roboflow | |
def parse_roboflow_url(url: str): | |
parsed = urlparse(url) | |
parts = parsed.path.strip('/').split('/') | |
workspace = parts[0] | |
project = parts[1] | |
try: | |
version = int(parts[-1]) | |
except ValueError: | |
version = int(parts[-2]) | |
return workspace, project, version | |
def convert_seg_to_bbox(api_key: str, dataset_url: str): | |
""" | |
1) Download segmentation dataset from Roboflow | |
2) Detect JSON‑vs‑mask export | |
3) Convert each mask/polygon to its bounding box (YOLO format) | |
4) Preserve original train/valid/test splits | |
5) Return before/after visuals + (dataset_path, detection_slug) | |
""" | |
rf = Roboflow(api_key=api_key) | |
ws, proj_name, ver = parse_roboflow_url(dataset_url) | |
version_obj = rf.workspace(ws).project(proj_name).version(ver) | |
dataset = version_obj.download("coco-segmentation") | |
root = dataset.location | |
# scan for any .json files | |
all_json = [] | |
for dp, _, files in os.walk(root): | |
for f in files: | |
if f.lower().endswith(".json"): | |
all_json.append(os.path.join(dp, f)) | |
if len(all_json) >= 3 and any("train" in os.path.basename(p).lower() for p in all_json): | |
# --- COCO‑JSON export branch --- | |
# locate train/valid/test JSONs | |
json_splits = {} | |
for path in all_json: | |
fn = os.path.basename(path).lower() | |
if "train" in fn: | |
json_splits["train"] = path | |
elif "val" in fn or "valid" in fn: | |
json_splits["valid"] = path | |
elif "test" in fn: | |
json_splits["test"] = path | |
if any(s not in json_splits for s in ("train", "valid", "test")): | |
raise RuntimeError(f"Missing one of train/valid/test JSONs: {json_splits}") | |
# build category → index from train.json | |
train_coco = json.load(open(json_splits["train"], "r")) | |
cat_ids = sorted(c["id"] for c in train_coco.get("categories", [])) | |
id2idx = {cid: i for i, cid in enumerate(cat_ids)} | |
# aggregate images_info & annotations | |
images_info = {} | |
annos = {} | |
for split, jf in json_splits.items(): | |
coco = json.load(open(jf, "r")) | |
for img in coco["images"]: | |
images_info[img["id"]] = img | |
for a in coco["annotations"]: | |
xs = a["segmentation"][0][0::2] | |
ys = a["segmentation"][0][1::2] | |
xmin, xmax = min(xs), max(xs) | |
ymin, ymax = min(ys), max(ys) | |
w, h = xmax - xmin, ymax - ymin | |
cx, cy = xmin + w/2, ymin + h/2 | |
iw = images_info[a["image_id"]]["width"] | |
ih = images_info[a["image_id"]]["height"] | |
line = ( | |
f"{id2idx[a['category_id']]} " | |
f"{cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}" | |
) | |
annos.setdefault(a["image_id"], []).append(line) | |
# build filename → path map | |
name2id = {img["file_name"]: img["id"] for img in images_info.values()} | |
filemap = { | |
f: os.path.join(dp, f) | |
for dp, _, files in os.walk(root) | |
for f in files | |
if f in name2id | |
} | |
# write out per‑split folders | |
out_root = tempfile.mkdtemp(prefix="yolov8_") | |
for split in ("train", "valid", "test"): | |
coco = json.load(open(json_splits[split], "r")) | |
img_dir = os.path.join(out_root, split, "images") | |
lbl_dir = os.path.join(out_root, split, "labels") | |
os.makedirs(img_dir, exist_ok=True) | |
os.makedirs(lbl_dir, exist_ok=True) | |
for img in coco["images"]: | |
fn = img["file_name"] | |
src = filemap[fn] | |
dst = os.path.join(img_dir, fn) | |
txtp = os.path.join(lbl_dir, fn.rsplit(".", 1)[0] + ".txt") | |
shutil.copy(src, dst) | |
with open(txtp, "w") as f: | |
f.write("\n".join(annos.get(img["id"], []))) | |
else: | |
# --- Segmentation‐Masks export branch --- | |
splits = ["train", "valid", "test"] | |
# detect masks subfolder name | |
mask_names = ("masks", "mask", "labels") | |
out_root = tempfile.mkdtemp(prefix="yolov8_") | |
for split in splits: | |
img_dir_src = os.path.join(root, split, "images") | |
# find which subdir holds the PNG masks | |
mdir = None | |
for m in mask_names: | |
candidate = os.path.join(root, split, m) | |
if os.path.isdir(candidate): | |
mdir = candidate | |
break | |
if mdir is None: | |
raise RuntimeError(f"No masks folder found under {split}/ (checked {mask_names})") | |
img_dir_dst = os.path.join(out_root, split, "images") | |
lbl_dir_dst = os.path.join(out_root, split, "labels") | |
os.makedirs(img_dir_dst, exist_ok=True) | |
os.makedirs(lbl_dir_dst, exist_ok=True) | |
for fn in os.listdir(img_dir_src): | |
if not fn.lower().endswith((".jpg", ".png")): | |
continue | |
src_img = os.path.join(img_dir_src, fn) | |
src_mask = os.path.join(mdir, fn) | |
img = cv2.imread(src_img) | |
h, w = img.shape[:2] | |
# read mask & binarize | |
mask = cv2.imread(src_mask, cv2.IMREAD_GRAYSCALE) | |
_, binm = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) | |
ys, xs = np.nonzero(binm) | |
if len(xs) == 0: | |
lines = [] | |
else: | |
xmin, xmax = xs.min(), xs.max() | |
ymin, ymax = ys.min(), ys.max() | |
bw, bh = xmax - xmin, ymax - ymin | |
cx, cy = xmin + bw/2, ymin + bh/2 | |
# assume single class → index 0 | |
lines = [f"0 {cx/w:.6f} {cy/h:.6f} {bw/w:.6f} {bh/h:.6f}"] | |
# copy image + write YOLO text | |
dst_img = os.path.join(img_dir_dst, fn) | |
dst_txt = os.path.join(lbl_dir_dst, fn.rsplit(".",1)[0] + ".txt") | |
shutil.copy(src_img, dst_img) | |
with open(dst_txt, "w") as f: | |
f.write("\n".join(lines)) | |
# --- prepare before/after galleries (random sample across out_root) --- | |
before, after = [], [] | |
all_imgs = [] | |
for split in ("train","valid","test"): | |
for fn in os.listdir(os.path.join(out_root, split, "images")): | |
path = os.path.join(out_root, split, "images", fn) | |
all_imgs.append(path) | |
sample = random.sample(all_imgs, min(5, len(all_imgs))) | |
for img_path in sample: | |
fn = os.path.basename(img_path) | |
img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) | |
# draw mask outline if available (JSON branch) else read mask again | |
seg_vis = img.copy() | |
box_vis = img.copy() | |
# overlay all .txt bboxes | |
txtp = img_path.replace("/images/", "/labels/").rsplit(".",1)[0] + ".txt" | |
w, h = img.shape[1], img.shape[0] | |
for line in open(txtp): | |
_, cxn, cyn, wnorm, hnorm = map(float, line.split()) | |
bw, bh = int(wnorm * w), int(hnorm * h) | |
x0 = int(cxn * w - bw/2) | |
y0 = int(cyn * h - bh/2) | |
cv2.rectangle(box_vis, (x0,y0), (x0+bw, y0+bh), (0,255,0), 2) | |
before.append(Image.fromarray(seg_vis)) | |
after.append(Image.fromarray(box_vis)) | |
detection_slug = proj_name + "-detection" | |
return before, after, out_root, detection_slug | |
def upload_and_train_detection( | |
api_key: str, | |
detection_slug: str, | |
dataset_path: str, | |
project_license: str = "MIT", | |
project_type: str = "object-detection" | |
): | |
rf = Roboflow(api_key=api_key) | |
ws = rf.workspace() | |
# get or create project | |
try: | |
proj = ws.project(detection_slug) | |
except Exception as e: | |
if "does not exist" in str(e).lower(): | |
proj = ws.create_project( | |
detection_slug, | |
annotation=project_type, | |
project_type=project_type, | |
project_license=project_license | |
) | |
else: | |
raise | |
# upload and kick off train | |
_, real_slug = proj.id.rsplit("/", 1) | |
ws.upload_dataset(dataset_path, real_slug, | |
project_license=project_license, | |
project_type=project_type) | |
try: | |
version_num = proj.generate_version(settings={"augmentation":{}, "preprocessing":{}}) | |
except RuntimeError as e: | |
msg = str(e).lower() | |
if "unsupported request" in msg or "does not exist" in msg: | |
# slug bump fallback | |
new_slug = real_slug + "-v2" | |
proj = ws.create_project( | |
new_slug, annotation=project_type, | |
project_type=project_type, | |
project_license=project_license | |
) | |
ws.upload_dataset(dataset_path, new_slug, | |
project_license=project_license, | |
project_type=project_type) | |
version_num = proj.generate_version(settings={"augmentation":{}, "preprocessing":{}}) | |
else: | |
raise | |
# wait for generation then train | |
for _ in range(20): | |
try: | |
model = proj.version(str(version_num)).train() | |
break | |
except RuntimeError as e: | |
if "still generating" in str(e).lower(): | |
time.sleep(5) | |
continue | |
else: | |
raise | |
else: | |
raise RuntimeError("Version generation timed out, try again later.") | |
return f"{model['base_url']}{model['id']}?api_key={api_key}" | |
# --- Gradio UI --- | |
with gr.Blocks() as app: | |
gr.Markdown("## 🔄 Seg→BBox + Auto‑Upload/Train") | |
api_input = gr.Textbox(label="Roboflow API Key", type="password") | |
url_input = gr.Textbox(label="Segmentation Dataset URL") | |
run_btn = gr.Button("Convert to BBoxes") | |
before_g = gr.Gallery(columns=5, label="Before") | |
after_g = gr.Gallery(columns=5, label="After") | |
ds_state = gr.Textbox(visible=False, label="Dataset Path") | |
slug_state= gr.Textbox(visible=False, label="Detection Slug") | |
run_btn.click( | |
convert_seg_to_bbox, | |
inputs=[api_input, url_input], | |
outputs=[before_g, after_g, ds_state, slug_state] | |
) | |
gr.Markdown("## 🚀 Upload & Train Detection Model") | |
train_btn = gr.Button("Upload & Train") | |
url_out = gr.Textbox(label="Hosted Model URL") | |
train_btn.click( | |
upload_and_train_detection, | |
inputs=[api_input, slug_state, ds_state], | |
outputs=[url_out] | |
) | |
if __name__ == "__main__": | |
app.launch() | |