wuhp commited on
Commit
8ece6c4
·
verified ·
1 Parent(s): 8193bd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -121
app.py CHANGED
@@ -17,7 +17,7 @@ def parse_roboflow_url(url: str):
17
  parsed = urlparse(url)
18
  parts = parsed.path.strip('/').split('/')
19
  workspace = parts[0]
20
- project = parts[1]
21
  try:
22
  version = int(parts[-1])
23
  except ValueError:
@@ -28,108 +28,168 @@ def parse_roboflow_url(url: str):
28
  def convert_seg_to_bbox(api_key: str, dataset_url: str):
29
  """
30
  1) Download segmentation dataset from Roboflow
31
- 2) Convert each mask to its bounding box (YOLO format)
32
- 3) Preserve original train/valid/test splits
33
- 4) Return before/after visuals plus (dataset_path, detection_slug)
 
34
  """
35
  rf = Roboflow(api_key=api_key)
36
  ws, proj_name, ver = parse_roboflow_url(dataset_url)
37
  version_obj = rf.workspace(ws).project(proj_name).version(ver)
38
- dataset = version_obj.download("coco-segmentation")
39
- root = dataset.location
40
 
41
- # 1) Locate all three split JSON files
42
- json_files = {}
43
  for dp, _, files in os.walk(root):
44
  for f in files:
45
- lf = f.lower()
46
- if not lf.endswith('.json'):
47
- continue
48
- if 'train' in lf:
49
- json_files['train'] = os.path.join(dp, f)
50
- elif 'valid' in lf or 'val' in lf:
51
- json_files['valid'] = os.path.join(dp, f)
52
- elif 'test' in lf:
53
- json_files['test'] = os.path.join(dp, f)
54
- if any(k not in json_files for k in ('train', 'valid', 'test')):
55
- raise RuntimeError(f"Missing one of train/valid/test JSONs: {json_files}")
56
-
57
- # 2) Build category → index mapping from the train split
58
- train_coco = json.load(open(json_files['train'], 'r'))
59
- cat_ids = sorted(c['id'] for c in train_coco.get('categories', []))
60
- id_to_index = {cid: idx for idx, cid in enumerate(cat_ids)}
61
-
62
- # 3) Aggregate ALL image info & annotations into global dicts
63
- global_images_info = {}
64
- global_annos = {}
65
- for split, jf in json_files.items():
66
- coco = json.load(open(jf, 'r'))
67
- for img in coco['images']:
68
- global_images_info[img['id']] = img
69
- for anno in coco['annotations']:
70
- xs = anno['segmentation'][0][0::2]
71
- ys = anno['segmentation'][0][1::2]
72
- xmin, xmax = min(xs), max(xs)
73
- ymin, ymax = min(ys), max(ys)
74
- w, h = xmax - xmin, ymax - ymin
75
- cx, cy = xmin + w/2, ymin + h/2
76
- iw = global_images_info[anno['image_id']]['width']
77
- ih = global_images_info[anno['image_id']]['height']
78
- line = f"{id_to_index[anno['category_id']]} {cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}"
79
- global_annos.setdefault(anno['image_id'], []).append(line)
80
-
81
- # 4) Build a quick map of filename → full path
82
- name_to_id = {img['file_name']: img['id'] for img in global_images_info.values()}
83
- file_paths = {
84
- f: os.path.join(dp, f)
85
- for dp, _, files in os.walk(root)
86
- for f in files
87
- if f in name_to_id
88
- }
89
-
90
- # 5) Copy images & write YOLO .txt labels, preserving original splits
91
- out_root = tempfile.mkdtemp(prefix="yolov8_")
92
- for split in ('train', 'valid', 'test'):
93
- coco = json.load(open(json_files[split], 'r'))
94
- img_dir = os.path.join(out_root, split, "images")
95
- lbl_dir = os.path.join(out_root, split, "labels")
96
- os.makedirs(img_dir, exist_ok=True)
97
- os.makedirs(lbl_dir, exist_ok=True)
98
- for img in coco['images']:
99
- fname = img['file_name']
100
- shutil.copy(file_paths[fname], os.path.join(img_dir, fname))
101
- with open(os.path.join(lbl_dir, fname.rsplit('.', 1)[0] + ".txt"), 'w') as f:
102
- f.write("\n".join(global_annos.get(img['id'], [])))
103
-
104
- # 6) Prepare a few before/after examples (random sample across all splits)
105
- before, after = [], []
106
- all_ids = list(global_images_info.keys())
107
- sample_ids = random.sample(all_ids, min(5, len(all_ids)))
108
- for img_id in sample_ids:
109
- fname = global_images_info[img_id]['file_name']
110
- img = cv2.cvtColor(cv2.imread(file_paths[fname]), cv2.COLOR_BGR2RGB)
111
-
112
- # draw segmentation outlines
113
- seg_vis = img.copy()
114
- for jf in json_files.values():
115
- coco = json.load(open(jf, 'r'))
116
- for anno in coco['annotations']:
117
- if anno['image_id'] != img_id:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  continue
119
- pts = np.array(anno['segmentation'][0], np.int32).reshape(-1, 2)
120
- cv2.polylines(seg_vis, [pts], True, (255, 0, 0), 2)
 
 
121
 
