segtodetect / app.py
wuhp's picture
Update app.py
a711e94 verified
raw
history blame
6.48 kB
import os
import json
import random
import shutil
import tempfile
from urllib.parse import urlparse
import cv2
import numpy as np
from PIL import Image
import gradio as gr
from roboflow import Roboflow
def parse_roboflow_url(url: str):
"""Extract workspace, project name, and version from a Roboflow URL."""
parsed = urlparse(url)
parts = parsed.path.strip('/').split('/')
workspace = parts[0]
project = parts[1]
try:
version = int(parts[-1])
except ValueError:
version = int(parts[-2])
return workspace, project, version
def convert_seg_to_bbox(api_key: str, dataset_url: str):
"""
1) Download segmentation dataset from Roboflow
2) Convert masks β†’ YOLOv8 bboxes
Returns before_gallery, after_gallery, local_dataset_path, project_slug
"""
rf = Roboflow(api_key=api_key)
ws, proj, ver = parse_roboflow_url(dataset_url)
version_obj = rf.workspace(ws).project(proj).version(ver)
dataset = version_obj.download("coco-segmentation")
root = dataset.location
# find annotation JSON
ann_file = None
for dp, _, files in os.walk(root):
for f in files:
if 'train' in f.lower() and f.lower().endswith('.json'):
ann_file = os.path.join(dp, f)
break
if ann_file:
break
if not ann_file:
for dp, _, files in os.walk(root):
for f in files:
if f.lower().endswith('.json'):
ann_file = os.path.join(dp, f)
break
if ann_file:
break
if not ann_file:
raise FileNotFoundError(f"No JSON annotations found under {root}")
coco = json.load(open(ann_file, 'r'))
images_info = {img['id']: img for img in coco['images']}
cat_ids = sorted(c['id'] for c in coco.get('categories', []))
id_to_index = {cid: idx for idx, cid in enumerate(cat_ids)}
# prepare YOLOv8 folders
out_root = tempfile.mkdtemp(prefix="yolov8_")
img_out = os.path.join(out_root, "images")
lbl_out = os.path.join(out_root, "labels")
os.makedirs(img_out, exist_ok=True)
os.makedirs(lbl_out, exist_ok=True)
# convert seg β†’ bbox
annos = {}
for a in coco['annotations']:
pid = a['image_id']
poly = a['segmentation'][0]
xs, ys = poly[0::2], poly[1::2]
xmin, xmax, ymin, ymax = min(xs), max(xs), min(ys), max(ys)
w, h = xmax - xmin, ymax - ymin
cx, cy = xmin + w/2, ymin + h/2
iw, ih = images_info[pid]['width'], images_info[pid]['height']
line = f"{id_to_index[a['category_id']]} {cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}"
annos.setdefault(pid, []).append(line)
# locate images and write labels
img_dir = None
for dp, _, files in os.walk(root):
if any(f.lower().endswith(('.jpg','.png','jpeg')) for f in files):
img_dir = dp
break
if not img_dir:
raise FileNotFoundError(f"No image files found under {root}")
fname2id = {img['file_name']: img['id'] for img in coco['images']}
for fname, pid in fname2id.items():
src = os.path.join(img_dir, fname)
if not os.path.isfile(src):
continue
shutil.copy(src, os.path.join(img_out, fname))
with open(os.path.join(lbl_out, fname.rsplit('.',1)[0] + ".txt"), 'w') as lf:
lf.write("\n".join(annos.get(pid, [])))
# build preview galleries
before, after = [], []
sample = random.sample(list(fname2id.keys()), min(5, len(fname2id)))
for fn in sample:
img = cv2.cvtColor(cv2.imread(os.path.join(img_dir, fn)), cv2.COLOR_BGR2RGB)
seg_vis = img.copy()
for a in coco['annotations']:
if a['image_id'] != fname2id[fn]:
continue
pts = np.array(a['segmentation'][0], np.int32).reshape(-1,2)
cv2.polylines(seg_vis, [pts], True, (255,0,0), 2)
box_vis = img.copy()
for line in annos.get(fname2id[fn], []):
_, cxn, cyn, wnorm, hnorm = map(float, line.split())
iw, ih = images_info[fname2id[fn]]['width'], images_info[fname2id[fn]]['height']
w0, h0 = int(wnorm*iw), int(hnorm*ih)
x0, y0 = int(cxn*iw - w0/2), int(cyn*ih - h0/2)
cv2.rectangle(box_vis, (x0,y0), (x0+w0,y0+h0), (0,255,0), 2)
before.append(Image.fromarray(seg_vis))
after.append(Image.fromarray(box_vis))
return before, after, out_root, proj # proj is our slug
def upload_and_train_detection(api_key: str, project_slug: str, dataset_path: str):
"""
1) Upload local YOLOv8 dataset to Roboflow
2) Generate & train a new detection version
Returns the hosted inference endpoint URL.
"""
rf = Roboflow(api_key=api_key)
ws = rf.workspace()
# upload dataset
ws.upload_dataset(
dataset_path,
project_slug,
project_license="MIT",
project_type="object-detection"
)
# generate a new version
proj = ws.project(project_slug)
version_number = proj.generate_version(preprocessing={}, augmentation={})
# train model (fast)
proj.version(version_number).train(speed="fast")
# fetch hosted endpoint
m = proj.version(str(version_number)).model
return f"{m['base_url']}{m['id']}?api_key={api_key}"
with gr.Blocks() as app:
gr.Markdown("## πŸ”„ Segmentation β†’ YOLOv8 Converter + Auto Trainer")
# Converter UI
api_input = gr.Textbox(label="Roboflow API Key", type="password")
url_input = gr.Textbox(label="Segmentation Dataset URL")
convert_btn = gr.Button("Convert to BBoxes")
before_gal = gr.Gallery(label="Before (Segmentation)", columns=5)
after_gal = gr.Gallery(label="After (BBoxes)", columns=5)
state_path = gr.State()
state_slug = gr.State()
convert_btn.click(
fn=convert_seg_to_bbox,
inputs=[api_input, url_input],
outputs=[before_gal, after_gal, state_path, state_slug]
)
# Train UI
train_btn = gr.Button("Upload & Train Detection Model")
endpoint_text = gr.Textbox(label="Hosted Detection Endpoint URL")
train_btn.click(
fn=upload_and_train_detection,
inputs=[api_input, state_slug, state_path],
outputs=[endpoint_text]
)
gr.Markdown("> First convert your seg data, then click **Upload & Train** to deploy your detection model.")
if __name__ == "__main__":
app.launch()