wuhp commited on
Commit
2d88615
·
verified ·
1 Parent(s): eecca4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -97
app.py CHANGED
@@ -27,21 +27,22 @@ def parse_roboflow_url(url: str):
27
 
28
  def convert_seg_to_bbox(api_key: str, dataset_url: str):
29
  """
30
- Download a segmentation dataset from Roboflow,
31
- convert masks → YOLOv8 bboxes,
32
- and return (before, after) galleries + local YOLO dataset path + auto slug.
 
33
  """
34
  rf = Roboflow(api_key=api_key)
35
- ws_name, seg_proj_slug, ver = parse_roboflow_url(dataset_url)
36
- version_obj = rf.workspace(ws_name).project(seg_proj_slug).version(ver)
37
  dataset = version_obj.download("coco-segmentation")
38
  root = dataset.location
39
 
40
- # find the annotation JSON
41
  ann_file = None
42
  for dp, _, files in os.walk(root):
43
  for f in files:
44
- if f.lower().endswith(".json") and "train" in f.lower():
45
  ann_file = os.path.join(dp, f)
46
  break
47
  if ann_file:
@@ -49,74 +50,84 @@ def convert_seg_to_bbox(api_key: str, dataset_url: str):
49
  if not ann_file:
50
  for dp, _, files in os.walk(root):
51
  for f in files:
52
- if f.lower().endswith(".json"):
53
  ann_file = os.path.join(dp, f)
54
  break
55
  if ann_file:
56
  break
57
  if not ann_file:
58
- raise FileNotFoundError(f"No JSON annotations found under {root}")
59
 
60
- coco = json.load(open(ann_file, "r"))
61
- images_info = {img["id"]: img for img in coco["images"]}
62
- cat_ids = sorted(c["id"] for c in coco.get("categories", []))
 
63
  id_to_index = {cid: idx for idx, cid in enumerate(cat_ids)}
64
 
65
- # prepare YOLOv8 folders
66
  out_root = tempfile.mkdtemp(prefix="yolov8_")
67
  img_out = os.path.join(out_root, "images")
68
  lbl_out = os.path.join(out_root, "labels")
69
  os.makedirs(img_out, exist_ok=True)
70
  os.makedirs(lbl_out, exist_ok=True)
71
 
72
- # convert seg bbox labels
73
  annos = {}
74
- for a in coco["annotations"]:
75
- pid = a["image_id"]
76
- poly = a["segmentation"][0]
77
  xs, ys = poly[0::2], poly[1::2]
78
- xmin, xmax, ymin, ymax = min(xs), max(xs), min(ys), max(ys)
79
- w, h = xmax - xmin, ymax - ymin
80
- cx, cy = xmin + w/2, ymin + h/2
81
- iw, ih = images_info[pid]["width"], images_info[pid]["height"]
82
- line = f"{id_to_index[a['category_id']]} {cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}"
83
- annos.setdefault(pid, []).append(line)
84
-
85
- # locate images and write labels
86
- img_dir = None
 
 
 
 
 
87
  for dp, _, files in os.walk(root):
88
- if any(f.lower().endswith((".jpg", ".png", ".jpeg")) for f in files):
89
- img_dir = dp
90
  break
91
- if not img_dir:
92
- raise FileNotFoundError(f"No images found under {root}")
93
 
94
- fname2id = {img["file_name"]: img["id"] for img in coco["images"]}
95
- for fname, pid in fname2id.items():
96
- src = os.path.join(img_dir, fname)
 
97
  if not os.path.isfile(src):
98
  continue
99
  shutil.copy(src, os.path.join(img_out, fname))
100
- with open(os.path.join(lbl_out, fname.rsplit(".", 1)[0] + ".txt"), "w") as lf:
101
- lf.write("\n".join(annos.get(pid, [])))
102
 
103
- # build preview galleries
104
  before, after = [], []
105
- sample = random.sample(list(fname2id.keys()), min(5, len(fname2id)))
106
- for fn in sample:
107
- img = cv2.cvtColor(cv2.imread(os.path.join(img_dir, fn)), cv2.COLOR_BGR2RGB)
 
108
 
 
109
  seg_vis = img.copy()
