segtodetect / app.py
wuhp's picture
Update app.py
8ece6c4 verified
import os
import json
import random
import shutil
import tempfile
import time
from urllib.parse import urlparse
import cv2
import numpy as np
from PIL import Image
import gradio as gr
from roboflow import Roboflow
def parse_roboflow_url(url: str):
parsed = urlparse(url)
parts = parsed.path.strip('/').split('/')
workspace = parts[0]
project = parts[1]
try:
version = int(parts[-1])
except ValueError:
version = int(parts[-2])
return workspace, project, version
def convert_seg_to_bbox(api_key: str, dataset_url: str):
"""
1) Download segmentation dataset from Roboflow
2) Detect JSON‑vs‑mask export
3) Convert each mask/polygon to its bounding box (YOLO format)
4) Preserve original train/valid/test splits
5) Return before/after visuals + (dataset_path, detection_slug)
"""
rf = Roboflow(api_key=api_key)
ws, proj_name, ver = parse_roboflow_url(dataset_url)
version_obj = rf.workspace(ws).project(proj_name).version(ver)
dataset = version_obj.download("coco-segmentation")
root = dataset.location
# scan for any .json files
all_json = []
for dp, _, files in os.walk(root):
for f in files:
if f.lower().endswith(".json"):
all_json.append(os.path.join(dp, f))
if len(all_json) >= 3 and any("train" in os.path.basename(p).lower() for p in all_json):
# --- COCO‑JSON export branch ---
# locate train/valid/test JSONs
json_splits = {}
for path in all_json:
fn = os.path.basename(path).lower()
if "train" in fn:
json_splits["train"] = path
elif "val" in fn or "valid" in fn:
json_splits["valid"] = path
elif "test" in fn:
json_splits["test"] = path
if any(s not in json_splits for s in ("train", "valid", "test")):
raise RuntimeError(f"Missing one of train/valid/test JSONs: {json_splits}")
# build category → index from train.json
train_coco = json.load(open(json_splits["train"], "r"))
cat_ids = sorted(c["id"] for c in train_coco.get("categories", []))
id2idx = {cid: i for i, cid in enumerate(cat_ids)}
# aggregate images_info & annotations
images_info = {}
annos = {}
for split, jf in json_splits.items():
coco = json.load(open(jf, "r"))
for img in coco["images"]:
images_info[img["id"]] = img
for a in coco["annotations"]:
xs = a["segmentation"][0][0::2]
ys = a["segmentation"][0][1::2]
xmin, xmax = min(xs), max(xs)
ymin, ymax = min(ys), max(ys)
w, h = xmax - xmin, ymax - ymin
cx, cy = xmin + w/2, ymin + h/2
iw = images_info[a["image_id"]]["width"]
ih = images_info[a["image_id"]]["height"]
line = (
f"{id2idx[a['category_id']]} "
f"{cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}"
)
annos.setdefault(a["image_id"], []).append(line)
# build filename → path map
name2id = {img["file_name"]: img["id"] for img in images_info.values()}
filemap = {
f: os.path.join(dp, f)
for dp, _, files in os.walk(root)
for f in files
if f in name2id
}
# write out per‑split folders
out_root = tempfile.mkdtemp(prefix="yolov8_")
for split in ("train", "valid", "test"):
coco = json.load(open(json_splits[split], "r"))
img_dir = os.path.join(out_root, split, "images")
lbl_dir = os.path.join(out_root, split, "labels")
os.makedirs(img_dir, exist_ok=True)
os.makedirs(lbl_dir, exist_ok=True)
for img in coco["images"]:
fn = img["file_name"]
src = filemap[fn]
dst = os.path.join(img_dir, fn)
txtp = os.path.join(lbl_dir, fn.rsplit(".", 1)[0] + ".txt")
shutil.copy(src, dst)
with open(txtp, "w") as f:
f.write("\n".join(annos.get(img["id"], [])))
else:
# --- Segmentation‐Masks export branch ---
splits = ["train", "valid", "test"]
# detect masks subfolder name
mask_names = ("masks", "mask", "labels")
out_root = tempfile.mkdtemp(prefix="yolov8_")
for split in splits:
img_dir_src = os.path.join(root, split, "images")
# find which subdir holds the PNG masks
mdir = None
for m in mask_names:
candidate = os.path.join(root, split, m)
if os.path.isdir(candidate):
mdir = candidate
break
if mdir is None:
raise RuntimeError(f"No masks folder found under {split}/ (checked {mask_names})")
img_dir_dst = os.path.join(out_root, split, "images")
lbl_dir_dst = os.path.join(out_root, split, "labels")
os.makedirs(img_dir_dst, exist_ok=True)
os.makedirs(lbl_dir_dst, exist_ok=True)
for fn in os.listdir(img_dir_src):
if not fn.lower().endswith((".jpg", ".png")):
continue
src_img = os.path.join(img_dir_src, fn)
src_mask = os.path.join(mdir, fn)
img = cv2.imread(src_img)
h, w = img.shape[:2]
# read mask & binarize
mask = cv2.imread(src_mask, cv2.IMREAD_GRAYSCALE)
_, binm = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
ys, xs = np.nonzero(binm)
if len(xs) == 0:
lines = []
else:
xmin, xmax = xs.min(), xs.max()
ymin, ymax = ys.min(), ys.max()
bw, bh = xmax - xmin, ymax - ymin
cx, cy = xmin + bw/2, ymin + bh/2
# assume single class → index 0
lines = [f"0 {cx/w:.6f} {cy/h:.6f} {bw/w:.6f} {bh/h:.6f}"]
# copy image + write YOLO text
dst_img = os.path.join(img_dir_dst, fn)
dst_txt = os.path.join(lbl_dir_dst, fn.rsplit(".",1)[0] + ".txt")
shutil.copy(src_img, dst_img)
with open(dst_txt, "w") as f:
f.write("\n".join(lines))
# --- prepare before/after galleries (random sample across out_root) ---
before, after = [], []
all_imgs = []
for split in ("train","valid","test"):
for fn in os.listdir(os.path.join(out_root, split, "images")):
path = os.path.join(out_root, split, "images", fn)
all_imgs.append(path)
sample = random.sample(all_imgs, min(5, len(all_imgs)))
for img_path in sample:
fn = os.path.basename(img_path)
img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
# draw mask outline if available (JSON branch) else read mask again
seg_vis = img.copy()
box_vis = img.copy()
# overlay all .txt bboxes
txtp = img_path.replace("/images/", "/labels/").rsplit(".",1)[0] + ".txt"
w, h = img.shape[1], img.shape[0]
for line in open(txtp):
_, cxn, cyn, wnorm, hnorm = map(float, line.split())
bw, bh = int(wnorm * w), int(hnorm * h)
x0 = int(cxn * w - bw/2)
y0 = int(cyn * h - bh/2)
cv2.rectangle(box_vis, (x0,y0), (x0+bw, y0+bh), (0,255,0), 2)
before.append(Image.fromarray(seg_vis))
after.append(Image.fromarray(box_vis))
detection_slug = proj_name + "-detection"
return before, after, out_root, detection_slug
def upload_and_train_detection(
api_key: str,
detection_slug: str,
dataset_path: str,
project_license: str = "MIT",
project_type: str = "object-detection"
):
rf = Roboflow(api_key=api_key)
ws = rf.workspace()
# get or create project
try:
proj = ws.project(detection_slug)
except Exception as e:
if "does not exist" in str(e).lower():
proj = ws.create_project(
detection_slug,
annotation=project_type,
project_type=project_type,
project_license=project_license
)
else:
raise
# upload and kick off train
_, real_slug = proj.id.rsplit("/", 1)
ws.upload_dataset(dataset_path, real_slug,
project_license=project_license,
project_type=project_type)
try:
version_num = proj.generate_version(settings={"augmentation":{}, "preprocessing":{}})
except RuntimeError as e:
msg = str(e).lower()
if "unsupported request" in msg or "does not exist" in msg:
# slug bump fallback
new_slug = real_slug + "-v2"
proj = ws.create_project(
new_slug, annotation=project_type,
project_type=project_type,
project_license=project_license
)
ws.upload_dataset(dataset_path, new_slug,
project_license=project_license,
project_type=project_type)
version_num = proj.generate_version(settings={"augmentation":{}, "preprocessing":{}})
else:
raise
# wait for generation then train
for _ in range(20):
try:
model = proj.version(str(version_num)).train()
break
except RuntimeError as e:
if "still generating" in str(e).lower():
time.sleep(5)
continue
else:
raise
else:
raise RuntimeError("Version generation timed out, try again later.")
return f"{model['base_url']}{model['id']}?api_key={api_key}"
# --- Gradio UI ---
with gr.Blocks() as app:
gr.Markdown("## 🔄 Seg→BBox + Auto‑Upload/Train")
api_input = gr.Textbox(label="Roboflow API Key", type="password")
url_input = gr.Textbox(label="Segmentation Dataset URL")
run_btn = gr.Button("Convert to BBoxes")
before_g = gr.Gallery(columns=5, label="Before")
after_g = gr.Gallery(columns=5, label="After")
ds_state = gr.Textbox(visible=False, label="Dataset Path")
slug_state= gr.Textbox(visible=False, label="Detection Slug")
run_btn.click(
convert_seg_to_bbox,
inputs=[api_input, url_input],
outputs=[before_g, after_g, ds_state, slug_state]
)
gr.Markdown("## 🚀 Upload & Train Detection Model")
train_btn = gr.Button("Upload & Train")
url_out = gr.Textbox(label="Hosted Model URL")
train_btn.click(
upload_and_train_detection,
inputs=[api_input, slug_state, ds_state],
outputs=[url_out]
)
if __name__ == "__main__":
app.launch()