segtodetect / app.py
wuhp's picture
Update app.py
99a318c verified
raw
history blame
7.75 kB
import os
import json
import random
import shutil
import tempfile
from urllib.parse import urlparse
import cv2
import numpy as np
from PIL import Image
import gradio as gr
from roboflow import Roboflow
def parse_roboflow_url(url: str):
parsed = urlparse(url)
parts = parsed.path.strip('/').split('/')
workspace = parts[0]
project = parts[1]
try:
version = int(parts[-1])
except ValueError:
version = int(parts[-2])
return workspace, project, version
def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1, 0.1)):
# --- download segmentation export
rf = Roboflow(api_key=api_key)
ws, proj_name, ver = parse_roboflow_url(dataset_url)
version_obj = rf.workspace(ws).project(proj_name).version(ver)
dataset = version_obj.download("coco-segmentation")
root = dataset.location
# --- find the COCO JSON
ann_file = None
for dp, _, files in os.walk(root):
for f in files:
if f.lower().endswith('.json'):
ann_file = os.path.join(dp, f)
break
if ann_file:
break
if not ann_file:
raise FileNotFoundError(f"No JSON annotations under {root}")
coco = json.load(open(ann_file, 'r'))
images_info = {img['id']: img for img in coco['images']}
cat_ids = sorted(c['id'] for c in coco.get('categories', []))
id_to_index = {cid: idx for idx, cid in enumerate(cat_ids)}
# --- make a flat YOLO folder
out_root = tempfile.mkdtemp(prefix="yolov8_")
flat_img = os.path.join(out_root, "flat_images")
flat_lbl = os.path.join(out_root, "flat_labels")
os.makedirs(flat_img, exist_ok=True)
os.makedirs(flat_lbl, exist_ok=True)
# --- convert each segmentation to a YOLO bbox line
annos = {}
for anno in coco['annotations']:
img_id = anno['image_id']
poly = anno['segmentation'][0]
xs, ys = poly[0::2], poly[1::2]
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
w, h = x_max - x_min, y_max - y_min
cx, cy = x_min + w/2, y_min + h/2
iw, ih = images_info[img_id]['width'], images_info[img_id]['height']
line = (
f"{id_to_index[anno['category_id']]} "
f"{cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}"
)
annos.setdefault(img_id, []).append(line)
# --- locate the single images folder
img_src = None
for dp, _, files in os.walk(root):
if any(f.lower().endswith(('.jpg','.png','.jpeg')) for f in files):
img_src = dp
break
if not img_src:
raise FileNotFoundError(f"No images folder in {root}")
# --- copy images + write flat labels
name_to_id = {img['file_name']: img['id'] for img in coco['images']}
for fname, img_id in name_to_id.items():
src_path = os.path.join(img_src, fname)
if not os.path.isfile(src_path):
continue
shutil.copy(src_path, os.path.join(flat_img, fname))
with open(os.path.join(flat_lbl, fname.rsplit('.',1)[0] + ".txt"), 'w') as lf:
lf.write("\n".join(annos.get(img_id, [])))
# --- split filenames into train/valid/test lists
all_files = sorted([f for f in os.listdir(flat_img) if f.lower().endswith(('.jpg','.png','.jpeg'))])
random.shuffle(all_files)
n = len(all_files)
n_train = max(1, int(n * split_ratios[0]))
n_valid = max(1, int(n * split_ratios[1]))
# ensure we don’t overshoot
n_valid = min(n_valid, n - n_train - 1)
splits = {
"train": all_files[:n_train],
"valid": all_files[n_train:n_train+n_valid],
"test": all_files[n_train+n_valid:]
}
# --- create Roboflow‑friendly structure:
# out_root/images/{train,valid,test}
# out_root/labels/{train,valid,test}
for split, files in splits.items():
img_dir = os.path.join(out_root, "images", split)
lbl_dir = os.path.join(out_root, "labels", split)
os.makedirs(img_dir, exist_ok=True)
os.makedirs(lbl_dir, exist_ok=True)
for fn in files:
shutil.move(os.path.join(flat_img, fn), os.path.join(img_dir, fn))
shutil.move(os.path.join(flat_lbl, fn.rsplit('.',1)[0] + ".txt"),
os.path.join(lbl_dir, fn.rsplit('.',1)[0] + ".txt"))
# --- clean up flats
shutil.rmtree(flat_img)
shutil.rmtree(flat_lbl)
# --- build a few before/after previews
before, after = [], []
sample = random.sample(list(name_to_id.keys()), min(5, len(name_to_id)))
for fname in sample:
src = os.path.join(img_src, fname)
img = cv2.cvtColor(cv2.imread(src), cv2.COLOR_BGR2RGB)
seg_vis = img.copy()
for anno in coco['annotations']:
if anno['image_id'] != name_to_id[fname]:
continue
pts = np.array(anno['segmentation'][0], np.int32).reshape(-1, 2)
cv2.polylines(seg_vis, [pts], True, (255, 0, 0), 2)
box_vis = img.copy()
for line in annos.get(name_to_id[fname], []):
_, cxn, cyn, wnorm, hnorm = map(float, line.split())
iw, ih = images_info[name_to_id[fname]]['width'], images_info[name_to_id[fname]]['height']
w0, h0 = int(wnorm * iw), int(hnorm * ih)
x0 = int(cxn * iw - w0 / 2)
y0 = int(cyn * ih - h0 / 2)
cv2.rectangle(box_vis, (x0, y0), (x0+w0, y0+h0), (0, 255, 0), 2)
before.append(Image.fromarray(seg_vis))
after.append(Image.fromarray(box_vis))
project_slug = f"{proj_name}-detection"
return before, after, out_root, project_slug
def upload_and_train_detection(
api_key: str,
project_slug: str,
dataset_path: str,
project_license: str = "MIT",
project_type: str = "object-detection"
):
rf = Roboflow(api_key=api_key)
ws = rf.workspace()
# get-or-create your detection project
try:
proj = ws.project(project_slug)
except Exception:
proj = ws.create_project(
project_slug,
annotation=project_type,
project_type=project_type,
project_license=project_license
)
# upload the properly‑split folder
ws.upload_dataset(
dataset_path,
project_slug,
project_license=project_license,
project_type=project_type
)
# create a new version
version_num = proj.generate_version(settings={
"augmentation": {},
"preprocessing": {},
})
# enqueue training (now finds train/valid/test)
proj.version(str(version_num)).train()
# return the hosted endpoint URL
m = proj.version(str(version_num)).model
return f"{m['base_url']}{m['id']}?api_key={api_key}"
# --- Gradio UI ---
with gr.Blocks() as app:
gr.Markdown("## 🔄 Seg→BBox + Auto‐Upload/Train")
api_input = gr.Textbox(label="Roboflow API Key", type="password")
url_input = gr.Textbox(label="Segmentation Dataset URL")
run_btn = gr.Button("Convert to BBoxes")
before_g = gr.Gallery(columns=5, label="Before")
after_g = gr.Gallery(columns=5, label="After")
ds_state = gr.Textbox(visible=False)
slug_state = gr.Textbox(visible=False)
run_btn.click(
convert_seg_to_bbox,
inputs=[api_input, url_input],
outputs=[before_g, after_g, ds_state, slug_state]
)
gr.Markdown("## 🚀 Upload & Train Detection Model")
train_btn = gr.Button("Upload & Train")
url_out = gr.Textbox(label="Hosted Model Endpoint URL")
train_btn.click(
upload_and_train_detection,
inputs=[api_input, slug_state, ds_state],
outputs=[url_out]
)
if __name__ == "__main__":
app.launch()