rootglitch commited on
Commit
186da5b
·
1 Parent(s): b66ab63

Added High freq blending code

Browse files
Files changed (1) hide show
  1. app.py +161 -148
app.py CHANGED
@@ -7,8 +7,16 @@ import time
7
  import logging
8
  import dotenv
9
  import fal_client
10
- import requests
11
  import base64
 
 
 
 
 
 
 
 
 
12
  from io import BytesIO
13
  from typing import Dict, List, Tuple, Union, Optional
14
 
@@ -34,14 +42,6 @@ sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
34
  sys.path.append(os.path.join(os.getcwd(), "sam-hq"))
35
  warnings.filterwarnings("ignore")
36
 
37
- import numpy as np
38
- import torch
39
- import torchvision
40
- import gradio as gr
41
- import argparse
42
- from PIL import Image, ImageFilter, ImageOps, ImageDraw, ImageFont
43
-
44
-
45
  # Grounding DINO
46
  import GroundingDINO.groundingdino.datasets.transforms as T
47
  from GroundingDINO.groundingdino.models import build_model
@@ -56,10 +56,10 @@ CONFIG_FILE = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
56
  GROUNDINGDINO_CHECKPOINT = "groundingdino_swint_ogc.pth"
57
  SAM_CHECKPOINT = 'sam_hq_vit_l.pth'
58
  OUTPUT_DIR = "outputs"
59
- FAL_KEY = os.getenv("FAL_KEY")
60
- UPLOAD_DIR = "./tmp/images"
61
 
62
- os.makedirs(UPLOAD_DIR, exist_ok=True)
63
 
64
  # Global variables for model caching