110
- for a in coco["annotations"]:
111
- if a["image_id"] != fname2id[fn]:
112
  continue
113
- pts = np.array(a["segmentation"][0], np.int32).reshape(-1, 2)
114
  cv2.polylines(seg_vis, [pts], True, (255, 0, 0), 2)
115
 
 
116
  box_vis = img.copy()
117
- for line in annos.get(fname2id[fn], []):
118
  _, cxn, cyn, wnorm, hnorm = map(float, line.split())
119
- iw, ih = images_info[fname2id[fn]]["width"], images_info[fname2id[fn]]["height"]
120
  w0, h0 = int(wnorm * iw), int(hnorm * ih)
121
  x0 = int(cxn * iw - w0/2)
122
  y0 = int(cyn * ih - h0/2)
@@ -125,85 +136,87 @@ def convert_seg_to_bbox(api_key: str, dataset_url: str):
125
  before.append(Image.fromarray(seg_vis))
126
  after.append(Image.fromarray(box_vis))
127
 
128
- # auto‐slug for the detection project
129
- detection_slug = f"{seg_proj_slug}-detection"
130
- return before, after, out_root, detection_slug
131
 
132
 
133
- def upload_and_train_detection(api_key: str, project_slug: str, dataset_path: str):
 
 
 
 
 
 
134
  """
135
- Given a YOLOv8 dataset folder, upload → version → train →
136
- return inference endpoint URL. Auto‐creates the project if needed.
 
 
137
  """
138
  rf = Roboflow(api_key=api_key)
139
  ws = rf.workspace()
140
 
141
- # get or create the detection project
142
  try:
143
  proj = ws.project(project_slug)
144
  except Exception:
 
145
  proj = ws.create_project(
146
- project_name=project_slug,
147
- project_type="object-detection",
148
- project_license="MIT"
 
149
  )
150
 
151
- # upload the dataset
152
  ws.upload_dataset(
153
  dataset_path,
154
- proj.id,
155
- num_workers=10,
156
- project_license="MIT",
157
- project_type="object-detection",
158
- batch_name=None,
159
- num_retries=0
160
  )
161
 
162
- # generate a new version
163
- new_v = proj.generate_version(settings={"preprocessing": {}, "augmentation": {}})
164
-
165
- # train (fast)
166
- version = proj.version(new_v)
167
- version.train(speed="fast")
168
 
169
- # return the hosted inference URL
170
- m = version.model
171
- return f"{m['base_url']}{m['id']}?api_key={api_key}"
 
172
 
173
 
 
174
  with gr.Blocks() as app:
175
- gr.Markdown("## 🔄 Segmentation → YOLOv8 + 📡 Auto‑Deploy Detector")
176
-
177
- # ─ Convert UI ─────────────────────────────────────────
178
- api = gr.Textbox(label="Roboflow API Key", type="password")
179
- segurl = gr.Textbox(label="Segmentation Dataset URL")
180
- btn_c = gr.Button("Convert to YOLOv8 BBoxes")
181
- out_b = gr.Gallery(label="Before (masks)")
182
- out_a = gr.Gallery(label="After (bboxes)")
183
- state_path = gr.State()
184
- state_slug = gr.State()
185
-
186
- btn_c.click(
187
- fn=convert_seg_to_bbox,
188
- inputs=[api, segurl],
189
- outputs=[out_b, out_a, state_path, state_slug]
190
- )
191
 
192
- gr.Markdown("---")
 
 
 
 
 
193
 
194
- # Train UI ───────────────────────────────────────────
195
- btn_t = gr.Button("Upload & Train Detection Model")
196
- endpoint = gr.Textbox(label="Hosted Detection Endpoint URL")
197
 
198
- btn_t.click(
199
- fn=upload_and_train_detection,
200
- inputs=[api, state_slug, state_path],
201
- outputs=[endpoint]
202
  )
203
 
204
- gr.Markdown(
205
- "> 1) Paste your segmentation URL and Convert. \n"
206
- "> 2) Then Upload & Train to instantly get your detector’s endpoint."
 
 
 
 
 
207
  )
208
 
209
  if __name__ == "__main__":
 
27
 
28
  def convert_seg_to_bbox(api_key: str, dataset_url: str):
