Spaces:
Running
Running
import os | |
import json | |
import random | |
import shutil | |
import tempfile | |
from urllib.parse import urlparse | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
from roboflow import Roboflow # removed RoboflowError, just import Roboflow | |
def parse_roboflow_url(url: str): | |
parsed = urlparse(url) | |
parts = parsed.path.strip('/').split('/') | |
ws = parts[0] | |
proj = parts[1] | |
try: | |
ver = int(parts[-1]) | |
except ValueError: | |
ver = int(parts[-2]) | |
return ws, proj, ver | |
def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1, 0.1)): | |
rf = Roboflow(api_key=api_key) | |
workspace, proj_name, ver = parse_roboflow_url(dataset_url) | |
version_obj = rf.workspace(workspace).project(proj_name).version(ver) | |
dataset = version_obj.download("coco-segmentation") | |
root = dataset.location | |
# find COCO JSON | |
ann_file = None | |
for dp, _, files in os.walk(root): | |
for f in files: | |
if f.lower().endswith('.json'): | |
ann_file = os.path.join(dp, f) | |
break | |
if ann_file: | |
break | |
if not ann_file: | |
raise FileNotFoundError(f"No JSON annotations found under {root}") | |
coco = json.load(open(ann_file, 'r')) | |
images_info = {img['id']: img for img in coco['images']} | |
cat_ids = sorted(c['id'] for c in coco.get('categories', [])) | |
id_to_index = {cid: idx for idx, cid in enumerate(cat_ids)} | |
# flatten & convert | |
out_root = tempfile.mkdtemp(prefix="yolov8_") | |
flat_img = os.path.join(out_root, "flat_images") | |
flat_lbl = os.path.join(out_root, "flat_labels") | |
os.makedirs(flat_img, exist_ok=True) | |
os.makedirs(flat_lbl, exist_ok=True) | |
annos = {} | |
for anno in coco['annotations']: | |
img_id = anno['image_id'] | |
poly = anno['segmentation'][0] | |
xs, ys = poly[0::2], poly[1::2] | |
xmin, xmax = min(xs), max(xs) | |
ymin, ymax = min(ys), max(ys) | |
w, h = xmax - xmin, ymax - ymin | |
cx, cy = xmin + w/2, ymin + h/2 | |
iw, ih = images_info[img_id]['width'], images_info[img_id]['height'] | |
line = ( | |
f"{id_to_index[anno['category_id']]} " | |
f"{cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}" | |
) | |
annos.setdefault(img_id, []).append(line) | |
name_to_id = {img['file_name']: img['id'] for img in coco['images']} | |
file_paths = {} | |
for dp, _, files in os.walk(root): | |
for f in files: | |
if f in name_to_id: | |
file_paths[f] = os.path.join(dp, f) | |
for fname, img_id in name_to_id.items(): | |
src = file_paths.get(fname) | |
if not src: | |
continue | |
shutil.copy(src, os.path.join(flat_img, fname)) | |
with open(os.path.join(flat_lbl, fname.rsplit('.',1)[0] + ".txt"), 'w') as lf: | |
lf.write("\n".join(annos.get(img_id, []))) | |
# split | |
all_files = sorted(f for f in os.listdir(flat_img) | |
if f.lower().endswith(('.jpg','.png','.jpeg'))) | |
random.shuffle(all_files) | |
n = len(all_files) | |
n_train = max(1, int(n * split_ratios[0])) | |
n_valid = max(1, int(n * split_ratios[1])) | |
n_valid = min(n_valid, n - n_train - 1) | |
splits = { | |
"train": all_files[:n_train], | |
"valid": all_files[n_train:n_train+n_valid], | |
"test": all_files[n_train+n_valid:] | |
} | |
for split, files in splits.items(): | |
idir = os.path.join(out_root, "images", split) | |
ldir = os.path.join(out_root, "labels", split) | |
os.makedirs(idir, exist_ok=True) | |
os.makedirs(ldir, exist_ok=True) | |
for fn in files: | |
shutil.move(os.path.join(flat_img, fn), | |
os.path.join(idir, fn)) | |
lbl = fn.rsplit('.',1)[0] + ".txt" | |
shutil.move(os.path.join(flat_lbl, lbl), | |
os.path.join(ldir, lbl)) | |
shutil.rmtree(flat_img) | |
shutil.rmtree(flat_lbl) | |
# prepare visuals | |
before, after = [], [] | |
sample = random.sample(list(name_to_id.keys()), min(5, len(name_to_id))) | |
for fname in sample: | |
src = file_paths.get(fname) | |
if not src: | |
continue | |
img = cv2.cvtColor(cv2.imread(src), cv2.COLOR_BGR2RGB) | |
seg_vis = img.copy() | |
for anno in coco['annotations']: | |
if anno['image_id'] != name_to_id[fname]: | |
continue | |
pts = np.array(anno['segmentation'][0], np.int32).reshape(-1,2) | |
cv2.polylines(seg_vis, [pts], True, (255,0,0), 2) | |
box_vis = img.copy() | |
for line in annos.get(name_to_id[fname], []): | |
_, cxn, cyn, wnorm, hnorm = map(float, line.split()) | |
iw, ih = images_info[name_to_id[fname]]['width'], images_info[name_to_id[fname]]['height'] | |
w0, h0 = int(wnorm*iw), int(hnorm*ih) | |
x0 = int(cxn*iw - w0/2) | |
y0 = int(cyn*ih - h0/2) | |
cv2.rectangle(box_vis, (x0,y0), (x0+w0,y0+h0), (0,255,0), 2) | |
before.append(Image.fromarray(seg_vis)) | |
after.append(Image.fromarray(box_vis)) | |
return before, after, out_root, proj_name + "-detection", workspace | |
def upload_and_train_detection( | |
api_key: str, | |
workspace: str, | |
project_slug: str, | |
dataset_path: str, | |
project_license: str = "MIT", | |
project_type: str = "object-detection" | |
): | |
rf = Roboflow(api_key=api_key) | |
ws = rf.workspace(workspace) | |
# get‑or‑create project by inspecting exception text | |
try: | |
proj = ws.project(project_slug) | |
except Exception as e: | |
if "does not exist" in str(e): | |
proj = ws.create_project( | |
project_slug, | |
annotation=project_type, | |
project_type=project_type, | |
project_license=project_license | |
) | |
else: | |
raise | |
# upload & train | |
ws.upload_dataset( | |
dataset_path, | |
project_slug, | |
project_license=project_license, | |
project_type=project_type | |
) | |
version_num = proj.generate_version(settings={ | |
"augmentation": {}, | |
"preprocessing": {}, | |
}) | |
proj.version(str(version_num)).train() | |
m = proj.version(str(version_num)).model | |
return f"{m['base_url']}{m['id']}?api_key={api_key}" | |
# --- Gradio UI --- | |
with gr.Blocks() as app: | |
gr.Markdown("## 🔄 Seg→BBox + Auto‑Upload/Train") | |
api_input = gr.Textbox(label="Roboflow API Key", type="password") | |
url_input = gr.Textbox(label="Segmentation Dataset URL") | |
run_btn = gr.Button("Convert to BBoxes") | |
before_g = gr.Gallery(columns=5, label="Before") | |
after_g = gr.Gallery(columns=5, label="After") | |
ds_state = gr.Textbox(visible=False) | |
slug_state = gr.Textbox(visible=False) | |
ws_state = gr.Textbox(visible=False) | |
run_btn.click( | |
convert_seg_to_bbox, | |
inputs=[api_input, url_input], | |
outputs=[before_g, after_g, ds_state, slug_state, ws_state] | |
) | |
gr.Markdown("## 🚀 Upload & Train Detection Model") | |
train_btn = gr.Button("Upload & Train") | |
url_out = gr.Textbox(label="Hosted Model Endpoint URL") | |
train_btn.click( | |
upload_and_train_detection, | |
inputs=[api_input, ws_state, slug_state, ds_state], | |
outputs=[url_out] | |
) | |
if __name__ == "__main__": | |
app.launch() | |