65
  _models = {
@@ -329,89 +329,106 @@ def run_grounded_sam(input_image):
329
  else:
330
  return [Image.new('RGB', (400, 300), color='gray'), Image.new('RGBA', (400, 300), color=(0, 0, 0, 0))]
331
 
332
- def image_gaussian_blur(image: torch.Tensor, radius: float) -> torch.Tensor:
333
- if image.ndim == 4: # Remove batch dimension if present
334
- image = image.squeeze(0)
335
- pil_image = tensor2pil(image)
336
- blurred_pil_image = pil_image.filter(ImageFilter.GaussianBlur(radius))
337
- return pil2tensor(blurred_pil_image).squeeze(0)
338
-
339
- def load_image(image_path: str) -> torch.Tensor:
340
- image = Image.open(image_path).convert("RGBA")
341
- image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
342
- return image_tensor
343
-
344
- def split_image_with_alpha(image: torch.Tensor):
345
- out_images = image[:3, :, :]
346
- out_alphas = image[3, :, :] if image.shape[0] > 3 else torch.ones_like(image[0, :, :])
347
- result = (out_images.unsqueeze(0), 1.0 - out_alphas.unsqueeze(0))
348
- return result
349
-
350
- def pil2numpy(image: Image.Image):
351
- return np.array(image).astype(np.float32) / 255.0
352
-
353
- def numpy2pil(image: np.ndarray, mode=None):
354
- return Image.fromarray(np.clip(255.0 * image, 0, 255).astype(np.uint8), mode)
355
 
356
- def pil2tensor(image: Image.Image):
357
- return torch.from_numpy(pil2numpy(image)).unsqueeze(0)
 
 
358
 
359
- def invert(image):
360
- s = 1.0 - image
361
- return s
362
 
363
- def image2tensor(image) -> torch.Tensor:
364
- image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
365
- return image_tensor
 
 
 
366
 
367
- def tensor2pil(image: torch.Tensor, mode=None):
368
- if image.ndim == 2: # Grayscale image
369
- image = image.unsqueeze(0) # Add channel dimension
 
 
370
 
371
- if image.ndim != 3 or image.shape[1:] == (0, 0):
372
- raise ValueError(f"Invalid tensor dimensions: {image.shape}")
 
 
 
 
 
373
 
374
- if image.shape[0] == 1: # Single channel, replicate to 3 channels
375
- image = image.repeat(3, 1, 1)
376
- elif image.shape[0] != 3:
377
- raise ValueError("Unexpected number of channels in the image tensor")
378
-
379
- return numpy2pil(image.cpu().numpy().transpose(1, 2, 0), mode=mode)
380
 
381
- def extract_high_frequency(image: torch.Tensor, blur_radius: float = 5.0) -> torch.Tensor:
382
- """Extract high-frequency details by subtracting the blurred image from the original."""
383
- if image.ndim == 4:
384
- image = image.squeeze(0)
385
-
386
- blurred = image_gaussian_blur(image, blur_radius)
387
-
388
- if blurred.ndim == 4:
389
- blurred = blurred.squeeze(0)
390
- elif blurred.ndim == 3 and blurred.shape[0] != 3:
391
- blurred = blurred.permute(2, 0, 1)
392
-
393
- high_freq = image - blurred
394
- return high_freq
395
 
396
- def image_blend_mask(image_a, image_b, mask, blend_percentage):
397
-
398
- # Convert images to PIL
399
- img_a = tensor2pil(image_a)
400
- img_b = tensor2pil(image_b)
401
- mask = ImageOps.invert(tensor2pil(mask).convert('L'))
402
 
403
  # Mask image
404
- masked_img = Image.composite(img_a, img_b, mask.resize(img_a.size))
405
 
406
  # Blend image
407
- blend_mask = Image.new(mode="L", size=img_a.size,
408
  color=(round(blend_percentage * 255)))
409
  blend_mask = ImageOps.invert(blend_mask)
410
- img_result = Image.composite(img_a, masked_img, blend_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
 
412
- del img_a, img_b, blend_mask, mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
- return (pil2tensor(img_result), )
415
 
416
  def encode_image(image):
417
  buffer = BytesIO()
@@ -442,75 +459,71 @@ def generate_ai_bg(input_img, prompt):
442
  return ic_light_img
443
 
444
  def blend_details(input_image, relit_image, masked_image):
445
- with torch.inference_mode():
446
- # Convert images to tensors
447
- input_image = image2tensor(input_image)
448
- relit_image = image2tensor(relit_image)
449
- masked_image = image2tensor(masked_image)
450
-
451
- # Resize input image
452
- input_image = torch.nn.functional.interpolate(
453
- input_image.unsqueeze(0),
454
- size=(1024, 1024),
455
- mode="bicubic",
456
- align_corners=False
457
- ).squeeze(0)
458
-
459
- # Resize relit image
460
- relit_image = torch.nn.functional.interpolate(
461
- relit_image.unsqueeze(0),
462
- size=(1024, 1024),
463
- mode="bicubic",
464
- align_corners=False
465
- ).squeeze(0)
466
-
467
- # Resize masked image
468
- masked_image = torch.nn.functional.interpolate(
469
- masked_image.unsqueeze(0),
470
- size=(1024, 1024),
471
- mode="bicubic",
472
- align_corners=False
473
- ).squeeze(0)
474
-
475
- # Split images and get RGB channels
476
- input_image_rgb = split_image_with_alpha(input_image)[0].squeeze(0)
477
- relit_image_rgb = split_image_with_alpha(relit_image)[0].squeeze(0)
478
-
479
- # Use masked image RGB channels as segmentation mask (average of RGB channels)
480
- segmentation_mask = masked_image[:3].mean(dim=0) # Average RGB channels to get grayscale mask
481
-
482
- print(f"segmentation_mask shape: {segmentation_mask.shape}")
483
-
484
- # Extract high-frequency details from input image
485
- high_freq_details = extract_high_frequency(input_image_rgb, blur_radius=3.0)
486
-
487
- # Print shapes for debugging
488
- print(f"high_freq_details shape: {high_freq_details.shape}")
489
- print(f"segmentation_mask shape: {segmentation_mask.shape}")
490
- print(f"relit_image_rgb shape: {relit_image_rgb.shape}")
491
-
492
- # Apply high-frequency details only in masked areas
493
- detail_strength = 0.5
494
- segmentation_mask = segmentation_mask.unsqueeze(0).repeat(3, 1, 1) # Expand mask to match RGB channels
495
- masked_details = high_freq_details * segmentation_mask
496
- # final_image = relit_image_rgb + (masked_details * detail_strength)
497
- # final_image = image_blend_mask(relit_image_rgb, masked_details, mask, blend_percentage)
498
- final_image = relit_image_rgb + masked_details
499
- print('final_image shape:', final_image.shape)
500
-
501
- # Normalize to [0, 1] range
502
- final_image = torch.clamp(final_image, 0, 1)
503
-
504
- # Save intermediate results for debugging
505
- # tensor2pil(segmentation_mask).save("output/segmentation_mask.png")
506
- # tensor2pil(high_freq_details).save("output/high_freq_details.png")
507
- # tensor2pil(masked_details).save("output/masked_details.png")
508
 
509
- # Save final result
510
- final_image_pil = tensor2pil(final_image)
511
- # final_image_pil.save("output/output_image.png")
512
- return final_image_pil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
 
 
514
  def generate_image(input_img, ai_gen_image, prompt):
515
 
516
  # ai_gen_image = generate_ai_bg(input_img, prompt)
@@ -539,7 +552,7 @@ def create_ui():
539
  # gallery = gr.Gallery(
540
  # label="Generated images", show_label=False, elem_id="gallery"
541
  # )
542
- masked_image = gr.Image(label="Generated Image")
543
  output_image = gr.Image(label="Generated Image")
544
 
545
  # Run button
@@ -550,7 +563,7 @@ def create_ui():
550
  ai_image,
551
  prompt
552
  ],
553
- outputs=[masked_image, output_image]
554
  )