29
  """
30
+ 1) Download a segmentation dataset
31
+ 2) Convert all masks → YOLO‐style bboxes
32
+ 3) Write out a temp YOLO dataset and return its path
33
+ 4) Return before/after galleries + the dataset path + an auto slug
34
  """
35
  rf = Roboflow(api_key=api_key)
36
+ ws, proj_name, ver = parse_roboflow_url(dataset_url)
37
+ version_obj = rf.workspace(ws).project(proj_name).version(ver)
38
  dataset = version_obj.download("coco-segmentation")
39
  root = dataset.location
40
 
41
+ # Find the annotation JSON
42
  ann_file = None
43
  for dp, _, files in os.walk(root):
44
  for f in files:
45
+ if 'train' in f.lower() and f.lower().endswith('.json'):
46
  ann_file = os.path.join(dp, f)
47
  break
48
  if ann_file:
 
50
  if not ann_file:
51
  for dp, _, files in os.walk(root):
52
  for f in files:
53
+ if f.lower().endswith('.json'):
54
  ann_file = os.path.join(dp, f)
55
  break
56
  if ann_file:
57
  break
58
  if not ann_file:
59
+ raise FileNotFoundError(f"No JSON annotations under {root}")
60
 
61
+ with open(ann_file, 'r') as f:
62
+ coco = json.load(f)
63
+ images_info = {img['id']: img for img in coco['images']}
64
+ cat_ids = sorted(c['id'] for c in coco.get('categories', []))
65
  id_to_index = {cid: idx for idx, cid in enumerate(cat_ids)}
66
 
67
+ # Prepare YOLO folders
68
  out_root = tempfile.mkdtemp(prefix="yolov8_")
69
  img_out = os.path.join(out_root, "images")
70
  lbl_out = os.path.join(out_root, "labels")
71
  os.makedirs(img_out, exist_ok=True)
72
  os.makedirs(lbl_out, exist_ok=True)
73
 
74
+ # Convert seg→bbox
75
  annos = {}
76
+ for anno in coco['annotations']:
77
+ img_id = anno['image_id']
78
+ poly = anno['segmentation'][0]
79
  xs, ys = poly[0::2], poly[1::2]
80
+ x_min, x_max = min(xs), max(xs)
81
+ y_min, y_max = min(ys), max(ys)
82
+ w, h = x_max - x_min, y_max - y_min
83
+ cx, cy = x_min + w/2, y_min + h/2
84
+
85
+ iw, ih = images_info[img_id]['width'], images_info[img_id]['height']
86
+ line = (
87
+ f"{id_to_index[anno['category_id']]} "
88
+ f"{cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}"
89
+ )
90
+ annos.setdefault(img_id, []).append(line)
91
+
92
+ # Find the image folder
93
+ train_img_dir = None
94
  for dp, _, files in os.walk(root):
95
+ if any(f.lower().endswith(('.jpg', '.png', '.jpeg')) for f in files):
96
+ train_img_dir = dp
97
  break
98
+ if not train_img_dir:
99
+ raise FileNotFoundError(f"No images under {root}")
100
 
101
+ # Copy images + write labels
102
+ name_to_id = {img['file_name']: img['id'] for img in coco['images']}
103
+ for fname, img_id in name_to_id.items():
104
+ src = os.path.join(train_img_dir, fname)
105
  if not os.path.isfile(src):
106
  continue
107
  shutil.copy(src, os.path.join(img_out, fname))
108
+ with open(os.path.join(lbl_out, fname.rsplit('.', 1)[0] + ".txt"), 'w') as lf:
109
+ lf.write("\n".join(annos.get(img_id, [])))
110
 
111
+ # Build before/after sample galleries
112
  before, after = [], []
113
+ sample = random.sample(list(name_to_id.keys()), min(5, len(name_to_id)))
114
+ for fname in sample:
115
+ src = os.path.join(train_img_dir, fname)
116
+ img = cv2.cvtColor(cv2.imread(src), cv2.COLOR_BGR2RGB)
117
 
118
+ # segmentation overlay
119
  seg_vis = img.copy()
