wuhp commited on
Commit
2f9f7a7
·
verified ·
1 Parent(s): c985904

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -54
app.py CHANGED
@@ -48,7 +48,7 @@ def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1,
48
  cat_ids = sorted(c['id'] for c in coco.get('categories', []))
49
  id_to_index = {cid: idx for idx, cid in enumerate(cat_ids)}
50
 
51
- # flatten & convert
52
  out_root = tempfile.mkdtemp(prefix="yolov8_")
53
  flat_img = os.path.join(out_root, "flat_images")
54
  flat_lbl = os.path.join(out_root, "flat_labels")
@@ -58,26 +58,25 @@ def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1,
58
  annos = {}
59
  for anno in coco['annotations']:
60
  img_id = anno['image_id']
61
- poly = anno['segmentation'][0]
62
- xs, ys = poly[0::2], poly[1::2]
63
  xmin, xmax = min(xs), max(xs)
64
  ymin, ymax = min(ys), max(ys)
65
- w, h = xmax - xmin, ymax - ymin
66
- cx, cy = xmin + w/2, ymin + h/2
67
 
68
- iw, ih = images_info[img_id]['width'], images_info[img_id]['height']
69
- line = (
70
  f"{id_to_index[anno['category_id']]} "
71
  f"{cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}"
72
  )
73
  annos.setdefault(img_id, []).append(line)
74
 
75
  name_to_id = {img['file_name']: img['id'] for img in coco['images']}
76
- file_paths = {}
77
- for dp, _, files in os.walk(root):
78
- for f in files:
79
- if f in name_to_id:
80
- file_paths[f] = os.path.join(dp, f)
81
 
82
  for fname, img_id in name_to_id.items():
83
  src = file_paths.get(fname)
@@ -87,7 +86,7 @@ def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1,
87
  with open(os.path.join(flat_lbl, fname.rsplit('.',1)[0] + ".txt"), 'w') as lf:
88
  lf.write("\n".join(annos.get(img_id, [])))
89
 
90
- # split
91
  all_files = sorted(f for f in os.listdir(flat_img)
92
  if f.lower().endswith(('.jpg','.png','.jpeg')))
93
  random.shuffle(all_files)
@@ -108,16 +107,14 @@ def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1,
108
  os.makedirs(idir, exist_ok=True)
109
  os.makedirs(ldir, exist_ok=True)
110
  for fn in files:
111
- shutil.move(os.path.join(flat_img, fn),
112
- os.path.join(idir, fn))
113
  lbl = fn.rsplit('.',1)[0] + ".txt"
114
- shutil.move(os.path.join(flat_lbl, lbl),
115
- os.path.join(ldir, lbl))
116
 
117
  shutil.rmtree(flat_img)
118
  shutil.rmtree(flat_lbl)
119
 
120
- # prepare visuals
121
  before, after = [], []
122
  sample = random.sample(list(name_to_id.keys()), min(5, len(name_to_id)))
123
  for fname in sample:
@@ -159,76 +156,73 @@ def upload_and_train_detection(
159
  rf = Roboflow(api_key=api_key)
160
  ws = rf.workspace(workspace)
161
 
162
- # 1) Try to fetch existing project
163
  try:
164
  proj = ws.project(project_slug)
165
  except Exception as e:
166
  if "does not exist" in str(e).lower():
167
  proj = ws.create_project(
168
- project_slug,
169
- annotation=project_type,
170
- project_type=project_type,
171
- project_license=project_license
172
  )
173
  else:
174
  raise
175
 
176
- # 2) If it exists but is NOT object-detection, make a fresh slug
177
- if getattr(proj, "annotation", None) != project_type:
178
  new_slug = project_slug + "-v2"
179
  proj = ws.create_project(
180
- new_slug,
181
- annotation=project_type,
182
- project_type=project_type,
183
- project_license=project_license
184
  )
185
  project_slug = new_slug
186
 
187
- # 3) Upload train/val/test
188
  ws.upload_dataset(
189
  dataset_path,
190
  project_slug,
191
- project_license=project_license,
192
  project_type=project_type
193
  )
194
 
195
- # 4) Generate new version & train, with fallback on unsupported-request errors
196
  try:
197
- version_num = proj.generate_version(settings={
198
- "augmentation": {},
199
- "preprocessing": {},
200
- })
201
  except RuntimeError as e:
202
  msg = str(e).lower()
203
  if "unsupported request" in msg or "does not exist" in msg:
 
204
  suffix = "-v3" if project_slug.endswith("-v2") else "-v2"
205
  new_slug = project_slug + suffix
206
  proj = ws.create_project(
207
- new_slug,
208
- annotation=project_type,
209
- project_type=project_type,
210
- project_license=project_license
211
  )
212
  project_slug = new_slug
213
- ws.upload_dataset(
214
- dataset_path,
215
- project_slug,
216
- project_license=project_license,
217
- project_type=project_type
 
218
  )
219
- version_num = proj.generate_version(settings={
220
- "augmentation": {},
221
- "preprocessing": {},
222
- })
223
  else:
224
  raise
225
 
226
- # 5) Kick off training
227
- proj.version(str(version_num)).train()
228
 
229
- # 6) Return the hosted endpoint
230
- m = proj.version(str(version_num)).model
231
- return f"{m['base_url']}{m['id']}?api_key={api_key}"
232
 
233
 
234
  # --- Gradio UI ---
 
48
  cat_ids = sorted(c['id'] for c in coco.get('categories', []))
49
  id_to_index = {cid: idx for idx, cid in enumerate(cat_ids)}
50
 
51
+ # flatten & convert to YOLO bboxes
52
  out_root = tempfile.mkdtemp(prefix="yolov8_")
53
  flat_img = os.path.join(out_root, "flat_images")
54
  flat_lbl = os.path.join(out_root, "flat_labels")
 
58
  annos = {}
59
  for anno in coco['annotations']:
60
  img_id = anno['image_id']
61
+ xs, ys = anno['segmentation'][0][0::2], anno['segmentation'][0][1::2]
 
62
  xmin, xmax = min(xs), max(xs)
63
  ymin, ymax = min(ys), max(ys)
64
+ w, h = xmax - xmin, ymax - ymin
65
+ cx, cy = xmin + w/2, ymin + h/2
66
 
67
+ iw, ih = images_info[img_id]['width'], images_info[img_id]['height']
68
+ line = (
69
  f"{id_to_index[anno['category_id']]} "
70
  f"{cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}"
71
  )
72
  annos.setdefault(img_id, []).append(line)
73
 
74
  name_to_id = {img['file_name']: img['id'] for img in coco['images']}
75
+ file_paths = {
76
+ f: os.path.join(dp, f)
77
+ for dp, _, files in os.walk(root) for f in files
78
+ if f in name_to_id
79
+ }
80
 
81
  for fname, img_id in name_to_id.items():
82
  src = file_paths.get(fname)
 
86
  with open(os.path.join(flat_lbl, fname.rsplit('.',1)[0] + ".txt"), 'w') as lf:
87
  lf.write("\n".join(annos.get(img_id, [])))
88
 
89
+ # split into train/valid/test
90
  all_files = sorted(f for f in os.listdir(flat_img)
91
  if f.lower().endswith(('.jpg','.png','.jpeg')))
92
  random.shuffle(all_files)
 
107
  os.makedirs(idir, exist_ok=True)
108
  os.makedirs(ldir, exist_ok=True)
109
  for fn in files:
110
+ shutil.move(os.path.join(flat_img, fn), os.path.join(idir, fn))
 
111
  lbl = fn.rsplit('.',1)[0] + ".txt"
112
+ shutil.move(os.path.join(flat_lbl, lbl), os.path.join(ldir, lbl))
 
113
 
114
  shutil.rmtree(flat_img)
115
  shutil.rmtree(flat_lbl)
116
 
117
+ # before/after visuals
118
  before, after = [], []
119
  sample = random.sample(list(name_to_id.keys()), min(5, len(name_to_id)))
120
  for fname in sample:
 
156
  rf = Roboflow(api_key=api_key)
157
  ws = rf.workspace(workspace)
158
 
159
+ # 1) Try fetch existing project
160
  try:
161
  proj = ws.project(project_slug)
162
  except Exception as e:
163
  if "does not exist" in str(e).lower():
164
  proj = ws.create_project(
165
+ name=project_slug,
166
+ type=project_type,
167
+ annotation=project_slug,
168
+ license=project_license
169
  )
170
  else:
171
  raise
172
 
173
+ # 2) If the type mismatches, spin up a fresh <slug>-v2 :contentReference[oaicite:0]{index=0}
174
+ if getattr(proj, "type", None) != project_type:
175
  new_slug = project_slug + "-v2"
176
  proj = ws.create_project(
177
+ name=new_slug,
178
+ type=project_type,
179
+ annotation=new_slug,
180
+ license=project_license
181
  )
182
  project_slug = new_slug
183
 
184
+ # 3) Upload your train/valid/test
185
  ws.upload_dataset(
186
  dataset_path,
187
  project_slug,
188
+ license=project_license,
189
  project_type=project_type
190
  )
191
 
192
+ # 4) Generate a new version with the documented signature :contentReference[oaicite:1]{index=1}
193
  try:
194
+ new_version = proj.generate_version(
195
+ preprocessing={},
196
+ augmentation={}
197
+ )
198
  except RuntimeError as e:
199
  msg = str(e).lower()
200
  if "unsupported request" in msg or "does not exist" in msg:
201
+ # fallback to <slug>-v3
202
  suffix = "-v3" if project_slug.endswith("-v2") else "-v2"
203
  new_slug = project_slug + suffix
204
  proj = ws.create_project(
205
+ name=new_slug,
206
+ type=project_type,
207
+ annotation=new_slug,
208
+ license=project_license
209
  )
210
  project_slug = new_slug
211
+ ws.upload_dataset(dataset_path, project_slug,
212
+ license=project_license,
213
+ project_type=project_type)
214
+ new_version = proj.generate_version(
215
+ preprocessing={},
216
+ augmentation={}
217
  )
 
 
 
 
218
  else:
219
  raise
220
 
221
+ # 5) Kick off training (asynchronously) :contentReference[oaicite:2]{index=2}
222
+ model = proj.version(new_version).train()
223
 
224
+ # 6) Return the hosted endpoint URL
225
+ return f"{model['base_url']}{model['id']}?api_key={api_key}"
 
226
 
227
 
228
  # --- Gradio UI ---