555
 
556
  return block
 
7
  import logging
8
  import dotenv
9
  import fal_client
 
10
  import base64
11
+ import numpy as np
12
+ import math
13
+ import scipy
14
+ import torch
15
+ import torchvision
16
+ import gradio as gr
17
+ import argparse
18
+ import spaces
19
+ from PIL import Image, ImageFilter, ImageOps, ImageDraw, ImageFont
20
  from io import BytesIO
21
  from typing import Dict, List, Tuple, Union, Optional
22
 
 
42
  sys.path.append(os.path.join(os.getcwd(), "sam-hq"))
43
  warnings.filterwarnings("ignore")
44
 
 
 
 
 
 
 
 
 
45
  # Grounding DINO
46
  import GroundingDINO.groundingdino.datasets.transforms as T
47
  from GroundingDINO.groundingdino.models import build_model
 
56
  GROUNDINGDINO_CHECKPOINT = "groundingdino_swint_ogc.pth"
57
  SAM_CHECKPOINT = 'sam_hq_vit_l.pth'
58
  OUTPUT_DIR = "outputs"
59
+ # FAL_KEY = os.getenv("FAL_KEY")
60
+ # UPLOAD_DIR = "./tmp/images"
61
 
62
+ # os.makedirs(UPLOAD_DIR, exist_ok=True)
63
 
64
  # Global variables for model caching
