segtodetect / app.py
wuhp's picture
Update app.py
3911cda verified
raw
history blame
7.84 kB
import os
import json
import random
import shutil
import tempfile
from urllib.parse import urlparse
import cv2
import numpy as np
from PIL import Image
import gradio as gr
from roboflow import Roboflow
def parse_roboflow_url(url: str):
parsed = urlparse(url)
parts = parsed.path.strip('/').split('/')
workspace = parts[0]
project = parts[1]
try:
version = int(parts[-1])
except ValueError:
version = int(parts[-2])
return workspace, project, version
def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1, 0.1)):
# --- download segmentation export
rf = Roboflow(api_key=api_key)
ws, proj_name, ver = parse_roboflow_url(dataset_url)
version_obj = rf.workspace(ws).project(proj_name).version(ver)
dataset = version_obj.download("coco-segmentation")
root = dataset.location
# --- find the COCO JSON
ann_file = None
for dp, _, files in os.walk(root):
for f in files:
if f.lower().endswith('.json'):
ann_file = os.path.join(dp, f)
break
if ann_file:
break
if not ann_file:
raise FileNotFoundError(f"No JSON annotations under {root}")
coco = json.load(open(ann_file, 'r'))
images_info = {img['id']: img for img in coco['images']}
cat_ids = sorted(c['id'] for c in coco.get('categories', []))
id_to_index = {cid: idx for idx, cid in enumerate(cat_ids)}
# --- prepare flat_images + flat_labels
out_root = tempfile.mkdtemp(prefix="yolov8_")
flat_img = os.path.join(out_root, "flat_images")
flat_lbl = os.path.join(out_root, "flat_labels")
os.makedirs(flat_img, exist_ok=True)
os.makedirs(flat_lbl, exist_ok=True)
# --- convert each segmentation → YOLO bbox line
annos = {}
for anno in coco['annotations']:
img_id = anno['image_id']
poly = anno['segmentation'][0]
xs, ys = poly[0::2], poly[1::2]
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
w, h = x_max - x_min, y_max - y_min
cx, cy = x_min + w/2, y_min + h/2
iw, ih = images_info[img_id]['width'], images_info[img_id]['height']
line = (
f"{id_to_index[anno['category_id']]} "
f"{cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}"
)
annos.setdefault(img_id, []).append(line)
# --- locate the single folder of raw images
img_src = None
for dp, _, files in os.walk(root):
if any(f.lower().endswith(('.jpg','.png','.jpeg')) for f in files):
img_src = dp
break
if not img_src:
raise FileNotFoundError(f"No images under {root}")
# --- copy images + write flat label files
name_to_id = {img['file_name']: img['id'] for img in coco['images']}
for fname, img_id in name_to_id.items():
src_path = os.path.join(img_src, fname)
if not os.path.isfile(src_path):
continue
shutil.copy(src_path, os.path.join(flat_img, fname))
with open(os.path.join(flat_lbl, fname.rsplit('.',1)[0] + ".txt"), 'w') as lf:
lf.write("\n".join(annos.get(img_id, [])))
# --- split into train/valid/test
all_files = sorted(
f for f in os.listdir(flat_img)
if f.lower().endswith(('.jpg','.png','.jpeg'))
)
random.shuffle(all_files)
n = len(all_files)
n_train = max(1, int(n * split_ratios[0]))
n_valid = max(1, int(n * split_ratios[1]))
# ensure at least 1 for each split
n_valid = min(n_valid, n - n_train - 1)
splits = {
"train": all_files[:n_train],
"valid": all_files[n_train:n_train+n_valid],
"test": all_files[n_train+n_valid:]
}
# --- build Roboflow‐friendly folder structure
for split, files in splits.items():
out_img_dir = os.path.join(out_root, "images", split)
out_lbl_dir = os.path.join(out_root, "labels", split)
os.makedirs(out_img_dir, exist_ok=True)
os.makedirs(out_lbl_dir, exist_ok=True)
for fn in files:
# move image
shutil.move(
os.path.join(flat_img, fn),
os.path.join(out_img_dir, fn)
)
# move corresponding .txt label
lbl_fn = fn.rsplit('.',1)[0] + ".txt"
shutil.move(
os.path.join(flat_lbl, lbl_fn),
os.path.join(out_lbl_dir, lbl_fn)
)
# --- clean up the flat dirs
shutil.rmtree(flat_img)
shutil.rmtree(flat_lbl)
# --- prepare a few before/after visuals
before, after = [], []
sample = random.sample(list(name_to_id.keys()), min(5, len(name_to_id)))
for fname in sample:
src = os.path.join(img_src, fname)
img = cv2.cvtColor(cv2.imread(src), cv2.COLOR_BGR2RGB)
seg_vis = img.copy()
for anno in coco['annotations']:
if anno['image_id'] != name_to_id[fname]:
continue
pts = np.array(anno['segmentation'][0], np.int32).reshape(-1,2)
cv2.polylines(seg_vis, [pts], True, (255,0,0), 2)
box_vis = img.copy()
for line in annos.get(name_to_id[fname], []):
_, cxn, cyn, wnorm, hnorm = map(float, line.split())
iw, ih = images_info[name_to_id[fname]]['width'], images_info[name_to_id[fname]]['height']
w0, h0 = int(wnorm*iw), int(hnorm*ih)
x0 = int(cxn*iw - w0/2)
y0 = int(cyn*ih - h0/2)
cv2.rectangle(box_vis, (x0,y0), (x0+w0,y0+h0), (0,255,0), 2)
before.append(Image.fromarray(seg_vis))
after.append(Image.fromarray(box_vis))
project_slug = f"{proj_name}-detection"
return before, after, out_root, project_slug
def upload_and_train_detection(
api_key: str,
project_slug: str,
dataset_path: str,
project_license: str = "MIT",
project_type: str = "object-detection"
):
rf = Roboflow(api_key=api_key)
ws = rf.workspace()
# get-or-create detection project
try:
proj = ws.project(project_slug)
except:
proj = ws.create_project(
project_slug,
annotation=project_type,
project_type=project_type,
project_license=project_license
)
# upload the folder that now has train/valid/test
ws.upload_dataset(
dataset_path,
project_slug,
project_license=project_license,
project_type=project_type
)
# create a new version & queue training
version_num = proj.generate_version(settings={
"augmentation": {},
"preprocessing": {},
})
proj.version(str(version_num)).train()
# return the hosted endpoint URL
m = proj.version(str(version_num)).model
return f"{m['base_url']}{m['id']}?api_key={api_key}"
# --- Gradio UI ---
with gr.Blocks() as app:
gr.Markdown("## 🔄 Seg→BBox + Auto‐Upload/Train")
api_input = gr.Textbox(label="Roboflow API Key", type="password")
url_input = gr.Textbox(label="Segmentation Dataset URL")
run_btn = gr.Button("Convert to BBoxes")
before_g = gr.Gallery(columns=5, label="Before")
after_g = gr.Gallery(columns=5, label="After")
ds_state = gr.Textbox(visible=False)
slug_state = gr.Textbox(visible=False)
run_btn.click(
convert_seg_to_bbox,
inputs=[api_input, url_input],
outputs=[before_g, after_g, ds_state, slug_state]
)
gr.Markdown("## 🚀 Upload & Train Detection Model")
train_btn = gr.Button("Upload & Train")
url_out = gr.Textbox(label="Hosted Model Endpoint URL")
train_btn.click(
upload_and_train_detection,
inputs=[api_input, slug_state, ds_state],
outputs=[url_out]
)
if __name__ == "__main__":
app.launch()