120
+ for anno in coco['annotations']:
121
+ if anno['image_id'] != name_to_id[fname]:
122
  continue
123
+ pts = np.array(anno['segmentation'][0], np.int32).reshape(-1, 2)
124
  cv2.polylines(seg_vis, [pts], True, (255, 0, 0), 2)
125
 
126
+ # bbox overlay
127
  box_vis = img.copy()
128
+ for line in annos.get(name_to_id[fname], []):
129
  _, cxn, cyn, wnorm, hnorm = map(float, line.split())
130
+ iw, ih = images_info[name_to_id[fname]]['width'], images_info[name_to_id[fname]]['height']
131
  w0, h0 = int(wnorm * iw), int(hnorm * ih)
132
  x0 = int(cxn * iw - w0/2)
133
  y0 = int(cyn * ih - h0/2)
 
136
  before.append(Image.fromarray(seg_vis))
137
  after.append(Image.fromarray(box_vis))
138
 
139
+ # auto-generated detection project slug
140
+ project_slug = f"{proj_name}-detection"
141
+ return before, after, out_root, project_slug
142
 
143
 
144
+ def upload_and_train_detection(
145
+ api_key: str,
146
+ project_slug: str,
147
+ dataset_path: str,
148
+ project_license: str = "MIT",
149
+ project_type: str = "object-detection"
150
+ ):
151
  """
152
+ 1) (re)create a Detection project
153
+ 2) upload the YOLO dataset
154
+ 3) generate & train a new version
155
+ 4) return the hosted endpoint URL
156
  """
157
  rf = Roboflow(api_key=api_key)
158
  ws = rf.workspace()
159
 
160
+ # 1) get-or-create project (need annotation arg)
161
  try:
162
  proj = ws.project(project_slug)
163
  except Exception:
164
+ # annotation must be provided as the 2nd positional arg
165
  proj = ws.create_project(
166
+ project_slug,
167
+ annotation=project_type,
168
+ project_type=project_type,
169
+ project_license=project_license
170
  )
171
 
172
+ # 2) upload dataset
173
  ws.upload_dataset(
174
  dataset_path,
175
+ project_slug,
176
+ project_license=project_license,
177
+ project_type=project_type
 
 
 
178
  )
179
 
180
+ # 3) generate a new version (no args = default preprocessing/augmentation)
181
+ version_num = proj.generate_version()
182
+ # 4) train it
183
+ proj.version(str(version_num)).train()
 
 
184
 
185
+ # 5) grab its hosted endpoint
186
+ m = proj.version(str(version_num)).model
187
+ endpoint = f"{m['base_url']}{m['id']}?api_key={api_key}"
188
+ return endpoint
189
 
190
 
191
+ # --- Gradio app ---
192
  with gr.Blocks() as app:
193
+ gr.Markdown("## 🔄 Segmentation → YOLOv8 Converter + Auto‐Upload")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ # single API key input
196
+ api_input = gr.Textbox(label="Roboflow API Key", type="password")
197
+ url_input = gr.Textbox(label="Segmentation Dataset URL")
198
+ run_btn = gr.Button("Convert to BBoxes")
199
+ before_g = gr.Gallery(label="Before (Segmentation)", columns=5, height="auto")
200
+ after_g = gr.Gallery(label="After (BBoxes)", columns=5, height="auto")
201
 
202
+ # hidden states for the YOLO dataset path & auto‐slug
203
+ dataset_path_state = gr.Textbox(visible=False)
204
+ project_slug_state = gr.Textbox(visible=False)
205
 
206
+ run_btn.click(
207
+ fn=convert_seg_to_bbox,
208
+ inputs=[api_input, url_input],
209
+ outputs=[before_g, after_g, dataset_path_state, project_slug_state]
210
  )
211
 
212
+ gr.Markdown("## 🚀 Upload & Train Detection Model")
213
+ train_btn = gr.Button("Upload & Train Detection Model")
214
+ url_output = gr.Textbox(label="Model Endpoint URL")
215
+
216
+ train_btn.click(
217
+ fn=upload_and_train_detection,
218
+ inputs=[api_input, project_slug_state, dataset_path_state],
219
+ outputs=[url_output]
220
  )
221
 
222
  if __name__ == "__main__":