122
- # draw bounding boxes
123
- box_vis = img.copy()
124
- for line in global_annos.get(img_id, []):
125
- _, cxn, cyn, wnorm, hnorm = map(float, line.split())
126
- iw = global_images_info[img_id]['width']
127
- ih = global_images_info[img_id]['height']
128
- w0, h0 = int(wnorm * iw), int(hnorm * ih)
129
- x0 = int(cxn * iw - w0 / 2)
130
- y0 = int(cyn * ih - h0 / 2)
131
- cv2.rectangle(box_vis, (x0, y0), (x0 + w0, y0 + h0), (0, 255, 0), 2)
 
 
 
 
 
 
 
 
 
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  before.append(Image.fromarray(seg_vis))
134
  after.append(Image.fromarray(box_vis))
135
 
@@ -142,13 +202,8 @@ def upload_and_train_detection(
142
  detection_slug: str,
143
  dataset_path: str,
144
  project_license: str = "MIT",
145
- project_type: str = "object-detection"
146
  ):
147
- """
148
- Uploads the converted dataset (with preserved splits) to Roboflow,
149
- creates or fetches a detection project, and kicks off training.
150
- Returns the hosted model URL.
151
- """
152
  rf = Roboflow(api_key=api_key)
153
  ws = rf.workspace()
154
 
