wuhp commited on
Commit
a711e94
·
verified ·
1 Parent(s): ac01980

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -90
app.py CHANGED
@@ -26,14 +26,18 @@ def parse_roboflow_url(url: str):
26
 
27
 
28
  def convert_seg_to_bbox(api_key: str, dataset_url: str):
29
- """Download a segmentation dataset from Roboflow, convert to YOLO bboxes, and return before/after galleries."""
 
 
 
 
30
  rf = Roboflow(api_key=api_key)
31
  ws, proj, ver = parse_roboflow_url(dataset_url)
32
  version_obj = rf.workspace(ws).project(proj).version(ver)
33
  dataset = version_obj.download("coco-segmentation")
34
  root = dataset.location
35
 
36
- # 1) Find annotation JSON
37
  ann_file = None
38
  for dp, _, files in os.walk(root):
39
  for f in files:
@@ -51,146 +55,136 @@ def convert_seg_to_bbox(api_key: str, dataset_url: str):
51
  if ann_file:
52
  break
53
  if not ann_file:
54
- raise FileNotFoundError("No JSON annotations found under %s" % root)
55
 
56
  coco = json.load(open(ann_file, 'r'))
57
  images_info = {img['id']: img for img in coco['images']}
58
  cat_ids = sorted(c['id'] for c in coco.get('categories', []))
59
  id_to_index = {cid: idx for idx, cid in enumerate(cat_ids)}
60
 
61
- # 2) Prepare YOLO folders
62
  out_root = tempfile.mkdtemp(prefix="yolov8_")
63
  img_out = os.path.join(out_root, "images")
64
  lbl_out = os.path.join(out_root, "labels")
65
  os.makedirs(img_out, exist_ok=True)
66
  os.makedirs(lbl_out, exist_ok=True)
67
 
68
- # 3) Convert seg→bbox
69
  annos = {}
70
- for anno in coco['annotations']:
71
- img_id = anno['image_id']
72
- poly = anno['segmentation'][0]
73
  xs, ys = poly[0::2], poly[1::2]
74
- x_min, x_max = min(xs), max(xs)
75
- y_min, y_max = min(ys), max(ys)
76
- w, h = x_max - x_min, y_max - y_min
77
- cx, cy = x_min + w / 2, y_min + h / 2
78
-
79
- iw, ih = images_info[img_id]['width'], images_info[img_id]['height']
80
- line = f"{id_to_index[anno['category_id']]} {cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}"
81
- annos.setdefault(img_id, []).append(line)
82
-
83
- # 4) Find images and write labels
84
- train_img_dir = None
85
  for dp, _, files in os.walk(root):
86
- if any(f.lower().endswith(('.jpg', '.png', '.jpeg')) for f in files):
87
- train_img_dir = dp
88
  break
89
- if not train_img_dir:
90
- raise FileNotFoundError("No image files found under %s" % root)
91
 
92
- name_to_id = {img['file_name']: img['id'] for img in coco['images']}
93
- for fname, img_id in name_to_id.items():
94
- src = os.path.join(train_img_dir, fname)
95
  if not os.path.isfile(src):
96
  continue
97
  shutil.copy(src, os.path.join(img_out, fname))
98
- with open(os.path.join(lbl_out, fname.rsplit('.', 1)[0] + ".txt"), 'w') as lf:
99
- lf.write("\n".join(annos.get(img_id, [])))
100
 
101
- # 5) Build galleries
102
  before, after = [], []
103
- sample = random.sample(list(name_to_id.keys()), min(5, len(name_to_id)))
104
- for fname in sample:
105
- src = os.path.join(train_img_dir, fname)
106
- img = cv2.cvtColor(cv2.imread(src), cv2.COLOR_BGR2RGB)
107
 
108
- # draw seg polygons
109
  seg_vis = img.copy()
110
- for anno in coco['annotations']:
111
- if anno['image_id'] != name_to_id[fname]:
112
  continue
113
- pts = np.array(anno['segmentation'][0], np.int32).reshape(-1, 2)
114
- cv2.polylines(seg_vis, [pts], True, (255, 0, 0), 2)
115
 
116
- # draw boxes
117
  box_vis = img.copy()
118
- for line in annos.get(name_to_id[fname], []):
119
  _, cxn, cyn, wnorm, hnorm = map(float, line.split())
120
- iw, ih = images_info[name_to_id[fname]]['width'], images_info[name_to_id[fname]]['height']
121
- w0, h0 = int(wnorm * iw), int(hnorm * ih)
122
- x0 = int(cxn * iw - w0 / 2)
123
- y0 = int(cyn * ih - h0 / 2)
124
- cv2.rectangle(box_vis, (x0, y0), (x0 + w0, y0 + h0), (0, 255, 0), 2)
125
 
126
  before.append(Image.fromarray(seg_vis))
127
  after.append(Image.fromarray(box_vis))
128
 
129
- return before, after
130
 
131
 
