wuhp commited on
Commit
66be7dd
·
verified ·
1 Parent(s): 23838aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -74
app.py CHANGED
@@ -15,23 +15,29 @@ from roboflow import Roboflow
15
  def parse_roboflow_url(url: str):
16
  parsed = urlparse(url)
17
  parts = parsed.path.strip('/').split('/')
18
- ws = parts[0]
19
- proj = parts[1]
20
  try:
21
- ver = int(parts[-1])
22
  except ValueError:
23
- ver = int(parts[-2])
24
- return ws, proj, ver
25
 
26
 
27
  def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1, 0.1)):
28
- rf = Roboflow(api_key=api_key)
29
- workspace, proj_name, ver = parse_roboflow_url(dataset_url)
30
- version_obj = rf.workspace(workspace).project(proj_name).version(ver)
31
- dataset = version_obj.download("coco-segmentation")
32
- root = dataset.location
33
-
34
- # find COCO JSON
 
 
 
 
 
 
35
  ann_file = None
36
  for dp, _, files in os.walk(root):
37
  for f in files:
@@ -41,29 +47,22 @@ def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1,
41
  if ann_file:
42
  break
43
  if not ann_file:
44
- raise FileNotFoundError(f"No JSON annotations found under {root}")
45
 
46
  coco = json.load(open(ann_file, 'r'))
47
  images_info = {img['id']: img for img in coco['images']}
48
  cat_ids = sorted(c['id'] for c in coco.get('categories', []))
49
  id_to_index = {cid: idx for idx, cid in enumerate(cat_ids)}
50
 
51
- # flatten & convert to YOLO bboxes
52
- out_root = tempfile.mkdtemp(prefix="yolov8_")
53
- flat_img = os.path.join(out_root, "flat_images")
54
- flat_lbl = os.path.join(out_root, "flat_labels")
55
- os.makedirs(flat_img, exist_ok=True)
56
- os.makedirs(flat_lbl, exist_ok=True)
57
-
58
  annos = {}
59
  for anno in coco['annotations']:
60
  img_id = anno['image_id']
61
  xs, ys = anno['segmentation'][0][0::2], anno['segmentation'][0][1::2]
62
  xmin, xmax = min(xs), max(xs)
63
  ymin, ymax = min(ys), max(ys)
64
- w, h = xmax - xmin, ymax - ymin
65
- cx, cy = xmin + w/2, ymin + h/2
66
-
67
  iw, ih = images_info[img_id]['width'], images_info[img_id]['height']
68
  line = (
69
  f"{id_to_index[anno['category_id']]} "
@@ -71,6 +70,13 @@ def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1,
71
  )
72
  annos.setdefault(img_id, []).append(line)
73
 
 
 
 
 
 
 
 
74
  name_to_id = {img['file_name']: img['id'] for img in coco['images']}