65
  _models = {
 
329
  else:
330
  return [Image.new('RGB', (400, 300), color='gray'), Image.new('RGBA', (400, 300), color=(0, 0, 0, 0))]
331
 
332
+ def split_image_with_alpha(image):
333
+ image = image.convert("RGB")
334
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
+ def gaussian_blur(image, radius=10):
337
+ """Apply Gaussian blur to image."""
338
+ blurred = image.filter(ImageFilter.GaussianBlur(radius=10))
339
+ return blurred
340
 
341
+ def invert_image(image):
342
+ img_inverted = ImageOps.invert(image)
343
+ return img_inverted
344
 
345
+ def expand_mask(mask, expand, tapered_corners):
346
+ # Ensure mask is in grayscale (mode 'L')
347
+ mask = mask.convert("L")
348
+
349
+ # Convert to NumPy array
350
+ mask_np = np.array(mask)
351
 
352
+ # Define kernel
353
+ c = 0 if tapered_corners else 1
354
+ kernel = np.array([[c, 1, c],
355
+ [1, 1, 1],
356
+ [c, 1, c]], dtype=np.uint8)
357
 
358
+ # Perform dilation or erosion based on expand value
359
+ if expand > 0:
360
+ for _ in range(expand):
361
+ mask_np = scipy.ndimage.grey_dilation(mask_np, footprint=kernel)
362
+ elif expand < 0:
363
+ for _ in range(abs(expand)):
364
+ mask_np = scipy.ndimage.grey_erosion(mask_np, footprint=kernel)
365
 
366
+ # Convert back to PIL image
367
+ return Image.fromarray(mask_np, mode="L")
 
 
 
 
368
 
369
+ def image_blend_by_mask(image_a, image_b, mask, blend_percentage):
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
+ mask = ImageOps.invert(mask.convert('L'))
 
 
 
 
 
372
 
373
  # Mask image
374
+ masked_img = Image.composite(image_a, image_b, mask)
375
 
376
  # Blend image
377
+ blend_mask = Image.new(mode="L", size=image_a.size,
378
  color=(round(blend_percentage * 255)))
379
  blend_mask = ImageOps.invert(blend_mask)
380
+ img_result = Image.composite(image_a, masked_img, blend_mask)
381
+
382
+ del image_a, image_b, blend_mask, mask
383
+
384
+ return img_result
385
+
386
+ def blend_images(image_a, image_b, blend_percentage):
387
+ """Blend img_b over image_a using the normal mode with a blend percentage."""
388
+ img_a = image_a.convert("RGBA")
389
+ img_b = image_b.convert("RGBA")
390
+
391
+ # Blend img_b over img_a using alpha_composite (normal blend mode)
392
+ out_image = Image.alpha_composite(img_a, img_b)
393
+
394
+ out_image = out_image.convert("RGB")
395
+
396
+ # Create blend mask
397
+ blend_mask = Image.new("L", image_a.size, round(blend_percentage * 255))
398
+ blend_mask = ImageOps.invert(blend_mask) # Invert the mask
399
 
400
+ # Apply composite blend
401
+ result = Image.composite(image_a, out_image, blend_mask)
402
+ return result
403
+
404
+ def apply_image_levels(image, black_level, mid_level, white_level):
405
+ levels = AdjustLevels(black_level, mid_level, white_level)
406
+ adjusted_image = levels.adjust(image)
407
+ return adjusted_image
408
+
409
+ class AdjustLevels:
410
+ def __init__(self, min_level, mid_level, max_level):
411
+ self.min_level = min_level
412
+ self.mid_level = mid_level
413
+ self.max_level = max_level
414
+
415
+ def adjust(self, im):
416
+
417
+ im_arr = np.array(im).astype(np.float32)
418
+ im_arr[im_arr < self.min_level] = self.min_level
419
+ im_arr = (im_arr - self.min_level) * \
420
+ (255 / (self.max_level - self.min_level))
421
+ im_arr = np.clip(im_arr, 0, 255)
422
+
423
+ # mid-level adjustment
424
+ gamma = math.log(0.5) / math.log((self.mid_level - self.min_level) / (self.max_level - self.min_level))
425
+ im_arr = np.power(im_arr / 255, gamma) * 255
426
+
427
+ im_arr = im_arr.astype(np.uint8)
428
+
429
+ im = Image.fromarray(im_arr)
430
 
431
+ return im
432
 
433
  def encode_image(image):
434
  buffer = BytesIO()
 
459
  return ic_light_img
460
 
461
  def blend_details(input_image, relit_image, masked_image):
462
+ # input_image = load_image(input_image_path)
463
+ # relit_image = load_image(relit_image_path)
464
+ # masked_image = load_image(masked_image_path)
465
+
466
+ scaling_factor = 1
467
+ input_image = input_image.resize((int(input_image.width * scaling_factor),
468
+ int(input_image.height * scaling_factor)))
469
+
470
+ relit_image = relit_image.resize((int(relit_image.width * scaling_factor),
471
+ int(relit_image.height * scaling_factor)))
472
+
473
+ masked_image = masked_image.resize((int(masked_image.width * scaling_factor),
474
+ int(masked_image.height * scaling_factor)))
475
+
476
+ masked_image_rgb = split_image_with_alpha(masked_image)
477
+ masked_image_blurred = gaussian_blur(masked_image_rgb, radius=10)
478
+ grow_mask = expand_mask(masked_image_blurred, -15, True)
479
+
480
+ # grow_mask.save("output/grow_mask.png")
481
+
482
+ # Split images and get RGB channels
483
+ input_image_rgb = split_image_with_alpha(input_image)
484
+ input_blurred = gaussian_blur(input_image_rgb, radius=10)
485
+ input_inverted = invert_image(input_image_rgb)
486
+
487
+ # input_blurred.save("output/input_blurred.png")
488
+ # input_inverted.save("output/input_inverted.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
+ # Add blurred and inverted images
491
+ input_blend_1 = blend_images(input_inverted, input_blurred, blend_percentage=0.5)
492
+ input_blend_1_inverted = invert_image(input_blend_1)
493
+ input_blend_2 = blend_images(input_blurred, input_blend_1_inverted, blend_percentage=1.0)
494
+
495
+ # input_blend_2.save("output/input_blend_2.png")
496
+
497
+ # Process relit image
498
+ relit_image_rgb = split_image_with_alpha(relit_image)
499
+ relit_blurred = gaussian_blur(relit_image_rgb, radius=10)
500
+ relit_inverted = invert_image(relit_image_rgb)
501
+
502
+ # relit_blurred.save("output/relit_blurred.png")
503
+ # relit_inverted.save("output/relit_inverted.png")
504
+
505
+ # Add blurred and inverted relit images
506
+ relit_blend_1 = blend_images(relit_inverted, relit_blurred, blend_percentage=0.5)
507
+ relit_blend_1_inverted = invert_image(relit_blend_1)
508
+ relit_blend_2 = blend_images(relit_blurred, relit_blend_1_inverted, blend_percentage=1.0)
509
+
510
+ # relit_blend_2.save("output/relit_blend_2.png")
511
+
512
+ high_freq_comp = image_blend_by_mask(relit_blend_2, input_blend_2, grow_mask, blend_percentage=1.0)
513
+
514
+ # high_freq_comp.save("output/high_freq_comp.png")
515
+
516
+ comped_image = blend_images(relit_blurred, high_freq_comp, blend_percentage=0.65)
517
+
518
+ # comped_image.save("output/comped_image.png")
519
+
520
+ final_image = apply_image_levels(comped_image, black_level=83, mid_level=128, white_level=172)
521
+
522
+ # final_image.save("output/final_image.png")
523
+
524
+ return final_image
525
 
526
+ @spaces.GPU
527
  def generate_image(input_img, ai_gen_image, prompt):
528
 
529
  # ai_gen_image = generate_ai_bg(input_img, prompt)
 
552
  # gallery = gr.Gallery(
553
  # label="Generated images", show_label=False, elem_id="gallery"
554
  # )
555
+ # masked_image = gr.Image(label="Generated Image")
556
  output_image = gr.Image(label="Generated Image")
557
 
558
  # Run button
 
563
  ai_image,
564
  prompt
565
  ],
566
+ outputs=[output_image]
567
  )
568
 
569
  return block