132
- def upload_and_train_detection(
133
- api_key: str,
134
- project_id: str,
135
- dataset_path: str,
136
- project_license: str = "MIT",
137
- project_type: str = "object-detection",
138
- preprocessing: dict = None,
139
- augmentation: dict = None,
140
- speed: str = "fast"
141
- ):
142
  """
143
- Upload a local detection dataset to Roboflow, generate+train a new version,
144
- and return the hosted inference endpoint URL.
 
145
  """
146
  rf = Roboflow(api_key=api_key)
147
  ws = rf.workspace()
148
 
149
- # 1) upload
150
  ws.upload_dataset(
151
  dataset_path,
152
- project_id,
153
- project_license=project_license,
154
- project_type=project_type
155
  )
156
 
157
- # 2) generate version
158
- proj = ws.project(project_id)
159
- version_number = proj.generate_version(
160
- preprocessing=preprocessing or {},
161
- augmentation=augmentation or {}
162
- )
163
 
164
- # 3) train
165
- proj.version(version_number).train(speed=speed)
166
 
167
- # 4) fetch model endpoint info
168
  m = proj.version(str(version_number)).model
169
- endpoint = f"{m['base_url']}{m['id']}?api_key={api_key}"
170
- return endpoint
171
 
172
 
173
- # --- Gradio app ---
174
  with gr.Blocks() as app:
175
- gr.Markdown("## 🔄 Segmentation → YOLOv8 Converter")
176
- api_input1 = gr.Textbox(label="Roboflow API Key", type="password")
177
- url_input = gr.Textbox(label="Segmentation Dataset URL")
178
- run_btn = gr.Button("Convert to BBoxes")
179
- before_g = gr.Gallery(label="Before (Segmentation)", columns=5)
180
- after_g = gr.Gallery(label="After (BBoxes)", columns=5)
181
- run_btn.click(fn=convert_seg_to_bbox, inputs=[api_input1, url_input], outputs=[before_g, after_g])
182
-
183
- gr.Markdown("## 🚀 Upload & Train Detection Model")
184
- api_input2 = gr.Textbox(label="Roboflow API Key", type="password")
185
- project_input = gr.Textbox(label="Project ID (slug)")
186
- path_input = gr.Textbox(label="Local Dataset Path")
187
- train_btn = gr.Button("Upload & Train")
188
- url_output = gr.Textbox(label="Hosted Model Endpoint URL")
 
 
 
 
 
 
 
189
  train_btn.click(
190
  fn=upload_and_train_detection,
191
- inputs=[api_input2, project_input, path_input],
192
- outputs=[url_output],
193
  )
194
 
 
 
195
  if __name__ == "__main__":
196
  app.launch()
 
26
 
27
 
28
  def convert_seg_to_bbox(api_key: str, dataset_url: str):
29
+ """
30
+ 1) Download segmentation dataset from Roboflow
31
+ 2) Convert masks → YOLOv8 bboxes
32
+ Returns before_gallery, after_gallery, local_dataset_path, project_slug
33
+ """
34
  rf = Roboflow(api_key=api_key)
35
  ws, proj, ver = parse_roboflow_url(dataset_url)
36
  version_obj = rf.workspace(ws).project(proj).version(ver)
37
  dataset = version_obj.download("coco-segmentation")
38
  root = dataset.location
39
 
40
+ # find annotation JSON
41
  ann_file = None
42
  for dp, _, files in os.walk(root):
43
  for f in files:
 
55
  if ann_file:
56
  break
57
  if not ann_file:
58
+ raise FileNotFoundError(f"No JSON annotations found under {root}")
59
 
60
  coco = json.load(open(ann_file, 'r'))
61
  images_info = {img['id']: img for img in coco['images']}
62
  cat_ids = sorted(c['id'] for c in coco.get('categories', []))
63
  id_to_index = {cid: idx for idx, cid in enumerate(cat_ids)}
64
 
65
+ # prepare YOLOv8 folders
66
  out_root = tempfile.mkdtemp(prefix="yolov8_")
67
  img_out = os.path.join(out_root, "images")
68
  lbl_out = os.path.join(out_root, "labels")
69
  os.makedirs(img_out, exist_ok=True)
70
  os.makedirs(lbl_out, exist_ok=True)
71
 
72
+ # convert seg bbox
73
  annos = {}
74
+ for a in coco['annotations']:
75
+ pid = a['image_id']
76
+ poly = a['segmentation'][0]
77
  xs, ys = poly[0::2], poly[1::2]
78
+ xmin, xmax, ymin, ymax = min(xs), max(xs), min(ys), max(ys)
79
+ w, h = xmax - xmin, ymax - ymin
80
+ cx, cy = xmin + w/2, ymin + h/2
81
+ iw, ih = images_info[pid]['width'], images_info[pid]['height']
82
+ line = f"{id_to_index[a['category_id']]} {cx/iw:.6f} {cy/ih:.6f} {w/iw:.6f} {h/ih:.6f}"
83
+ annos.setdefault(pid, []).append(line)
84
+
85
+ # locate images and write labels
86
+ img_dir = None
 
 
87
  for dp, _, files in os.walk(root):