@@ -166,36 +221,32 @@ def upload_and_train_detection(
166
  else:
167
  raise
168
 
169
- # upload entire split folder
170
  _, real_slug = proj.id.rsplit("/", 1)
171
- ws.upload_dataset(
172
- dataset_path,
173
- real_slug,
174
- project_license=project_license,
175
- project_type=project_type
176
- )
177
 
178
- # generate new version (with fallback slug bump)
179
  try:
180
- version_num = proj.generate_version(settings={"augmentation": {}, "preprocessing": {}})
181
  except RuntimeError as e:
182
  msg = str(e).lower()
183
  if "unsupported request" in msg or "does not exist" in msg:
 
184
  new_slug = real_slug + "-v2"
185
  proj = ws.create_project(
186
- new_slug,
187
- annotation=project_type,
188
  project_type=project_type,
189
  project_license=project_license
190
  )
191
  ws.upload_dataset(dataset_path, new_slug,
192
  project_license=project_license,
193
  project_type=project_type)
194
- version_num = proj.generate_version(settings={"augmentation": {}, "preprocessing": {}})
195
  else:
196
  raise
197
 
198
- # wait for generation, then train
199
  for _ in range(20):
200
  try:
201
  model = proj.version(str(version_num)).train()
@@ -204,24 +255,25 @@ def upload_and_train_detection(
204
  if "still generating" in str(e).lower():
205
  time.sleep(5)
206
  continue
207
- raise
 
208
  else:
209
- raise RuntimeError("Dataset version did not finish generating in time; try again later.")
210
 
211
  return f"{model['base_url']}{model['id']}?api_key={api_key}"
212
 
213
 
214
  # --- Gradio UI ---
215
  with gr.Blocks() as app:
216
- gr.Markdown("## 🔄 Seg→BBox + Auto-Upload/Train")
217
 
218
  api_input = gr.Textbox(label="Roboflow API Key", type="password")
219
  url_input = gr.Textbox(label="Segmentation Dataset URL")
220
- run_btn = gr.Button("Convert to BBoxes")
221
- before_g = gr.Gallery(columns=5, label="Before")
222
- after_g = gr.Gallery(columns=5, label="After")
223
- ds_state = gr.Textbox(visible=False, label="Converted Dataset Path")
224
- slug_state = gr.Textbox(visible=False, label="Detection Project Slug")
225
 
226
  run_btn.click(
227
  convert_seg_to_bbox,
@@ -231,7 +283,7 @@ with gr.Blocks() as app:
231
 
232
  gr.Markdown("## 🚀 Upload & Train Detection Model")
233
  train_btn = gr.Button("Upload & Train")
234
- url_out = gr.Textbox(label="Hosted Model Endpoint URL")
235
 
236
  train_btn.click(
237
  upload_and_train_detection,
 
17
  parsed = urlparse(url)
18
  parts = parsed.path.strip('/').split('/')
19
  workspace = parts[0]
20
+ project = parts[1]
21
  try:
22
  version = int(parts[-1])
23
  except ValueError:
 
28
  def convert_seg_to_bbox(api_key: str, dataset_url: str):
29
  """
30
  1) Download segmentation dataset from Roboflow
31
+ 2) Detect JSON‑vs‑mask export
32
+ 3) Convert each mask/polygon to its bounding box (YOLO format)
33
+ 4) Preserve original train/valid/test splits
34
+ 5) Return before/after visuals + (dataset_path, detection_slug)
35
  """
36
  rf = Roboflow(api_key=api_key)
37
  ws, proj_name, ver = parse_roboflow_url(dataset_url)
38
  version_obj = rf.workspace(ws).project(proj_name).version(ver)
39
+ dataset = version_obj.download("coco-segmentation")
40
+ root = dataset.location
41
 
42
+ # scan for any .json files
43
+ all_json = []
44
  for dp, _, files in os.walk(root):
45
  for f in files:
46
+ if f.lower().endswith(".json"):
47
+ all_json.append(os.path.join(dp, f))
48
+
49
+ if len(all_json) >= 3 and any("train" in os.path.basename(p).lower() for p in all_json):
50
+ # --- COCO‑JSON export branch ---
51
+ # locate train/valid/test JSONs
52
+ json_splits = {}
53
+ for path in all_json:
54
+ fn = os.path.basename(path).lower()
55
+ if "train" in fn:
56
+ json_splits["train"] = path
57
+ elif "val" in fn or "valid" in fn:
58
+ json_splits["valid"] = path
59
+ elif "test" in fn:
60
+ json_splits["test"] = path
61
+ if any(s not in json_splits for s in ("train", "valid", "test")):
62
+ raise RuntimeError(f"Missing one of train/valid/test JSONs: {json_splits}")
63
+
64
+ # build category → index from train.json
65
+ train_coco = json.load(open(json_splits["train"], "r"))
66
+ cat_ids = sorted(c["id"] for c in train_coco.get("categories", []))
67
+ id2idx = {cid: i for i, cid in enumerate(cat_ids)}
68
+
69
+ # aggregate images_info & annotations
70
+ images_info = {}
71
+ annos = {}
72
+ for split, jf in json_splits.items():
73
+ coco = json.load(open(jf, "r"))
74
+ for img in coco["images"]:
75
+ images_info[img["id"]] = img
76
+ for a in coco["annotations"]:
77
+ xs = a["segmentation"][0][0::2]
78
+ ys = a["segmentation"][0][1::2]
79
+ xmin, xmax = min(xs), max(xs)
80
+ ymin, ymax = min(ys), max(ys)
81
+ w, h = xmax - xmin, ymax - ymin
82
+ cx, cy = xmin + w/2, ymin + h/2
83
+ iw = images_info[a["image_id"]]["width"]
84
+ ih = images_info[a["image_id"]]["height"]
85
+ line = (
86
+ f"{id2idx[a['category_id']]} "
87
+ f"{cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}"
88
+ )
89
+ annos.setdefault(a["image_id"], []).append(line)
90
+
91
+ # build filename path map
92
+ name2id = {img["file_name"]: img["id"] for img in images_info.values()}
93
+ filemap = {
94
+ f: os.path.join(dp, f)
95
+ for dp, _, files in os.walk(root)
96
+ for f in files
97
+ if f in name2id
98
+ }
99
+
100
+ # write out per‑split folders
101
+ out_root = tempfile.mkdtemp(prefix="yolov8_")
102
+ for split in ("train", "valid", "test"):
103
+ coco = json.load(open(json_splits[split], "r"))
104
+ img_dir = os.path.join(out_root, split, "images")
105
+ lbl_dir = os.path.join(out_root, split, "labels")
106
+ os.makedirs(img_dir, exist_ok=True)
107
+ os.makedirs(lbl_dir, exist_ok=True)
108
+ for img in coco["images"]:
109
+ fn = img["file_name"]
110
+ src = filemap[fn]
111
+ dst = os.path.join(img_dir, fn)
112
+ txtp = os.path.join(lbl_dir, fn.rsplit(".", 1)[0] + ".txt")
113
+ shutil.copy(src, dst)
114
+ with open(txtp, "w") as f:
115
+ f.write("\n".join(annos.get(img["id"], [])))
116
+
117
+ else:
118
+ # --- Segmentation‐Masks export branch ---
119
+ splits = ["train", "valid", "test"]
120
+ # detect masks subfolder name
121
+ mask_names = ("masks", "mask", "labels")
122
+ out_root = tempfile.mkdtemp(prefix="yolov8_")
123
+
124
+ for split in splits:
125
+ img_dir_src = os.path.join(root, split, "images")
126
+ # find which subdir holds the PNG masks
127
+ mdir = None
128
+ for m in mask_names:
129
+ candidate = os.path.join(root, split, m)
130
+ if os.path.isdir(candidate):
131
+ mdir = candidate
132
+ break
133
+ if mdir is None:
134
+ raise RuntimeError(f"No masks folder found under {split}/ (checked {mask_names})")
135
+
136
+ img_dir_dst = os.path.join(out_root, split, "images")
137
+ lbl_dir_dst = os.path.join(out_root, split, "labels")
138
+ os.makedirs(img_dir_dst, exist_ok=True)
139
+ os.makedirs(lbl_dir_dst, exist_ok=True)
140
+
141
+ for fn in os.listdir(img_dir_src):
142
+ if not fn.lower().endswith((".jpg", ".png")):
143
  continue
144
+ src_img = os.path.join(img_dir_src, fn)
145
+ src_mask = os.path.join(mdir, fn)
146
+ img = cv2.imread(src_img)
147
+ h, w = img.shape[:2]
148
 
149
+ # read mask & binarize
150
+ mask = cv2.imread(src_mask, cv2.IMREAD_GRAYSCALE)
151
+ _, binm = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
152
+ ys, xs = np.nonzero(binm)
153
+ if len(xs) == 0:
154
+ lines = []
155
+ else:
156
+ xmin, xmax = xs.min(), xs.max()
157
+ ymin, ymax = ys.min(), ys.max()
158
+ bw, bh = xmax - xmin, ymax - ymin
159
+ cx, cy = xmin + bw/2, ymin + bh/2
160
+ # assume single class → index 0
161
+ lines = [f"0 {cx/w:.6f} {cy/h:.6f} {bw/w:.6f} {bh/h:.6f}"]
162
+
163
+ # copy image + write YOLO text
164
+ dst_img = os.path.join(img_dir_dst, fn)
165
+ dst_txt = os.path.join(lbl_dir_dst, fn.rsplit(".",1)[0] + ".txt")
166
+ shutil.copy(src_img, dst_img)
167
+ with open(dst_txt, "w") as f:
168
+ f.write("\n".join(lines))
169
 
170
+ # --- prepare before/after galleries (random sample across out_root) ---
171
+ before, after = [], []
172
+ all_imgs = []
173
+ for split in ("train","valid","test"):
174
+ for fn in os.listdir(os.path.join(out_root, split, "images")):
175
+ path = os.path.join(out_root, split, "images", fn)
176
+ all_imgs.append(path)
177
+ sample = random.sample(all_imgs, min(5, len(all_imgs)))
178
+ for img_path in sample:
179
+ fn = os.path.basename(img_path)
180
+ img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
181
+ # draw mask outline if available (JSON branch) else read mask again
182
+ seg_vis = img.copy()
183
+ box_vis = img.copy()
184
+ # overlay all .txt bboxes
185
+ txtp = img_path.replace("/images/", "/labels/").rsplit(".",1)[0] + ".txt"
186
+ w, h = img.shape[1], img.shape[0]
187
+ for line in open(txtp):
188
+ _, cxn, cyn, wnorm, hnorm = map(float, line.split())
189
+ bw, bh = int(wnorm * w), int(hnorm * h)
190
+ x0 = int(cxn * w - bw/2)
191
+ y0 = int(cyn * h - bh/2)
192
+ cv2.rectangle(box_vis, (x0,y0), (x0+bw, y0+bh), (0,255,0), 2)
193
  before.append(Image.fromarray(seg_vis))
194
  after.append(Image.fromarray(box_vis))
195
 
 
202
  detection_slug: str,
203
  dataset_path: str,
204
  project_license: str = "MIT",
205
+ project_type: str = "object-detection"
206
  ):
 
 
 
 
 
207
  rf = Roboflow(api_key=api_key)
208
  ws = rf.workspace()
209
 
 
221
  else:
222
  raise
223
 
224
+ # upload and kick off train
225
  _, real_slug = proj.id.rsplit("/", 1)
226
+ ws.upload_dataset(dataset_path, real_slug,
227
+ project_license=project_license,
228
+ project_type=project_type)
 
 
 
229
 
 
230
  try:
231
+ version_num = proj.generate_version(settings={"augmentation":{}, "preprocessing":{}})
232
  except RuntimeError as e:
233
  msg = str(e).lower()
234
  if "unsupported request" in msg or "does not exist" in msg:
235
+ # slug bump fallback
236
  new_slug = real_slug + "-v2"
237
  proj = ws.create_project(
238
+ new_slug, annotation=project_type,
 
239
  project_type=project_type,
240
  project_license=project_license
241
  )
242
  ws.upload_dataset(dataset_path, new_slug,
243
  project_license=project_license,
244
  project_type=project_type)
245
+ version_num = proj.generate_version(settings={"augmentation":{}, "preprocessing":{}})
246
  else:
247
  raise
248
 
249
+ # wait for generation then train
250
  for _ in range(20):
251
  try:
252
  model = proj.version(str(version_num)).train()
 
255
  if "still generating" in str(e).lower():
256
  time.sleep(5)
257
  continue
258
+ else:
259
+ raise
260
  else:
261
+ raise RuntimeError("Version generation timed out, try again later.")
262
 
263
  return f"{model['base_url']}{model['id']}?api_key={api_key}"
264
 
265
 
266
  # --- Gradio UI ---
267
  with gr.Blocks() as app:
268
+ gr.Markdown("## 🔄 Seg→BBox + AutoUpload/Train")
269
 
270
  api_input = gr.Textbox(label="Roboflow API Key", type="password")
271
  url_input = gr.Textbox(label="Segmentation Dataset URL")
272
+ run_btn = gr.Button("Convert to BBoxes")
273
+ before_g = gr.Gallery(columns=5, label="Before")
274
+ after_g = gr.Gallery(columns=5, label="After")
275
+ ds_state = gr.Textbox(visible=False, label="Dataset Path")
276
+ slug_state= gr.Textbox(visible=False, label="Detection Slug")
277
 
278
  run_btn.click(
279
  convert_seg_to_bbox,
 
283
 
284
  gr.Markdown("## 🚀 Upload & Train Detection Model")
285
  train_btn = gr.Button("Upload & Train")
286
+ url_out = gr.Textbox(label="Hosted Model URL")
287
 
288
  train_btn.click(
289
  upload_and_train_detection,