Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -48,7 +48,7 @@ def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1,
|
|
48 |
cat_ids = sorted(c['id'] for c in coco.get('categories', []))
|
49 |
id_to_index = {cid: idx for idx, cid in enumerate(cat_ids)}
|
50 |
|
51 |
-
# flatten & convert
|
52 |
out_root = tempfile.mkdtemp(prefix="yolov8_")
|
53 |
flat_img = os.path.join(out_root, "flat_images")
|
54 |
flat_lbl = os.path.join(out_root, "flat_labels")
|
@@ -58,26 +58,25 @@ def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1,
|
|
58 |
annos = {}
|
59 |
for anno in coco['annotations']:
|
60 |
img_id = anno['image_id']
|
61 |
-
|
62 |
-
xs, ys = poly[0::2], poly[1::2]
|
63 |
xmin, xmax = min(xs), max(xs)
|
64 |
ymin, ymax = min(ys), max(ys)
|
65 |
-
w, h
|
66 |
-
cx, cy
|
67 |
|
68 |
-
iw, ih
|
69 |
-
line
|
70 |
f"{id_to_index[anno['category_id']]} "
|
71 |
f"{cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}"
|
72 |
)
|
73 |
annos.setdefault(img_id, []).append(line)
|
74 |
|
75 |
name_to_id = {img['file_name']: img['id'] for img in coco['images']}
|
76 |
-
file_paths = {
|
77 |
-
|
78 |
-
for f in files
|
79 |
-
|
80 |
-
|
81 |
|
82 |
for fname, img_id in name_to_id.items():
|
83 |
src = file_paths.get(fname)
|
@@ -87,7 +86,7 @@ def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1,
|
|
87 |
with open(os.path.join(flat_lbl, fname.rsplit('.',1)[0] + ".txt"), 'w') as lf:
|
88 |
lf.write("\n".join(annos.get(img_id, [])))
|
89 |
|
90 |
-
# split
|
91 |
all_files = sorted(f for f in os.listdir(flat_img)
|
92 |
if f.lower().endswith(('.jpg','.png','.jpeg')))
|
93 |
random.shuffle(all_files)
|
@@ -108,16 +107,14 @@ def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1,
|
|
108 |
os.makedirs(idir, exist_ok=True)
|
109 |
os.makedirs(ldir, exist_ok=True)
|
110 |
for fn in files:
|
111 |
-
shutil.move(os.path.join(flat_img, fn),
|
112 |
-
os.path.join(idir, fn))
|
113 |
lbl = fn.rsplit('.',1)[0] + ".txt"
|
114 |
-
shutil.move(os.path.join(flat_lbl, lbl),
|
115 |
-
os.path.join(ldir, lbl))
|
116 |
|
117 |
shutil.rmtree(flat_img)
|
118 |
shutil.rmtree(flat_lbl)
|
119 |
|
120 |
-
#
|
121 |
before, after = [], []
|
122 |
sample = random.sample(list(name_to_id.keys()), min(5, len(name_to_id)))
|
123 |
for fname in sample:
|
@@ -159,76 +156,73 @@ def upload_and_train_detection(
|
|
159 |
rf = Roboflow(api_key=api_key)
|
160 |
ws = rf.workspace(workspace)
|
161 |
|
162 |
-
# 1) Try
|
163 |
try:
|
164 |
proj = ws.project(project_slug)
|
165 |
except Exception as e:
|
166 |
if "does not exist" in str(e).lower():
|
167 |
proj = ws.create_project(
|
168 |
-
project_slug,
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
)
|
173 |
else:
|
174 |
raise
|
175 |
|
176 |
-
# 2) If
|
177 |
-
if getattr(proj, "
|
178 |
new_slug = project_slug + "-v2"
|
179 |
proj = ws.create_project(
|
180 |
-
new_slug,
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
)
|
185 |
project_slug = new_slug
|
186 |
|
187 |
-
# 3) Upload train/
|
188 |
ws.upload_dataset(
|
189 |
dataset_path,
|
190 |
project_slug,
|
191 |
-
|
192 |
project_type=project_type
|
193 |
)
|
194 |
|
195 |
-
# 4) Generate new version
|
196 |
try:
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
except RuntimeError as e:
|
202 |
msg = str(e).lower()
|
203 |
if "unsupported request" in msg or "does not exist" in msg:
|
|
|
204 |
suffix = "-v3" if project_slug.endswith("-v2") else "-v2"
|
205 |
new_slug = project_slug + suffix
|
206 |
proj = ws.create_project(
|
207 |
-
new_slug,
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
)
|
212 |
project_slug = new_slug
|
213 |
-
ws.upload_dataset(
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
218 |
)
|
219 |
-
version_num = proj.generate_version(settings={
|
220 |
-
"augmentation": {},
|
221 |
-
"preprocessing": {},
|
222 |
-
})
|
223 |
else:
|
224 |
raise
|
225 |
|
226 |
-
# 5) Kick off training
|
227 |
-
proj.version(
|
228 |
|
229 |
-
# 6) Return the hosted endpoint
|
230 |
-
|
231 |
-
return f"{m['base_url']}{m['id']}?api_key={api_key}"
|
232 |
|
233 |
|
234 |
# --- Gradio UI ---
|
|
|
48 |
cat_ids = sorted(c['id'] for c in coco.get('categories', []))
|
49 |
id_to_index = {cid: idx for idx, cid in enumerate(cat_ids)}
|
50 |
|
51 |
+
# flatten & convert to YOLO bboxes
|
52 |
out_root = tempfile.mkdtemp(prefix="yolov8_")
|
53 |
flat_img = os.path.join(out_root, "flat_images")
|
54 |
flat_lbl = os.path.join(out_root, "flat_labels")
|
|
|
58 |
annos = {}
|
59 |
for anno in coco['annotations']:
|
60 |
img_id = anno['image_id']
|
61 |
+
xs, ys = anno['segmentation'][0][0::2], anno['segmentation'][0][1::2]
|
|
|
62 |
xmin, xmax = min(xs), max(xs)
|
63 |
ymin, ymax = min(ys), max(ys)
|
64 |
+
w, h = xmax - xmin, ymax - ymin
|
65 |
+
cx, cy = xmin + w/2, ymin + h/2
|
66 |
|
67 |
+
iw, ih = images_info[img_id]['width'], images_info[img_id]['height']
|
68 |
+
line = (
|
69 |
f"{id_to_index[anno['category_id']]} "
|
70 |
f"{cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}"
|
71 |
)
|
72 |
annos.setdefault(img_id, []).append(line)
|
73 |
|
74 |
name_to_id = {img['file_name']: img['id'] for img in coco['images']}
|
75 |
+
file_paths = {
|
76 |
+
f: os.path.join(dp, f)
|
77 |
+
for dp, _, files in os.walk(root) for f in files
|
78 |
+
if f in name_to_id
|
79 |
+
}
|
80 |
|
81 |
for fname, img_id in name_to_id.items():
|
82 |
src = file_paths.get(fname)
|
|
|
86 |
with open(os.path.join(flat_lbl, fname.rsplit('.',1)[0] + ".txt"), 'w') as lf:
|
87 |
lf.write("\n".join(annos.get(img_id, [])))
|
88 |
|
89 |
+
# split into train/valid/test
|
90 |
all_files = sorted(f for f in os.listdir(flat_img)
|
91 |
if f.lower().endswith(('.jpg','.png','.jpeg')))
|
92 |
random.shuffle(all_files)
|
|
|
107 |
os.makedirs(idir, exist_ok=True)
|
108 |
os.makedirs(ldir, exist_ok=True)
|
109 |
for fn in files:
|
110 |
+
shutil.move(os.path.join(flat_img, fn), os.path.join(idir, fn))
|
|
|
111 |
lbl = fn.rsplit('.',1)[0] + ".txt"
|
112 |
+
shutil.move(os.path.join(flat_lbl, lbl), os.path.join(ldir, lbl))
|
|
|
113 |
|
114 |
shutil.rmtree(flat_img)
|
115 |
shutil.rmtree(flat_lbl)
|
116 |
|
117 |
+
# before/after visuals
|
118 |
before, after = [], []
|
119 |
sample = random.sample(list(name_to_id.keys()), min(5, len(name_to_id)))
|
120 |
for fname in sample:
|
|
|
156 |
rf = Roboflow(api_key=api_key)
|
157 |
ws = rf.workspace(workspace)
|
158 |
|
159 |
+
# 1) Try fetch existing project
|
160 |
try:
|
161 |
proj = ws.project(project_slug)
|
162 |
except Exception as e:
|
163 |
if "does not exist" in str(e).lower():
|
164 |
proj = ws.create_project(
|
165 |
+
name=project_slug,
|
166 |
+
type=project_type,
|
167 |
+
annotation=project_slug,
|
168 |
+
license=project_license
|
169 |
)
|
170 |
else:
|
171 |
raise
|
172 |
|
173 |
+
# 2) If the type mismatches, spin up a fresh <slug>-v2 :contentReference[oaicite:0]{index=0}
|
174 |
+
if getattr(proj, "type", None) != project_type:
|
175 |
new_slug = project_slug + "-v2"
|
176 |
proj = ws.create_project(
|
177 |
+
name=new_slug,
|
178 |
+
type=project_type,
|
179 |
+
annotation=new_slug,
|
180 |
+
license=project_license
|
181 |
)
|
182 |
project_slug = new_slug
|
183 |
|
184 |
+
# 3) Upload your train/valid/test
|
185 |
ws.upload_dataset(
|
186 |
dataset_path,
|
187 |
project_slug,
|
188 |
+
license=project_license,
|
189 |
project_type=project_type
|
190 |
)
|
191 |
|
192 |
+
# 4) Generate a new version with the documented signature :contentReference[oaicite:1]{index=1}
|
193 |
try:
|
194 |
+
new_version = proj.generate_version(
|
195 |
+
preprocessing={},
|
196 |
+
augmentation={}
|
197 |
+
)
|
198 |
except RuntimeError as e:
|
199 |
msg = str(e).lower()
|
200 |
if "unsupported request" in msg or "does not exist" in msg:
|
201 |
+
# fallback to <slug>-v3
|
202 |
suffix = "-v3" if project_slug.endswith("-v2") else "-v2"
|
203 |
new_slug = project_slug + suffix
|
204 |
proj = ws.create_project(
|
205 |
+
name=new_slug,
|
206 |
+
type=project_type,
|
207 |
+
annotation=new_slug,
|
208 |
+
license=project_license
|
209 |
)
|
210 |
project_slug = new_slug
|
211 |
+
ws.upload_dataset(dataset_path, project_slug,
|
212 |
+
license=project_license,
|
213 |
+
project_type=project_type)
|
214 |
+
new_version = proj.generate_version(
|
215 |
+
preprocessing={},
|
216 |
+
augmentation={}
|
217 |
)
|
|
|
|
|
|
|
|
|
218 |
else:
|
219 |
raise
|
220 |
|
221 |
+
# 5) Kick off training (asynchronously) :contentReference[oaicite:2]{index=2}
|
222 |
+
model = proj.version(new_version).train()
|
223 |
|
224 |
+
# 6) Return the hosted endpoint URL
|
225 |
+
return f"{model['base_url']}{model['id']}?api_key={api_key}"
|
|
|
226 |
|
227 |
|
228 |
# --- Gradio UI ---
|