rootglitch commited on
Commit
52b5462
·
1 Parent(s): d04cb87

Add mask and blending modules

Browse files
Files changed (1) hide show
  1. app.py +160 -14
app.py CHANGED
@@ -225,9 +225,7 @@ def draw_box(box: torch.Tensor, draw: ImageDraw.Draw, label: Optional[str]) -> N
225
  draw.text((box[0], box[1]), str(label), fill="white")
226
 
227
 
228
- def run_grounded_sam(
229
- input_image
230
- ) -> List[Image.Image]:
231
  """Main function to run GroundingDINO and SAM-HQ"""
232
  try:
233
  # Create output directory
@@ -319,7 +317,7 @@ def run_grounded_sam(
319
  for box, label in zip(boxes_filt, pred_phrases):
320
  draw_box(box, image_draw, label)
321
 
322
- return [mask_image]
323
 
324
  except Exception as e:
325
  logger.error(f"Error in run_grounded_sam: {e}")
@@ -331,14 +329,85 @@ def run_grounded_sam(
331
  else:
332
  return [Image.new('RGB', (400, 300), color='gray'), Image.new('RGBA', (400, 300), color=(0, 0, 0, 0))]
333
 
334
- def upload_file(img_file):
335
- file_path = os.path.join(UPLOAD_DIR, 'input_image.png')
336
- img_file.save(file_path)
337
- print('file_path = ', file_path)
338
- print('hf_file_path = ', f"https://huggingface.co/spaces/rootglitch/CarVizGradioDemo01/{file_path}")
 
339
 
340
- # Return URL
341
- return file_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
  def encode_image(image):
344
  buffer = BytesIO()
@@ -368,10 +437,85 @@ def generate_ai_bg(input_img, prompt):
368
 
369
  return ic_light_img
370
 
371
- def generate_image(input_img, prompt):
372
- ai_gen_image = generate_ai_bg(input_img, prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
- return [ai_gen_image]
375
 
376
  def create_ui():
377
  """Create Gradio UI for CarViz demo"""
@@ -383,6 +527,7 @@ def create_ui():
383
  with gr.Row():
384
  with gr.Column():
385
  input_image = gr.Image(type="pil", label="image")
 
386
  prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
387
  run_button = gr.Button(value='Run')
388
 
@@ -396,6 +541,7 @@ def create_ui():
396
  fn=generate_image,
397
  inputs=[
398
  input_image,
 
399
  prompt
400
  ],
401
  outputs=gallery
 
225
  draw.text((box[0], box[1]), str(label), fill="white")
226
 
227
 
228
+ def run_grounded_sam(input_image):
 
 
229
  """Main function to run GroundingDINO and SAM-HQ"""
230
  try:
231
  # Create output directory
 
317
  for box, label in zip(boxes_filt, pred_phrases):
318
  draw_box(box, image_draw, label)
319
 
320
+ return mask_image
321
 
322
  except Exception as e:
323
  logger.error(f"Error in run_grounded_sam: {e}")
 
329
  else:
330
  return [Image.new('RGB', (400, 300), color='gray'), Image.new('RGBA', (400, 300), color=(0, 0, 0, 0))]
331
 
332
+ def image_gaussian_blur(image: torch.Tensor, radius: float) -> torch.Tensor:
333
+ if image.ndim == 4: # Remove batch dimension if present
334
+ image = image.squeeze(0)
335
+ pil_image = tensor2pil(image)
336
+ blurred_pil_image = pil_image.filter(ImageFilter.GaussianBlur(radius))
337
+ return pil2tensor(blurred_pil_image).squeeze(0)
338
 
339
+ def load_image(image_path: str) -> torch.Tensor:
340
+ image = Image.open(image_path).convert("RGBA")
341
+ image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
342
+ return image_tensor
343
+
344
+ def split_image_with_alpha(image: torch.Tensor):
345
+ out_images = image[:3, :, :]
346
+ out_alphas = image[3, :, :] if image.shape[0] > 3 else torch.ones_like(image[0, :, :])
347
+ result = (out_images.unsqueeze(0), 1.0 - out_alphas.unsqueeze(0))
348
+ return result
349
+
350
+ def pil2numpy(image: Image.Image):
351
+ return np.array(image).astype(np.float32) / 255.0
352
+
353
+ def numpy2pil(image: np.ndarray, mode=None):
354
+ return Image.fromarray(np.clip(255.0 * image, 0, 255).astype(np.uint8), mode)
355
+
356
+ def pil2tensor(image: Image.Image):
357
+ return torch.from_numpy(pil2numpy(image)).unsqueeze(0)
358
+
359
+ def invert(image):
360
+ s = 1.0 - image
361
+ return s
362
+
363
+ def tensor2pil(image: torch.Tensor, mode=None):
364
+ if image.ndim == 2: # Grayscale image
365
+ image = image.unsqueeze(0) # Add channel dimension
366
+
367
+ if image.ndim != 3 or image.shape[1:] == (0, 0):
368
+ raise ValueError(f"Invalid tensor dimensions: {image.shape}")
369
+
370
+ if image.shape[0] == 1: # Single channel, replicate to 3 channels
371
+ image = image.repeat(3, 1, 1)
372
+ elif image.shape[0] != 3:
373
+ raise ValueError("Unexpected number of channels in the image tensor")
374
+
375
+ return numpy2pil(image.cpu().numpy().transpose(1, 2, 0), mode=mode)
376
+
377
+ def extract_high_frequency(image: torch.Tensor, blur_radius: float = 5.0) -> torch.Tensor:
378
+ """Extract high-frequency details by subtracting the blurred image from the original."""
379
+ if image.ndim == 4:
380
+ image = image.squeeze(0)
381
+
382
+ blurred = image_gaussian_blur(image, blur_radius)
383
+
384
+ if blurred.ndim == 4:
385
+ blurred = blurred.squeeze(0)
386
+ elif blurred.ndim == 3 and blurred.shape[0] != 3:
387
+ blurred = blurred.permute(2, 0, 1)
388
+
389
+ high_freq = image - blurred
390
+ return high_freq
391
+
392
+ def image_blend_mask(image_a, image_b, mask, blend_percentage):
393
+
394
+ # Convert images to PIL
395
+ img_a = tensor2pil(image_a)
396
+ img_b = tensor2pil(image_b)
397
+ mask = ImageOps.invert(tensor2pil(mask).convert('L'))
398
+
399
+ # Mask image
400
+ masked_img = Image.composite(img_a, img_b, mask.resize(img_a.size))
401
+
402
+ # Blend image
403
+ blend_mask = Image.new(mode="L", size=img_a.size,
404
+ color=(round(blend_percentage * 255)))
405
+ blend_mask = ImageOps.invert(blend_mask)
406
+ img_result = Image.composite(img_a, masked_img, blend_mask)
407
+
408
+ del img_a, img_b, blend_mask, mask
409
+
410
+ return (pil2tensor(img_result), )
411
 
412
  def encode_image(image):
413
  buffer = BytesIO()
 
437
 
438
  return ic_light_img
439
 
440
+ def blend_details(input_image, relit_image, masked_image):
441
+ with torch.inference_mode():
442
+ # Load and resize images
443
+ # input_image = load_image(input_image_path)
444
+ # relit_image = load_image(relit_image_path)
445
+ # masked_image = load_image(masked_image_path)
446
+
447
+ # Resize input image
448
+ input_image = torch.nn.functional.interpolate(
449
+ input_image.unsqueeze(0),
450
+ size=(1024, 1024),
451
+ mode="bicubic",
452
+ align_corners=False
453
+ ).squeeze(0)
454
+
455
+ # Resize relit image
456
+ relit_image = torch.nn.functional.interpolate(
457
+ relit_image.unsqueeze(0),
458
+ size=(1024, 1024),
459
+ mode="bicubic",
460
+ align_corners=False
461
+ ).squeeze(0)
462
+
463
+ # Resize masked image
464
+ masked_image = torch.nn.functional.interpolate(
465
+ masked_image.unsqueeze(0),
466
+ size=(1024, 1024),
467
+ mode="bicubic",
468
+ align_corners=False
469
+ ).squeeze(0)
470
+
471
+ # Split images and get RGB channels
472
+ input_image_rgb = split_image_with_alpha(input_image)[0].squeeze(0)
473
+ relit_image_rgb = split_image_with_alpha(relit_image)[0].squeeze(0)
474
+
475
+ # Use masked image RGB channels as segmentation mask (average of RGB channels)
476
+ segmentation_mask = masked_image[:3].mean(dim=0) # Average RGB channels to get grayscale mask
477
+
478
+ print(f"segmentation_mask shape: {segmentation_mask.shape}")
479
+
480
+ # Extract high-frequency details from input image
481
+ high_freq_details = extract_high_frequency(input_image_rgb, blur_radius=3.0)
482
+
483
+ # Print shapes for debugging
484
+ print(f"high_freq_details shape: {high_freq_details.shape}")
485
+ print(f"segmentation_mask shape: {segmentation_mask.shape}")
486
+ print(f"relit_image_rgb shape: {relit_image_rgb.shape}")
487
+
488
+ # Apply high-frequency details only in masked areas
489
+ detail_strength = 0.5
490
+ segmentation_mask = segmentation_mask.unsqueeze(0).repeat(3, 1, 1) # Expand mask to match RGB channels
491
+ masked_details = high_freq_details * segmentation_mask
492
+ # final_image = relit_image_rgb + (masked_details * detail_strength)
493
+ # final_image = image_blend_mask(relit_image_rgb, masked_details, mask, blend_percentage)
494
+ final_image = relit_image_rgb + masked_details
495
+ print('final_image shape:', final_image.shape)
496
+
497
+ # Normalize to [0, 1] range
498
+ final_image = torch.clamp(final_image, 0, 1)
499
+
500
+ # Save intermediate results for debugging
501
+ tensor2pil(segmentation_mask).save("output/segmentation_mask.png")
502
+ tensor2pil(high_freq_details).save("output/high_freq_details.png")
503
+ tensor2pil(masked_details).save("output/masked_details.png")
504
+
505
+ # Save final result
506
+ final_image_pil = tensor2pil(final_image)
507
+ # final_image_pil.save("output/output_image.png")
508
+ return [final_image_pil]
509
+
510
+ def generate_image(input_img, ai_gen_image, prompt):
511
+
512
+ # ai_gen_image = generate_ai_bg(input_img, prompt)
513
+
514
+ mask_input_image = run_grounded_sam(input_img)
515
+
516
+ final_image = blend_details(input_img, ai_gen_image, mask_input_image)
517
 
518
+ return [final_image]
519
 
520
  def create_ui():
521
  """Create Gradio UI for CarViz demo"""
 
527
  with gr.Row():
528
  with gr.Column():
529
  input_image = gr.Image(type="pil", label="image")
530
+ ai_image = gr.Image(type="pil", label="image")
531
  prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
532
  run_button = gr.Button(value='Run')
533
 
 
541
  fn=generate_image,
542
  inputs=[
543
  input_image,
544
+ ai_image,
545
  prompt
546
  ],
547
  outputs=gallery