rootglitch commited on
Commit
a57bf8b
·
1 Parent(s): 5017c3c

Remove commented code

Browse files
Files changed (1) hide show
  1. app.py +11 -207
app.py CHANGED
@@ -32,7 +32,7 @@ import torchvision
32
  import gradio as gr
33
  import argparse
34
  from PIL import Image, ImageDraw, ImageFont
35
- from scipy import ndimage
36
 
37
  # Grounding DINO
38
  import GroundingDINO.groundingdino.datasets.transforms as T
@@ -104,17 +104,7 @@ class ModelManager:
104
  _models['sam_predictor'] = SamPredictor(sam)
105
 
106
  logger.info(f"SAM-HQ model loaded in {time.time() - start_time:.2f} seconds")
107
-
108
- # elif model_name == 'blip' and (_models['blip_processor'] is None or _models['blip_model'] is None):
109
- # logger.info("Loading BLIP model...")
110
- # start_time = time.time()
111
-
112
- # _models['blip_processor'] = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
113
- # _models['blip_model'] = BlipForConditionalGeneration.from_pretrained(
114
- # "Salesforce/blip-image-captioning-large", torch_dtype=torch.float16
115
- # ).to(device)
116
-
117
- # logger.info(f"BLIP model loaded in {time.time() - start_time:.2f} seconds")
118
 
119
  except Exception as e:
120
  logger.error(f"Error loading {model_name} model: {e}")
@@ -137,22 +127,6 @@ class ModelManager:
137
  torch.cuda.empty_cache()
138
 
139
 
140
- # def generate_caption(raw_image: Image.Image) -> str:
141
- # """Generate image caption using BLIP"""
142
- # try:
143
- # blip_processor = ModelManager.get_model('blip_processor')
144
- # blip_model = ModelManager.get_model('blip_model')
145
-
146
- # inputs = blip_processor(raw_image, return_tensors="pt").to(device, torch.float16)
147
- # out = blip_model.generate(**inputs)
148
- # caption = blip_processor.decode(out[0], skip_special_tokens=True)
149
- # logger.info(f"Generated caption: {caption}")
150
- # return caption
151
- # except Exception as e:
152
- # logger.error(f"Error generating caption: {e}")
153
- # return "Failed to generate caption."
154
-
155
-
156
  def transform_image(image_pil: Image.Image) -> torch.Tensor:
157
  """Transform PIL image for GroundingDINO"""
