rootglitch commited on
Commit
7877974
·
1 Parent(s): bd1381f

Debug mask images

Browse files
Files changed (1) hide show
  1. app.py +86 -87
app.py CHANGED
@@ -225,100 +225,99 @@ def draw_box(box: torch.Tensor, draw: ImageDraw.Draw, label: Optional[str]) -> N
225
 
226
  def run_grounded_sam(input_image):
227
  """Main function to run GroundingDINO and SAM-HQ"""
228
- try:
229
- # Create output directory
230
- os.makedirs(OUTPUT_DIR, exist_ok=True)
231
- text_prompt = 'car'
232
- task_type = 'text'
233
- box_threshold = 0.3
234
- text_threshold = 0.25
235
- iou_threshold = 0.8
236
- hq_token_only = True
 
 
 
 
 
 
 
 
 
237
 
238
- # Process input image
239
- if isinstance(input_image, dict):
240
- # Input from gradio sketch component
241
- scribble = np.array(input_image["mask"])
242
- image_pil = input_image["image"].convert("RGB")
243
- else:
244
- # Direct image input
245
- image_pil = input_image.convert("RGB") if input_image else None
246
- scribble = None
247
-
248
- if image_pil is None:
249
- logger.error("No input image provided")
250
- return [Image.new('RGB', (400, 300), color='gray')]
251
 
252
- # Transform image for GroundingDINO
253
- transformed_image = transform_image(image_pil)
254
-
255
- # Load models as needed
256
- ModelManager.load_model('groundingdino')
257
- size = image_pil.size
258
- H, W = size[1], size[0]
259
-
260
- # Run GroundingDINO with provided text
261
- boxes_filt, scores, pred_phrases = get_grounding_output(
262
- transformed_image, text_prompt, box_threshold, text_threshold
263
- )
264
 
265
- if boxes_filt is not None:
266
- # Scale boxes to image dimensions
267
- for i in range(boxes_filt.size(0)):
268
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
269
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
270
- boxes_filt[i][2:] += boxes_filt[i][:2]
271
-
272
- # Apply non-maximum suppression if we have multiple boxes
273
- if boxes_filt.size(0) > 1:
274
- logger.info(f"Before NMS: {boxes_filt.shape[0]} boxes")
275
- nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
276
- boxes_filt = boxes_filt[nms_idx]
277
- pred_phrases = [pred_phrases[idx] for idx in nms_idx]
278
- logger.info(f"After NMS: {boxes_filt.shape[0]} boxes")
279
-
280
- # Load SAM model
281
- ModelManager.load_model('sam')
282
- sam_predictor = ModelManager.get_model('sam_predictor')
283
-
284
- # Set image for SAM
285
- image = np.array(image_pil)
286
- sam_predictor.set_image(image)
287
-
288
- # Run SAM
289
- # Use boxes for these task types
290
- if boxes_filt.size(0) == 0:
291
- logger.warning("No boxes detected")
292
- return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))]
293
-
294
- transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
295
 
296
- masks, _, _ = sam_predictor.predict_torch(
297
- point_coords=None,
298
- point_labels=None,
299
- boxes=transformed_boxes,
300
- multimask_output=False,
301
- hq_token_only=hq_token_only,
302
- )
303
-
304
- # Create mask image
305
- mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
306
- mask_draw = ImageDraw.Draw(mask_image)
307
-
308
- # Draw masks
309
- for mask in masks:
310
- draw_mask(mask[0].cpu().numpy(), mask_draw)
311
-
312
- # Draw boxes and points on original image
313
- image_draw = ImageDraw.Draw(image_pil)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
- for box, label in zip(boxes_filt, pred_phrases):
316
- draw_box(box, image_draw, label)
317
 
318
- return mask_image
319
 
320
- except Exception as e:
321
- logger.error(f"Error in run_grounded_sam: {e}")
322
  # # Return original image on error
323
  # if isinstance(input_image, dict) and "image" in input_image:
324
  # return [input_image["image"], Image.new('RGBA', input_image["image"].size, color=(0, 0, 0, 0))]
 
225
 
226
  def run_grounded_sam(input_image):
227
  """Main function to run GroundingDINO and SAM-HQ"""
228
+ # Create output directory
229
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
230
+ text_prompt = 'car'
231
+ task_type = 'text'
232
+ box_threshold = 0.3
233
+ text_threshold = 0.25
234
+ iou_threshold = 0.8
235
+ hq_token_only = True
236
+
237
+ # Process input image
238
+ if isinstance(input_image, dict):
239
+ # Input from gradio sketch component
240
+ scribble = np.array(input_image["mask"])
241
+ image_pil = input_image["image"].convert("RGB")
242
+ else:
243
+ # Direct image input
244
+ image_pil = input_image.convert("RGB") if input_image else None
245
+ scribble = None
246
 
247
+ if image_pil is None:
248
+ logger.error("No input image provided")
249
+ return [Image.new('RGB', (400, 300), color='gray')]
 
 
 
 
 
 
 
 
 
 
250
 
251
+ # Transform image for GroundingDINO
252
+ transformed_image = transform_image(image_pil)
253
+
254
+ # Load models as needed
255
+ ModelManager.load_model('groundingdino')
256
+ size = image_pil.size
257
+ H, W = size[1], size[0]
258
+
259
+ # Run GroundingDINO with provided text
260
+ boxes_filt, scores, pred_phrases = get_grounding_output(
261
+ transformed_image, text_prompt, box_threshold, text_threshold
262
+ )
263
 
264
+ if boxes_filt is not None:
265
+ # Scale boxes to image dimensions
266
+ for i in range(boxes_filt.size(0)):
267
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
268
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
269
+ boxes_filt[i][2:] += boxes_filt[i][:2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
+ # Apply non-maximum suppression if we have multiple boxes
272
+ if boxes_filt.size(0) > 1:
273
+ logger.info(f"Before NMS: {boxes_filt.shape[0]} boxes")
274
+ nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
275
+ boxes_filt = boxes_filt[nms_idx]
276
+ pred_phrases = [pred_phrases[idx] for idx in nms_idx]
277
+ logger.info(f"After NMS: {boxes_filt.shape[0]} boxes")
278
+
279
+ # Load SAM model
280
+ ModelManager.load_model('sam')
281
+ sam_predictor = ModelManager.get_model('sam_predictor')
282
+
283
+ # Set image for SAM
284
+ image = np.array(image_pil)
285
+ sam_predictor.set_image(image)
286
+
287
+ # Run SAM
288
+ # Use boxes for these task types
289
+ if boxes_filt.size(0) == 0:
290
+ logger.warning("No boxes detected")
291
+ return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))]
292
+
293
+ transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
294
+
295
+ masks, _, _ = sam_predictor.predict_torch(
296
+ point_coords=None,
297
+ point_labels=None,
298
+ boxes=transformed_boxes,
299
+ multimask_output=False,
300
+ hq_token_only=hq_token_only,
301
+ )
302
+
303
+ # Create mask image
304
+ mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
305
+ mask_draw = ImageDraw.Draw(mask_image)
306
+
307
+ # Draw masks
308
+ for mask in masks:
309
+ draw_mask(mask[0].cpu().numpy(), mask_draw)
310
+
311
+ # Draw boxes and points on original image
312
+ image_draw = ImageDraw.Draw(image_pil)
313
 
314
+ for box, label in zip(boxes_filt, pred_phrases):
315
+ draw_box(box, image_draw, label)
316
 
317
+ return mask_image
318
 
319
+ # except Exception as e:
320
+ # logger.error(f"Error in run_grounded_sam: {e}")
321
  # # Return original image on error
322
  # if isinstance(input_image, dict) and "image" in input_image:
323
  # return [input_image["image"], Image.new('RGBA', input_image["image"].size, color=(0, 0, 0, 0))]