88
+ if any(f.lower().endswith(('.jpg','.png','jpeg')) for f in files):
89
+ img_dir = dp
90
  break
91
+ if not img_dir:
92
+ raise FileNotFoundError(f"No image files found under {root}")
93
 
94
+ fname2id = {img['file_name']: img['id'] for img in coco['images']}
95
+ for fname, pid in fname2id.items():
96
+ src = os.path.join(img_dir, fname)
97
  if not os.path.isfile(src):
98
  continue
99
  shutil.copy(src, os.path.join(img_out, fname))
100
+ with open(os.path.join(lbl_out, fname.rsplit('.',1)[0] + ".txt"), 'w') as lf:
101
+ lf.write("\n".join(annos.get(pid, [])))
102
 
103
+ # build preview galleries
104
  before, after = [], []
105
+ sample = random.sample(list(fname2id.keys()), min(5, len(fname2id)))
106
+ for fn in sample:
107
+ img = cv2.cvtColor(cv2.imread(os.path.join(img_dir, fn)), cv2.COLOR_BGR2RGB)
 
108
 
 
109
  seg_vis = img.copy()
110
+ for a in coco['annotations']:
111
+ if a['image_id'] != fname2id[fn]:
112
  continue
113
+ pts = np.array(a['segmentation'][0], np.int32).reshape(-1,2)
114
+ cv2.polylines(seg_vis, [pts], True, (255,0,0), 2)
115
 
 
116
  box_vis = img.copy()
117
+ for line in annos.get(fname2id[fn], []):
118
  _, cxn, cyn, wnorm, hnorm = map(float, line.split())
119
+ iw, ih = images_info[fname2id[fn]]['width'], images_info[fname2id[fn]]['height']
120
+ w0, h0 = int(wnorm*iw), int(hnorm*ih)
121
+ x0, y0 = int(cxn*iw - w0/2), int(cyn*ih - h0/2)
122
+ cv2.rectangle(box_vis, (x0,y0), (x0+w0,y0+h0), (0,255,0), 2)
 
123
 
124
  before.append(Image.fromarray(seg_vis))
125
  after.append(Image.fromarray(box_vis))
126
 
127
+ return before, after, out_root, proj # proj is our slug
128
 
129
 
130
+ def upload_and_train_detection(api_key: str, project_slug: str, dataset_path: str):
 
 
 
 
 
 
 
 
 
131
  """
132
+ 1) Upload local YOLOv8 dataset to Roboflow
133
+ 2) Generate & train a new detection version
134
+ Returns the hosted inference endpoint URL.
135
  """
136
  rf = Roboflow(api_key=api_key)
137
  ws = rf.workspace()
138
 
139
+ # upload dataset
140
  ws.upload_dataset(
141
  dataset_path,
142
+ project_slug,
143
+ project_license="MIT",
144
+ project_type="object-detection"
145
  )
146
 
147
+ # generate a new version
148
+ proj = ws.project(project_slug)
149
+ version_number = proj.generate_version(preprocessing={}, augmentation={})
 
 
 
150
 
151
+ # train model (fast)
152
+ proj.version(version_number).train(speed="fast")
153
 
154
+ # fetch hosted endpoint
155
  m = proj.version(str(version_number)).model
156
+ return f"{m['base_url']}{m['id']}?api_key={api_key}"
 
157
 
158
 
 
159
  with gr.Blocks() as app:
160
+ gr.Markdown("## 🔄 Segmentation → YOLOv8 Converter + Auto Trainer")
161
+
162
+ # Converter UI
163
+ api_input = gr.Textbox(label="Roboflow API Key", type="password")
164
+ url_input = gr.Textbox(label="Segmentation Dataset URL")
165
+ convert_btn = gr.Button("Convert to BBoxes")
166
+ before_gal = gr.Gallery(label="Before (Segmentation)", columns=5)
167
+ after_gal = gr.Gallery(label="After (BBoxes)", columns=5)
168
+ state_path = gr.State()
169
+ state_slug = gr.State()
170
+
171
+ convert_btn.click(
172
+ fn=convert_seg_to_bbox,
173
+ inputs=[api_input, url_input],
174
+ outputs=[before_gal, after_gal, state_path, state_slug]
175
+ )
176
+
177
+ # Train UI
178
+ train_btn = gr.Button("Upload & Train Detection Model")
179
+ endpoint_text = gr.Textbox(label="Hosted Detection Endpoint URL")
180
+
181
  train_btn.click(
182
  fn=upload_and_train_detection,
183
+ inputs=[api_input, state_slug, state_path],
184
+ outputs=[endpoint_text]
185
  )
186
 
187
+ gr.Markdown("> First convert your seg data, then click **Upload & Train** to deploy your detection model.")
188
+
189
  if __name__ == "__main__":
190
  app.launch()