158
  transform = T.Compose([
@@ -218,10 +192,7 @@ def get_grounding_output(
218
 
219
  def draw_mask(mask: np.ndarray, draw: ImageDraw.Draw) -> None:
220
  """Draw mask on image"""
221
- # if random_color:
222
- # color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 153)
223
- # else:
224
- # color = (30, 144, 255, 153)
225
  color = (255, 255, 255, 255)
226
 
227
  nonzero_coords = np.transpose(np.nonzero(mask))
@@ -245,46 +216,6 @@ def draw_box(box: torch.Tensor, draw: ImageDraw.Draw, label: Optional[str]) -> N
245
  draw.text((box[0], box[1]), str(label), fill="white")
246
 
247
 
248
- # def draw_point(point: np.ndarray, draw: ImageDraw.Draw, r: int = 10) -> None:
249
- # """Draw points on image"""
250
- # for p in point:
251
- # x, y = p
252
- # draw.ellipse((x-r, y-r, x+r, y+r), fill='green')
253
-
254
-
255
- # def process_scribble_points(scribble: np.ndarray) -> np.ndarray:
256
- # """Process scribble mask to get point coordinates"""
257
- # # Transpose to get the correct orientation
258
- # scribble = scribble.transpose(2, 1, 0)[0]
259
- # # Label connected components
260
- # labeled_array, num_features = ndimage.label(scribble >= 255)
261
- # if num_features == 0:
262
- # logger.warning("No points detected in scribble")
263
- # return np.array([])
264
-
265
- # # Get center of mass for each component
266
- # centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features + 1))
267
- # return np.array(centers)
268
-
269
-
270
- # def process_scribble_box(scribble: np.ndarray) -> torch.Tensor:
271
- # """Process scribble mask to get bounding box"""
272
- # # Get point coordinates first
273
- # centers = process_scribble_points(scribble)
274
- # if len(centers) < 2:
275
- # logger.warning("Not enough points for bounding box, need at least 2")
276
- # # Return a default small box in the center if not enough points
277
- # return torch.tensor([[0.4, 0.4, 0.6, 0.6]])
278
-
279
- # # Define bounding box from scribble centers: (x_min, y_min, x_max, y_max)
280
- # x_min = centers[:, 0].min()
281
- # x_max = centers[:, 0].max()
282
- # y_min = centers[:, 1].min()
283
- # y_max = centers[:, 1].max()
284
- # bbox = np.array([x_min, y_min, x_max, y_max])
285
- # return torch.tensor(bbox).unsqueeze(0)
286
-
287
-
288
  def run_grounded_sam(
289
  input_image
290
  # text_prompt: str,
@@ -318,13 +249,7 @@ def run_grounded_sam(
318
  if image_pil is None:
319
  logger.error("No input image provided")
320
  return [Image.new('RGB', (400, 300), color='gray')]
321
-
322
- # # Prepare for scribble tasks
323
- # if task_type == 'scribble_box' or task_type == 'scribble_point':
324
- # if scribble is None:
325
- # logger.warning(f"No scribble provided for {task_type} task")
326
- # scribble = np.zeros((image_pil.height, image_pil.width, 3), dtype=np.uint8)
327
-
328
  # Transform image for GroundingDINO
329
  transformed_image = transform_image(image_pil)
330
 
@@ -337,49 +262,7 @@ def run_grounded_sam(
337
  boxes_filt, scores, pred_phrases = get_grounding_output(
338
  transformed_image, text_prompt, box_threshold, text_threshold
339
  )
340
-
341
- # # Process based on task type
342
- # if task_type == 'automatic':
343
- # # Generate caption with BLIP
344
- # ModelManager.load_model('blip')
345
- # text_prompt = generate_caption(image_pil)
346
- # logger.info(f"Automatic caption: {text_prompt}")
347
-
348
- # # Run GroundingDINO
349
- # boxes_filt, scores, pred_phrases = get_grounding_output(
350
- # transformed_image, text_prompt, box_threshold, text_threshold
351
- # )
352
-
353
- # elif task_type == 'text':
354
- # if not text_prompt:
355
- # logger.warning("No text prompt provided for 'text' task")
356
- # return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))]
357
-
358
- # # Run GroundingDINO with provided text
359
- # boxes_filt, scores, pred_phrases = get_grounding_output(
360
- # transformed_image, text_prompt, box_threshold, text_threshold
361
- # )
362
-
363
- # elif task_type == 'scribble_box':
364
- # # No need for GroundingDINO, get box from scribble
365
- # boxes_filt = process_scribble_box(scribble)
366
- # scores = torch.ones(boxes_filt.size(0))
367
- # pred_phrases = ["scribble_box"] * boxes_filt.size(0)
368
-
369
- # elif task_type == 'scribble_point':
370
- # # Will handle differently with SAM
371
- # point_coords = process_scribble_points(scribble)
372
- # if len(point_coords) == 0:
373
- # logger.warning("No points detected in scribble")
374
- # return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))]
375
-
376
- # boxes_filt = None # Not needed for point-based segmentation
377
-
378
- # else:
379
- # logger.error(f"Unknown task type: {task_type}")
380
- # return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))]
381
-
382
- # Process boxes if present (not for scribble_point)
383
  if boxes_filt is not None:
384
  # Scale boxes to image dimensions
385
  for i in range(boxes_filt.size(0)):
@@ -403,10 +286,6 @@ def run_grounded_sam(
403
  image = np.array(image_pil)
404
  sam_predictor.set_image(image)
405
 
406
- # # Convert string to boolean
407
- # if isinstance(hq_token_only, str):
408
- # hq_token_only = (hq_token_only.lower() == 'true')
409
-
410
  # Run SAM
411
  # Use boxes for these task types
412
  if boxes_filt.size(0) == 0:
@@ -422,55 +301,21 @@ def run_grounded_sam(
422
  multimask_output=False,
423
  hq_token_only=hq_token_only,
424
  )
425
-
426
- # elif task_type == 'scribble_point':
427
- # # Use points for this task type
428
- # point_labels = np.ones(point_coords.shape[0])
429
-
430
- # masks, _, _ = sam_predictor.predict(
431
- # point_coords=point_coords,
432
- # point_labels=point_labels,
433
- # box=None,
434
- # multimask_output=False,
435
- # hq_token_only=hq_token_only,
436
- # )
437
 
438
  # Create mask image
439
  mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
440
  mask_draw = ImageDraw.Draw(mask_image)
441
 
442
  # Draw masks
443
- if task_type == 'text':
444
- # for mask in masks:
445
- # draw_mask(mask, mask_draw, random_color=True)
446
- # else:
447
- for mask in masks:
448
- draw_mask(mask[0].cpu().numpy(), mask_draw)
449
 
450
  # Draw boxes and points on original image
451
  image_draw = ImageDraw.Draw(image_pil)
452
 
453
  for box, label in zip(boxes_filt, pred_phrases):
454
  draw_box(box, image_draw, label)
455
-
456
- # if task_type == 'scribble_box':
457
- # for box in boxes_filt:
458
- # draw_box(box, image_draw, None)
459
- # elif task_type in ['text', 'automatic']:
460
- # for box, label in zip(boxes_filt, pred_phrases):
461
- # draw_box(box, image_draw, label)
462
- # elif task_type == 'scribble_point':
463
- # draw_point(point_coords, image_draw)
464
-
465
- # Add caption text for automatic mode
466
- # if task_type == 'automatic':
467
- # image_draw.text((10, 10), text_prompt, fill='black')
468
-
469
- # Combine original image with mask
470
- # image_pil = image_pil.convert('RGBA')
471
- # image_pil.alpha_composite(mask_image)
472
-
473
- # return [image_pil, mask_image]
474
  return [mask_image]
475
 
476
  except Exception as e:
@@ -494,64 +339,23 @@ def create_ui():
494
  with gr.Row():
495
  with gr.Column():
496
  input_image = gr.Image(type="pil", label="image")
497
- # input_image = gr.ImageMask(
498
- # sources=["upload", "clipboard"],
499
- # transforms=[],
500
- # layers=False,
501
- # format="pil",
502
- # label="base image",
503
- # show_label=True
504
- # )
505
- # task_type = gr.Dropdown(
506
- # ["automatic", "scribble_point", "scribble_box", "text"],
507
- # value="automatic",
508
- # label="Task Type"
509
- # )
510
- # text_prompt = gr.Textbox(label="Text Prompt", placeholder="bench .")
511
- # hq_token_only = gr.Dropdown(
512
- # [False, True], value=False, label="hq_token_only"
513
- # )
514
  run_button = gr.Button(value='Run')
515
-
516
- # with gr.Accordion("Advanced options", open=False):
517
- # box_threshold = gr.Slider(
518
- # label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
519
- # )
520
- # text_threshold = gr.Slider(
521
- # label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
522
- # )
523
- # iou_threshold = gr.Slider(
524
- # label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
525
- # )
526
 
527
  with gr.Column():
528
  gallery = gr.Gallery(
529
  label="Generated images", show_label=False, elem_id="gallery"
530
  )
531
-
532
- # # Update visibility of text prompt based on task type
533
- # def update_text_prompt_visibility(task):
534
- # return gr.update(visible=(task == "text"))
535
-
536
- # task_type.change(
537
- # fn=update_text_prompt_visibility,
538
- # inputs=[task_type],
539
- # outputs=[text_prompt]
540
- # )
541
 
542
  # Run button
543
  run_button.click(
544
  fn=run_grounded_sam,
545
  inputs=[
546
  input_image
547
- # , text_prompt, task_type,
548
- # box_threshold, text_threshold, iou_threshold, hq_token_only
549
  ],
550
  outputs=gallery
551
- )
552
-
553
-
554
-
555
  return block
556
 
557
 
 
32
  import gradio as gr
33
  import argparse
34
  from PIL import Image, ImageDraw, ImageFont
35
+ # from scipy import ndimage
36
 
37
  # Grounding DINO
38
  import GroundingDINO.groundingdino.datasets.transforms as T
 
104
  _models['sam_predictor'] = SamPredictor(sam)
105
 
106
  logger.info(f"SAM-HQ model loaded in {time.time() - start_time:.2f} seconds")
107
+
 
 
 
 
 
 
 
 
 
 
108
 
109
  except Exception as e:
110
  logger.error(f"Error loading {model_name} model: {e}")
 
127
  torch.cuda.empty_cache()
128
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def transform_image(image_pil: Image.Image) -> torch.Tensor:
131
  """Transform PIL image for GroundingDINO"""
132
  transform = T.Compose([
 
192
 
193
  def draw_mask(mask: np.ndarray, draw: ImageDraw.Draw) -> None:
194
  """Draw mask on image"""
195
+
 
 
 
196
  color = (255, 255, 255, 255)
197
 
198
  nonzero_coords = np.transpose(np.nonzero(mask))
 
216
  draw.text((box[0], box[1]), str(label), fill="white")
217
 
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  def run_grounded_sam(
220
  input_image
221
  # text_prompt: str,
 
249
  if image_pil is None:
250
  logger.error("No input image provided")
251
  return [Image.new('RGB', (400, 300), color='gray')]
252
+
 
 
 
 
 
 
253
  # Transform image for GroundingDINO
254
  transformed_image = transform_image(image_pil)
255
 
 
262
  boxes_filt, scores, pred_phrases = get_grounding_output(
263
  transformed_image, text_prompt, box_threshold, text_threshold
264
  )
265
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  if boxes_filt is not None:
267
  # Scale boxes to image dimensions
268
  for i in range(boxes_filt.size(0)):
 
286
  image = np.array(image_pil)
287
  sam_predictor.set_image(image)
288
 
 
 
 
 
289
  # Run SAM
290
  # Use boxes for these task types
291
  if boxes_filt.size(0) == 0:
 
301
  multimask_output=False,
302
  hq_token_only=hq_token_only,
303
  )
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
  # Create mask image
306
  mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
307
  mask_draw = ImageDraw.Draw(mask_image)
308
 
309
  # Draw masks
310
+ for mask in masks:
311
+ draw_mask(mask[0].cpu().numpy(), mask_draw)
 
 
 
 
312
 
313
  # Draw boxes and points on original image
314
  image_draw = ImageDraw.Draw(image_pil)
315
 
316
  for box, label in zip(boxes_filt, pred_phrases):
317
  draw_box(box, image_draw, label)
318
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  return [mask_image]
320
 
321
  except Exception as e:
 
339
  with gr.Row():
340
  with gr.Column():
341
  input_image = gr.Image(type="pil", label="image")
342
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  run_button = gr.Button(value='Run')
 
 
 
 
 
 
 
 
 
 
 
344
 
345
  with gr.Column():
346
  gallery = gr.Gallery(
347
  label="Generated images", show_label=False, elem_id="gallery"
348
  )
 
 
 
 
 
 
 
 
 
 
349
 
350
  # Run button
351
  run_button.click(
352
  fn=run_grounded_sam,
353
  inputs=[
354
  input_image
 
 
355
  ],
356
  outputs=gallery
357
+ )
358
+
 
 
359
  return block
360
 
361