segtodetect / app.py
wuhp's picture
Update app.py
eecca4c verified
raw
history blame
7.18 kB
import os
import json
import random
import shutil
import tempfile
from urllib.parse import urlparse
import cv2
import numpy as np
from PIL import Image
import gradio as gr
from roboflow import Roboflow
def parse_roboflow_url(url: str):
"""Extract workspace, project name, and version from a Roboflow URL."""
parsed = urlparse(url)
parts = parsed.path.strip('/').split('/')
workspace = parts[0]
project = parts[1]
try:
version = int(parts[-1])
except ValueError:
version = int(parts[-2])
return workspace, project, version
def convert_seg_to_bbox(api_key: str, dataset_url: str):
"""
Download a segmentation dataset from Roboflow,
convert masks β†’ YOLOv8 bboxes,
and return (before, after) galleries + local YOLO dataset path + auto slug.
"""
rf = Roboflow(api_key=api_key)
ws_name, seg_proj_slug, ver = parse_roboflow_url(dataset_url)
version_obj = rf.workspace(ws_name).project(seg_proj_slug).version(ver)
dataset = version_obj.download("coco-segmentation")
root = dataset.location
# find the annotation JSON
ann_file = None
for dp, _, files in os.walk(root):
for f in files:
if f.lower().endswith(".json") and "train" in f.lower():
ann_file = os.path.join(dp, f)
break
if ann_file:
break
if not ann_file:
for dp, _, files in os.walk(root):
for f in files:
if f.lower().endswith(".json"):
ann_file = os.path.join(dp, f)
break
if ann_file:
break
if not ann_file:
raise FileNotFoundError(f"No JSON annotations found under {root}")
coco = json.load(open(ann_file, "r"))
images_info = {img["id"]: img for img in coco["images"]}
cat_ids = sorted(c["id"] for c in coco.get("categories", []))
id_to_index = {cid: idx for idx, cid in enumerate(cat_ids)}
# prepare YOLOv8 folders
out_root = tempfile.mkdtemp(prefix="yolov8_")
img_out = os.path.join(out_root, "images")
lbl_out = os.path.join(out_root, "labels")
os.makedirs(img_out, exist_ok=True)
os.makedirs(lbl_out, exist_ok=True)
# convert seg β†’ bbox labels
annos = {}
for a in coco["annotations"]:
pid = a["image_id"]
poly = a["segmentation"][0]
xs, ys = poly[0::2], poly[1::2]
xmin, xmax, ymin, ymax = min(xs), max(xs), min(ys), max(ys)
w, h = xmax - xmin, ymax - ymin
cx, cy = xmin + w/2, ymin + h/2
iw, ih = images_info[pid]["width"], images_info[pid]["height"]
line = f"{id_to_index[a['category_id']]} {cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}"
annos.setdefault(pid, []).append(line)
# locate images and write labels
img_dir = None
for dp, _, files in os.walk(root):
if any(f.lower().endswith((".jpg", ".png", ".jpeg")) for f in files):
img_dir = dp
break
if not img_dir:
raise FileNotFoundError(f"No images found under {root}")
fname2id = {img["file_name"]: img["id"] for img in coco["images"]}
for fname, pid in fname2id.items():
src = os.path.join(img_dir, fname)
if not os.path.isfile(src):
continue
shutil.copy(src, os.path.join(img_out, fname))
with open(os.path.join(lbl_out, fname.rsplit(".", 1)[0] + ".txt"), "w") as lf:
lf.write("\n".join(annos.get(pid, [])))
# build preview galleries
before, after = [], []
sample = random.sample(list(fname2id.keys()), min(5, len(fname2id)))
for fn in sample:
img = cv2.cvtColor(cv2.imread(os.path.join(img_dir, fn)), cv2.COLOR_BGR2RGB)
seg_vis = img.copy()
for a in coco["annotations"]:
if a["image_id"] != fname2id[fn]:
continue
pts = np.array(a["segmentation"][0], np.int32).reshape(-1, 2)
cv2.polylines(seg_vis, [pts], True, (255, 0, 0), 2)
box_vis = img.copy()
for line in annos.get(fname2id[fn], []):
_, cxn, cyn, wnorm, hnorm = map(float, line.split())
iw, ih = images_info[fname2id[fn]]["width"], images_info[fname2id[fn]]["height"]
w0, h0 = int(wnorm * iw), int(hnorm * ih)
x0 = int(cxn * iw - w0/2)
y0 = int(cyn * ih - h0/2)
cv2.rectangle(box_vis, (x0, y0), (x0 + w0, y0 + h0), (0, 255, 0), 2)
before.append(Image.fromarray(seg_vis))
after.append(Image.fromarray(box_vis))
# auto‐slug for the detection project
detection_slug = f"{seg_proj_slug}-detection"
return before, after, out_root, detection_slug
def upload_and_train_detection(api_key: str, project_slug: str, dataset_path: str):
"""
Given a YOLOv8 dataset folder, upload β†’ version β†’ train β†’
return inference endpoint URL. Auto‐creates the project if needed.
"""
rf = Roboflow(api_key=api_key)
ws = rf.workspace()
# get or create the detection project
try:
proj = ws.project(project_slug)
except Exception:
proj = ws.create_project(
project_name=project_slug,
project_type="object-detection",
project_license="MIT"
)
# upload the dataset
ws.upload_dataset(
dataset_path,
proj.id,
num_workers=10,
project_license="MIT",
project_type="object-detection",
batch_name=None,
num_retries=0
)
# generate a new version
new_v = proj.generate_version(settings={"preprocessing": {}, "augmentation": {}})
# train (fast)
version = proj.version(new_v)
version.train(speed="fast")
# return the hosted inference URL
m = version.model
return f"{m['base_url']}{m['id']}?api_key={api_key}"
with gr.Blocks() as app:
gr.Markdown("## πŸ”„ Segmentation β†’ YOLOv8 + πŸ“‘ Auto‑Deploy Detector")
# ─ Convert UI ─────────────────────────────────────────
api = gr.Textbox(label="Roboflow API Key", type="password")
segurl = gr.Textbox(label="Segmentation Dataset URL")
btn_c = gr.Button("Convert to YOLOv8 BBoxes")
out_b = gr.Gallery(label="Before (masks)")
out_a = gr.Gallery(label="After (bboxes)")
state_path = gr.State()
state_slug = gr.State()
btn_c.click(
fn=convert_seg_to_bbox,
inputs=[api, segurl],
outputs=[out_b, out_a, state_path, state_slug]
)
gr.Markdown("---")
# ─ Train UI ───────────────────────────────────────────
btn_t = gr.Button("Upload & Train Detection Model")
endpoint = gr.Textbox(label="Hosted Detection Endpoint URL")
btn_t.click(
fn=upload_and_train_detection,
inputs=[api, state_slug, state_path],
outputs=[endpoint]
)
gr.Markdown(
"> 1) Paste your segmentation URL and Convert. \n"
"> 2) Then Upload & Train to instantly get your detector’s endpoint."
)
if __name__ == "__main__":
app.launch()