75
  file_paths = {
76
  f: os.path.join(dp, f)
@@ -84,48 +90,44 @@ def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1,
84
  if not src:
85
  continue
86
  shutil.copy(src, os.path.join(flat_img, fname))
87
- with open(os.path.join(flat_lbl, fname.rsplit('.',1)[0] + ".txt"), 'w') as lf:
 
88
  lf.write("\n".join(annos.get(img_id, [])))
89
 
90
- # split into train/valid/test
91
- all_files = sorted(
92
- f for f in os.listdir(flat_img)
93
- if f.lower().endswith(('.jpg','.png','.jpeg'))
94
- )
95
  random.shuffle(all_files)
96
  n = len(all_files)
97
  n_train = max(1, int(n * split_ratios[0]))
98
  n_valid = max(1, int(n * split_ratios[1]))
99
  n_valid = min(n_valid, n - n_train - 1)
100
-
101
  splits = {
102
  "train": all_files[:n_train],
103
  "valid": all_files[n_train:n_train+n_valid],
104
  "test": all_files[n_train+n_valid:]
105
  }
106
 
 
107
  for split, files in splits.items():
108
- idir = os.path.join(out_root, "images", split)
109
- ldir = os.path.join(out_root, "labels", split)
110
- os.makedirs(idir, exist_ok=True)
111
- os.makedirs(ldir, exist_ok=True)
112
  for fn in files:
113
- shutil.move(os.path.join(flat_img, fn), os.path.join(idir, fn))
114
  lbl = fn.rsplit('.',1)[0] + ".txt"
115
- shutil.move(os.path.join(flat_lbl, lbl), os.path.join(ldir, lbl))
116
 
117
  shutil.rmtree(flat_img)
118
  shutil.rmtree(flat_lbl)
119
 
120
- # before/after visuals
121
  before, after = [], []
122
  sample = random.sample(list(name_to_id.keys()), min(5, len(name_to_id)))
123
  for fname in sample:
124
- src = file_paths.get(fname)
125
- if not src:
126
- continue
127
- img = cv2.cvtColor(cv2.imread(src), cv2.COLOR_BGR2RGB)
128
 
 
129
  seg_vis = img.copy()
130
  for anno in coco['annotations']:
131
  if anno['image_id'] != name_to_id[fname]:
@@ -133,6 +135,7 @@ def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1,
133
  pts = np.array(anno['segmentation'][0], np.int32).reshape(-1,2)
134
  cv2.polylines(seg_vis, [pts], True, (255,0,0), 2)
135
 
 
136
  box_vis = img.copy()
137
  for line in annos.get(name_to_id[fname], []):
138
  _, cxn, cyn, wnorm, hnorm = map(float, line.split())
@@ -145,27 +148,33 @@ def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1,
145
  before.append(Image.fromarray(seg_vis))
146
  after.append(Image.fromarray(box_vis))
147
 
148
- return before, after, out_root, proj_name + "-detection", workspace
 
149
 
150
 
151
  def upload_and_train_detection(
152
  api_key: str,
153
- workspace: str,
154
- project_slug: str,
155
  dataset_path: str,
156
  project_license: str = "MIT",
157
  project_type: str = "object-detection"
158
  ):
 
 
 
 
 
159
  rf = Roboflow(api_key=api_key)
160
- ws = rf.workspace(workspace)
 
161
 
162
- # 1) Try to fetch existing project
163
  try:
164
- proj = ws.project(project_slug)
165
  except Exception as e:
166
  if "does not exist" in str(e).lower():
167
  proj = ws.create_project(
168
- project_slug,
169
  annotation=project_type,
170
  project_type=project_type,
171
  project_license=project_license
@@ -173,26 +182,15 @@ def upload_and_train_detection(
173
  else:
174
  raise
175
 
176
- # 2) If it exists but as the wrong annotation type, spin up <slug>-v2
177
- if getattr(proj, "annotation", None) != project_type:
178
- new_slug = project_slug + "-v2"
179
- proj = ws.create_project(
180
- new_slug,
181
- annotation=project_type,
182
- project_type=project_type,
183
- project_license=project_license
184
- )
185
- project_slug = new_slug
186
-
187
- # 3) Upload train/valid/test
188
  ws.upload_dataset(
189
  dataset_path,
190
- project_slug,
191
  project_license=project_license,
192
  project_type=project_type
193
  )
194
 
195
- # 4) Generate new version & train using the `settings` parameter
196
  try:
197
  version_num = proj.generate_version(settings={
198
  "augmentation": {},
@@ -201,18 +199,17 @@ def upload_and_train_detection(
201
  except RuntimeError as e:
202
  msg = str(e).lower()
203
  if "unsupported request" in msg or "does not exist" in msg:
204
- suffix = "-v3" if project_slug.endswith("-v2") else "-v2"
205
- new_slug = project_slug + suffix
206
  proj = ws.create_project(
207
  new_slug,
208
  annotation=project_type,
209
  project_type=project_type,
210
  project_license=project_license
211
  )
212
- project_slug = new_slug
213
  ws.upload_dataset(
214
  dataset_path,
215
- project_slug,
216
  project_license=project_license,
217
  project_type=project_type
218
  )
@@ -223,10 +220,8 @@ def upload_and_train_detection(
223
  else:
224
  raise
225
 
226
- # 5) Kick off training
227
  model = proj.version(str(version_num)).train()
228
-
229
- # 6) Return the hosted endpoint URL
230
  return f"{model['base_url']}{model['id']}?api_key={api_key}"
231
 
232
 
@@ -239,14 +234,13 @@ with gr.Blocks() as app:
239
  run_btn = gr.Button("Convert to BBoxes")
240
  before_g = gr.Gallery(columns=5, label="Before")
241
  after_g = gr.Gallery(columns=5, label="After")
242
- ds_state = gr.Textbox(visible=False)
243
- slug_state = gr.Textbox(visible=False)
244
- ws_state = gr.Textbox(visible=False)
245
 
246
  run_btn.click(
247
  convert_seg_to_bbox,
248
  inputs=[api_input, url_input],
249
- outputs=[before_g, after_g, ds_state, slug_state, ws_state]
250
  )
251
 
252
  gr.Markdown("## 🚀 Upload & Train Detection Model")
@@ -255,7 +249,7 @@ with gr.Blocks() as app:
255
 
256
  train_btn.click(
257
  upload_and_train_detection,
258
- inputs=[api_input, ws_state, slug_state, ds_state],
259
  outputs=[url_out]
260
  )
261
 
 
15
  def parse_roboflow_url(url: str):
16
  parsed = urlparse(url)
17
  parts = parsed.path.strip('/').split('/')
18
+ workspace = parts[0]
19
+ project = parts[1]
20
  try:
21
+ version = int(parts[-1])
22
  except ValueError:
23
+ version = int(parts[-2])
24
+ return workspace, project, version
25
 
26
 
27
  def convert_seg_to_bbox(api_key: str, dataset_url: str, split_ratios=(0.8, 0.1, 0.1)):
28
+ """
29
+ 1) Download segmentation dataset from Roboflow
30
+ 2) Convert each mask to its bounding box (YOLO format)
31
+ 3) Split into train/valid/test
32
+ 4) Return before/after visuals plus (dataset_path, detection_slug)
33
+ """
34
+ rf = Roboflow(api_key=api_key)
35
+ ws, proj, ver = parse_roboflow_url(dataset_url)
36
+ version_obj = rf.workspace(ws).project(proj).version(ver)
37
+ dataset = version_obj.download("coco-segmentation")
38
+ root = dataset.location
39
+
40
+ # find the COCO JSON
41
  ann_file = None
42
  for dp, _, files in os.walk(root):
43
  for f in files:
 
47
  if ann_file:
48
  break
49
  if not ann_file:
50
+ raise FileNotFoundError(f"No JSON found under {root}")
51
 
52
  coco = json.load(open(ann_file, 'r'))
53
  images_info = {img['id']: img for img in coco['images']}
54
  cat_ids = sorted(c['id'] for c in coco.get('categories', []))
55
  id_to_index = {cid: idx for idx, cid in enumerate(cat_ids)}
56
 
57
+ # build YOLO bboxes
 
 
 
 
 
 
58
  annos = {}
59
  for anno in coco['annotations']:
60
  img_id = anno['image_id']
61
  xs, ys = anno['segmentation'][0][0::2], anno['segmentation'][0][1::2]
62
  xmin, xmax = min(xs), max(xs)
63
  ymin, ymax = min(ys), max(ys)
64
+ w, h = xmax - xmin, ymax - ymin
65
+ cx, cy = xmin + w/2, ymin + h/2
 
66
  iw, ih = images_info[img_id]['width'], images_info[img_id]['height']
67
  line = (
68
  f"{id_to_index[anno['category_id']]} "
 
70
  )
71
  annos.setdefault(img_id, []).append(line)
72
 
73
+ # copy and write out flat images + labels
74
+ out_root = tempfile.mkdtemp(prefix="yolov8_")
75
+ flat_img = os.path.join(out_root, "flat_images")
76
+ flat_lbl = os.path.join(out_root, "flat_labels")
77
+ os.makedirs(flat_img, exist_ok=True)
78
+ os.makedirs(flat_lbl, exist_ok=True)
79
+
80
  name_to_id = {img['file_name']: img['id'] for img in coco['images']}
81
  file_paths = {
82
  f: os.path.join(dp, f)
 
90
  if not src:
91
  continue
92
  shutil.copy(src, os.path.join(flat_img, fname))
93
+ lbl_path = os.path.join(flat_lbl, fname.rsplit('.',1)[0] + ".txt")
94
+ with open(lbl_path, 'w') as lf:
95
  lf.write("\n".join(annos.get(img_id, [])))
96
 
97
+ # split filenames
98
+ all_files = [f for f in os.listdir(flat_img) if f.lower().endswith(('.jpg','.png','.jpeg'))]
 
 
 
99
  random.shuffle(all_files)
100
  n = len(all_files)
101
  n_train = max(1, int(n * split_ratios[0]))
102
  n_valid = max(1, int(n * split_ratios[1]))
103
  n_valid = min(n_valid, n - n_train - 1)
 
104
  splits = {
105
  "train": all_files[:n_train],
106
  "valid": all_files[n_train:n_train+n_valid],
107
  "test": all_files[n_train+n_valid:]
108
  }
109
 
110
+ # move into final folder structure
111
  for split, files in splits.items():
112
+ img_dir = os.path.join(out_root, "images", split)
113
+ lbl_dir = os.path.join(out_root, "labels", split)
114
+ os.makedirs(img_dir, exist_ok=True)
115
+ os.makedirs(lbl_dir, exist_ok=True)
116
  for fn in files:
117
+ shutil.move(os.path.join(flat_img, fn), os.path.join(img_dir, fn))
118
  lbl = fn.rsplit('.',1)[0] + ".txt"
119
+ shutil.move(os.path.join(flat_lbl, lbl), os.path.join(lbl_dir, lbl))
120
 
121
  shutil.rmtree(flat_img)
122
  shutil.rmtree(flat_lbl)
123
 
124
+ # prepare a few before/after images for display
125
  before, after = [], []
126
  sample = random.sample(list(name_to_id.keys()), min(5, len(name_to_id)))
127
  for fname in sample:
128
+ img = cv2.cvtColor(cv2.imread(file_paths[fname]), cv2.COLOR_BGR2RGB)
 
 
 
129
 
130
+ # original segmentation overlay
131
  seg_vis = img.copy()
132
  for anno in coco['annotations']:
133
  if anno['image_id'] != name_to_id[fname]:
 
135
  pts = np.array(anno['segmentation'][0], np.int32).reshape(-1,2)
136
  cv2.polylines(seg_vis, [pts], True, (255,0,0), 2)
137
 
138
+ # bbox overlay
139
  box_vis = img.copy()
140
  for line in annos.get(name_to_id[fname], []):
141
  _, cxn, cyn, wnorm, hnorm = map(float, line.split())
 
148
  before.append(Image.fromarray(seg_vis))
149
  after.append(Image.fromarray(box_vis))
150
 
151
+ detection_slug = proj + "-detection"
152
+ return before, after, out_root, detection_slug
153
 
154
 
155
  def upload_and_train_detection(
156
  api_key: str,
157
+ detection_slug: str,
 
158
  dataset_path: str,
159
  project_license: str = "MIT",
160
  project_type: str = "object-detection"
161
  ):
162
+ """
163
+ Uploads your converted dataset into *your* active Roboflow workspace,
164
+ creates (or finds) a project named `detection_slug`, and kicks off training.
165
+ Returns the hosted endpoint URL.
166
+ """
167
  rf = Roboflow(api_key=api_key)
168
+ # use your active workspace (no name needed)
169
+ ws = rf.workspace()
170
 
171
+ # 1) get-or-create
172
  try:
173
+ proj = ws.project(detection_slug)
174
  except Exception as e:
175
  if "does not exist" in str(e).lower():
176
  proj = ws.create_project(
177
+ detection_slug,
178
  annotation=project_type,
179
  project_type=project_type,
180
  project_license=project_license
 
182
  else:
183
  raise
184
 
185
+ # 2) upload everything under dataset_path
 
 
 
 
 
 
 
 
 
 
 
186
  ws.upload_dataset(
187
  dataset_path,
188
+ proj.slug,
189
  project_license=project_license,
190
  project_type=project_type
191
  )
192
 
193
+ # 3) generate a new version
194
  try:
195
  version_num = proj.generate_version(settings={
196
  "augmentation": {},
 
199
  except RuntimeError as e:
200
  msg = str(e).lower()
201
  if "unsupported request" in msg or "does not exist" in msg:
202
+ # bump slug and retry
203
+ new_slug = proj.slug + "-v2"
204
  proj = ws.create_project(
205
  new_slug,
206
  annotation=project_type,
207
  project_type=project_type,
208
  project_license=project_license
209
  )
 
210
  ws.upload_dataset(
211
  dataset_path,
212
+ proj.slug,
213
  project_license=project_license,
214
  project_type=project_type
215
  )
 
220
  else:
221
  raise
222
 
223
+ # 4) train & return endpoint
224
  model = proj.version(str(version_num)).train()
 
 
225
  return f"{model['base_url']}{model['id']}?api_key={api_key}"
226
 
227
 
 
234
  run_btn = gr.Button("Convert to BBoxes")
235
  before_g = gr.Gallery(columns=5, label="Before")
236
  after_g = gr.Gallery(columns=5, label="After")
237
+ ds_state = gr.Textbox(visible=False, label="Converted Dataset Path")
238
+ slug_state = gr.Textbox(visible=False, label="Detection Project Slug")
 
239
 
240
  run_btn.click(
241
  convert_seg_to_bbox,
242
  inputs=[api_input, url_input],
243
+ outputs=[before_g, after_g, ds_state, slug_state]
244
  )
245
 
246
  gr.Markdown("## 🚀 Upload & Train Detection Model")
 
249
 
250
  train_btn.click(
251
  upload_and_train_detection,
252
+ inputs=[api_input, slug_state, ds_state],
253
  outputs=[url_out]
254
  )
255