Spaces:
Running
on
Zero
Running
on
Zero
feat: enable pipeline to output fused result
Browse files- app.py +3 -110
- pipeline_objectclear.py +49 -15
- utils.py +105 -0
app.py
CHANGED
@@ -10,9 +10,8 @@ import argparse
|
|
10 |
import numpy as np
|
11 |
import torchvision.transforms.functional as TF
|
12 |
from scipy.ndimage import convolve, zoom
|
13 |
-
import cv2
|
14 |
-
import time
|
15 |
import spaces
|
|
|
16 |
|
17 |
from tools.interact_tools import SamControler
|
18 |
from tools.misc import get_device
|
@@ -33,106 +32,6 @@ def parse_augment():
|
|
33 |
|
34 |
return args
|
35 |
|
36 |
-
|
37 |
-
def pad_to_multiple(image: np.ndarray, multiple: int = 8):
|
38 |
-
h, w = image.shape[:2]
|
39 |
-
pad_h = (multiple - h % multiple) % multiple
|
40 |
-
pad_w = (multiple - w % multiple) % multiple
|
41 |
-
if image.ndim == 3:
|
42 |
-
padded = np.pad(image, ((0, pad_h), (0, pad_w), (0,0)), mode='reflect')
|
43 |
-
else:
|
44 |
-
padded = np.pad(image, ((0, pad_h), (0, pad_w)), mode='reflect')
|
45 |
-
return padded, h, w
|
46 |
-
|
47 |
-
def crop_to_original(image: np.ndarray, h: int, w: int):
|
48 |
-
return image[:h, :w]
|
49 |
-
|
50 |
-
def wavelet_blur_np(image: np.ndarray, radius: int):
|
51 |
-
kernel = np.array([
|
52 |
-
[0.0625, 0.125, 0.0625],
|
53 |
-
[0.125, 0.25, 0.125],
|
54 |
-
[0.0625, 0.125, 0.0625]
|
55 |
-
], dtype=np.float32)
|
56 |
-
|
57 |
-
blurred = np.empty_like(image)
|
58 |
-
for c in range(image.shape[0]):
|
59 |
-
blurred_c = convolve(image[c], kernel, mode='nearest')
|
60 |
-
if radius > 1:
|
61 |
-
blurred_c = zoom(zoom(blurred_c, 1 / radius, order=1), radius, order=1)
|
62 |
-
blurred[c] = blurred_c
|
63 |
-
return blurred
|
64 |
-
|
65 |
-
def wavelet_decomposition_np(image: np.ndarray, levels=5):
|
66 |
-
high_freq = np.zeros_like(image)
|
67 |
-
for i in range(levels):
|
68 |
-
radius = 2 ** i
|
69 |
-
low_freq = wavelet_blur_np(image, radius)
|
70 |
-
high_freq += (image - low_freq)
|
71 |
-
image = low_freq
|
72 |
-
return high_freq, low_freq
|
73 |
-
|
74 |
-
def wavelet_reconstruction_np(content_feat: np.ndarray, style_feat: np.ndarray):
|
75 |
-
content_high, _ = wavelet_decomposition_np(content_feat)
|
76 |
-
_, style_low = wavelet_decomposition_np(style_feat)
|
77 |
-
return content_high + style_low
|
78 |
-
|
79 |
-
def wavelet_color_fix_np(fused: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
80 |
-
fused_np = fused.astype(np.float32) / 255.0
|
81 |
-
mask_np = mask.astype(np.float32) / 255.0
|
82 |
-
|
83 |
-
fused_np = fused_np.transpose(2, 0, 1)
|
84 |
-
mask_np = mask_np.transpose(2, 0, 1)
|
85 |
-
|
86 |
-
result_np = wavelet_reconstruction_np(fused_np, mask_np)
|
87 |
-
|
88 |
-
result_np = result_np.transpose(1, 2, 0)
|
89 |
-
result_np = np.clip(result_np * 255.0, 0, 255).astype(np.uint8)
|
90 |
-
|
91 |
-
return result_np
|
92 |
-
|
93 |
-
def fuse_with_wavelet(ori: np.ndarray, removed: np.ndarray, attn_map: np.ndarray, multiple: int = 8):
|
94 |
-
H, W = ori.shape[:2]
|
95 |
-
attn_map = attn_map.astype(np.float32)
|
96 |
-
_, attn_map = cv2.threshold(attn_map, 128, 255, cv2.THRESH_BINARY)
|
97 |
-
am = attn_map.astype(np.float32)
|
98 |
-
am = am/255.0
|
99 |
-
am_up = cv2.resize(am, (W, H), interpolation=cv2.INTER_NEAREST)
|
100 |
-
|
101 |
-
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (21,21))
|
102 |
-
am_d = cv2.dilate(am_up, kernel, iterations=1)
|
103 |
-
am_d = cv2.GaussianBlur(am_d.astype(np.float32), (9,9), sigmaX=2)
|
104 |
-
|
105 |
-
am_merged = np.maximum(am_up, am_d)
|
106 |
-
am_merged = np.clip(am_merged, 0, 1)
|
107 |
-
|
108 |
-
attn_up_3c = np.stack([am_merged]*3, axis=-1)
|
109 |
-
attn_up_ori_3c = np.stack([am_up]*3, axis=-1)
|
110 |
-
|
111 |
-
ori_out = ori * (1 - attn_up_ori_3c)
|
112 |
-
rem_out = removed * (1 - attn_up_ori_3c)
|
113 |
-
|
114 |
-
ori_pad, h0, w0 = pad_to_multiple(ori_out, multiple)
|
115 |
-
rem_pad, _, _ = pad_to_multiple(rem_out, multiple)
|
116 |
-
|
117 |
-
wave_rgb = wavelet_color_fix_np(ori_pad, rem_pad)
|
118 |
-
wave = crop_to_original(wave_rgb, h0, w0)
|
119 |
-
# fusion
|
120 |
-
fused = (wave * (1 - attn_up_3c) + removed * attn_up_3c).astype(np.uint8)
|
121 |
-
return fused
|
122 |
-
|
123 |
-
|
124 |
-
def resize_by_short_side(image, target_short=512, resample=Image.BICUBIC):
|
125 |
-
w, h = image.size
|
126 |
-
if w < h:
|
127 |
-
new_w = target_short
|
128 |
-
new_h = int(h * target_short / w)
|
129 |
-
new_h = (new_h + 15) // 16 * 16
|
130 |
-
else:
|
131 |
-
new_h = target_short
|
132 |
-
new_w = int(w * target_short / h)
|
133 |
-
new_w = (new_w + 15) // 16 * 16
|
134 |
-
return image.resize((new_w, new_h), resample=resample)
|
135 |
-
|
136 |
# convert points input to prompt state
|
137 |
def get_prompt(click_state, click_input):
|
138 |
inputs = json.loads(click_input)
|
@@ -281,7 +180,7 @@ pipe = ObjectClearPipeline.from_pretrained_with_custom_modules(
|
|
281 |
"jixin0101/ObjectClear",
|
282 |
torch_dtype=torch.float16,
|
283 |
variant='fp16',
|
284 |
-
|
285 |
)
|
286 |
|
287 |
pipe.to(device)
|
@@ -325,13 +224,7 @@ def process(image_state, interactive_state, mask_dropdown, guidance_scale, seed,
|
|
325 |
height=h,
|
326 |
width=w,
|
327 |
)
|
328 |
-
|
329 |
-
inpainted_img = result[0].images[0]
|
330 |
-
attn_map = result[1]
|
331 |
-
attn_np = attn_map.mean(dim=1)[0].cpu().numpy() * 255.
|
332 |
-
|
333 |
-
fused_img = fuse_with_wavelet(np.array(image), np.array(inpainted_img), attn_np)
|
334 |
-
fused_img_pil = Image.fromarray(fused_img.astype(np.uint8))
|
335 |
|
336 |
return fused_img_pil.resize((image_or.size[:2])), (image.resize((image_or.size[:2])), fused_img_pil.resize((image_or.size[:2])))
|
337 |
|
|
|
10 |
import numpy as np
|
11 |
import torchvision.transforms.functional as TF
|
12 |
from scipy.ndimage import convolve, zoom
|
|
|
|
|
13 |
import spaces
|
14 |
+
from utils import resize_by_short_side
|
15 |
|
16 |
from tools.interact_tools import SamControler
|
17 |
from tools.misc import get_device
|
|
|
32 |
|
33 |
return args
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
# convert points input to prompt state
|
36 |
def get_prompt(click_state, click_input):
|
37 |
inputs = json.loads(click_input)
|
|
|
180 |
"jixin0101/ObjectClear",
|
181 |
torch_dtype=torch.float16,
|
182 |
variant='fp16',
|
183 |
+
apply_attention_guided_fusion=True
|
184 |
)
|
185 |
|
186 |
pipe.to(device)
|
|
|
224 |
height=h,
|
225 |
width=w,
|
226 |
)
|
227 |
+
fused_img_pil = result.images[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
|
229 |
return fused_img_pil.resize((image_or.size[:2])), (image.resize((image_or.size[:2])), fused_img_pil.resize((image_or.size[:2])))
|
230 |
|
pipeline_objectclear.py
CHANGED
@@ -58,6 +58,7 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusio
|
|
58 |
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
59 |
|
60 |
from model import CLIPImageEncoder, PostfuseModule
|
|
|
61 |
import gc
|
62 |
import torch.nn.functional as F
|
63 |
|
@@ -328,6 +329,10 @@ def retrieve_timesteps(
|
|
328 |
return timesteps, num_inference_steps
|
329 |
|
330 |
|
|
|
|
|
|
|
|
|
331 |
class ObjectClearPipeline(
|
332 |
DiffusionPipeline,
|
333 |
StableDiffusionMixin,
|
@@ -422,7 +427,7 @@ class ObjectClearPipeline(
|
|
422 |
requires_aesthetics_score: bool = False,
|
423 |
force_zeros_for_empty_prompt: bool = True,
|
424 |
add_watermarker: Optional[bool] = None,
|
425 |
-
|
426 |
):
|
427 |
super().__init__()
|
428 |
|
@@ -441,7 +446,7 @@ class ObjectClearPipeline(
|
|
441 |
)
|
442 |
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
443 |
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
|
444 |
-
self.register_to_config(
|
445 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
446 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
447 |
self.mask_processor = VaeImageProcessor(
|
@@ -455,7 +460,7 @@ class ObjectClearPipeline(
|
|
455 |
else:
|
456 |
self.watermark = None
|
457 |
|
458 |
-
if self.config.
|
459 |
self.cross_attention_scores = {}
|
460 |
self.unet = self.unet_store_cross_attention_scores(
|
461 |
self.unet, self.cross_attention_scores
|
@@ -1367,6 +1372,7 @@ class ObjectClearPipeline(
|
|
1367 |
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
1368 |
output_type: Optional[str] = "pil",
|
1369 |
return_dict: bool = True,
|
|
|
1370 |
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1371 |
guidance_rescale: float = 0.0,
|
1372 |
original_size: Tuple[int, int] = None,
|
@@ -1859,7 +1865,7 @@ class ObjectClearPipeline(
|
|
1859 |
).to(device=device, dtype=latents.dtype)
|
1860 |
|
1861 |
self._num_timesteps = len(timesteps)
|
1862 |
-
|
1863 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1864 |
for i, t in enumerate(timesteps):
|
1865 |
if self.interrupt:
|
@@ -1906,16 +1912,16 @@ class ObjectClearPipeline(
|
|
1906 |
|
1907 |
# progressive attention mask blending
|
1908 |
fuse_index = 5
|
1909 |
-
if self.config.
|
1910 |
if i == len(timesteps) - 1:
|
1911 |
attn_key, attn_map = next(iter(self.cross_attention_scores.items()))
|
1912 |
-
|
1913 |
init_latents_proper = image_latents
|
1914 |
if self.do_classifier_free_guidance:
|
1915 |
-
_, init_mask =
|
1916 |
else:
|
1917 |
-
init_mask =
|
1918 |
-
|
1919 |
self.clear_cross_attention_scores(self.cross_attention_scores)
|
1920 |
|
1921 |
if num_channels_unet == 4:
|
@@ -1994,7 +2000,7 @@ class ObjectClearPipeline(
|
|
1994 |
if needs_upcasting:
|
1995 |
self.vae.to(dtype=torch.float16)
|
1996 |
else:
|
1997 |
-
return
|
1998 |
|
1999 |
# apply watermark if available
|
2000 |
if self.watermark is not None:
|
@@ -2004,11 +2010,39 @@ class ObjectClearPipeline(
|
|
2004 |
|
2005 |
if padding_mask_crop is not None:
|
2006 |
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2007 |
|
2008 |
-
|
2009 |
-
|
|
|
|
|
|
|
|
|
|
|
2010 |
|
2011 |
-
|
2012 |
-
|
2013 |
|
2014 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
59 |
|
60 |
from model import CLIPImageEncoder, PostfuseModule
|
61 |
+
from utils import attention_guided_fusion
|
62 |
import gc
|
63 |
import torch.nn.functional as F
|
64 |
|
|
|
329 |
return timesteps, num_inference_steps
|
330 |
|
331 |
|
332 |
+
@dataclass
|
333 |
+
class ObjectClearPipelineOutput(StableDiffusionXLPipelineOutput):
|
334 |
+
attns: Optional[List[PIL.Image.Image]] = None
|
335 |
+
|
336 |
class ObjectClearPipeline(
|
337 |
DiffusionPipeline,
|
338 |
StableDiffusionMixin,
|
|
|
427 |
requires_aesthetics_score: bool = False,
|
428 |
force_zeros_for_empty_prompt: bool = True,
|
429 |
add_watermarker: Optional[bool] = None,
|
430 |
+
apply_attention_guided_fusion: bool = False,
|
431 |
):
|
432 |
super().__init__()
|
433 |
|
|
|
446 |
)
|
447 |
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
448 |
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
|
449 |
+
self.register_to_config(apply_attention_guided_fusion=apply_attention_guided_fusion)
|
450 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
451 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
452 |
self.mask_processor = VaeImageProcessor(
|
|
|
460 |
else:
|
461 |
self.watermark = None
|
462 |
|
463 |
+
if self.config.apply_attention_guided_fusion:
|
464 |
self.cross_attention_scores = {}
|
465 |
self.unet = self.unet_store_cross_attention_scores(
|
466 |
self.unet, self.cross_attention_scores
|
|
|
1372 |
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
1373 |
output_type: Optional[str] = "pil",
|
1374 |
return_dict: bool = True,
|
1375 |
+
return_attn_map: bool = False,
|
1376 |
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1377 |
guidance_rescale: float = 0.0,
|
1378 |
original_size: Tuple[int, int] = None,
|
|
|
1865 |
).to(device=device, dtype=latents.dtype)
|
1866 |
|
1867 |
self._num_timesteps = len(timesteps)
|
1868 |
+
attn_map = None
|
1869 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1870 |
for i, t in enumerate(timesteps):
|
1871 |
if self.interrupt:
|
|
|
1912 |
|
1913 |
# progressive attention mask blending
|
1914 |
fuse_index = 5
|
1915 |
+
if self.config.apply_attention_guided_fusion:
|
1916 |
if i == len(timesteps) - 1:
|
1917 |
attn_key, attn_map = next(iter(self.cross_attention_scores.items()))
|
1918 |
+
attn_map = self.resize_attn_map_divide2(attn_map, mask, fuse_index)
|
1919 |
init_latents_proper = image_latents
|
1920 |
if self.do_classifier_free_guidance:
|
1921 |
+
_, init_mask = attn_map.chunk(2)
|
1922 |
else:
|
1923 |
+
init_mask = attn_map
|
1924 |
+
attn_map = init_mask
|
1925 |
self.clear_cross_attention_scores(self.cross_attention_scores)
|
1926 |
|
1927 |
if num_channels_unet == 4:
|
|
|
2000 |
if needs_upcasting:
|
2001 |
self.vae.to(dtype=torch.float16)
|
2002 |
else:
|
2003 |
+
return ObjectClearPipelineOutput(images=latents)
|
2004 |
|
2005 |
# apply watermark if available
|
2006 |
if self.watermark is not None:
|
|
|
2010 |
|
2011 |
if padding_mask_crop is not None:
|
2012 |
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
|
2013 |
+
|
2014 |
+
attn_pils = []
|
2015 |
+
if output_type == "pil" and attn_map is not None:
|
2016 |
+
for i in range(len(attn_map)):
|
2017 |
+
attn_np = attn_map[i].mean(dim=0).cpu().numpy() * 255.
|
2018 |
+
attn_pil = PIL.Image.fromarray(attn_np.astype(np.uint8)).convert("L")
|
2019 |
+
attn_pils.append(attn_pil)
|
2020 |
+
|
2021 |
+
original_pils = self.image_processor.postprocess(init_image, output_type="pil")
|
2022 |
|
2023 |
+
generated_pils = image
|
2024 |
+
|
2025 |
+
fused_images = []
|
2026 |
+
for i in range(len(generated_pils)):
|
2027 |
+
ori_pil = original_pils[i]
|
2028 |
+
gen_pil = generated_pils[i]
|
2029 |
+
attn_pil = attn_pils[i]
|
2030 |
|
2031 |
+
fused_np = attention_guided_fusion(np.array(ori_pil), np.array(gen_pil), np.array(attn_pil))
|
2032 |
+
fused_pil = PIL.Image.fromarray(fused_np.astype(np.uint8)).resize(ori_pil.size)
|
2033 |
|
2034 |
+
fused_images.append(fused_pil)
|
2035 |
+
|
2036 |
+
image = fused_images
|
2037 |
+
|
2038 |
+
# Offload all models
|
2039 |
+
self.maybe_free_model_hooks()
|
2040 |
+
|
2041 |
+
if return_attn_map and len(attn_pils) > 0:
|
2042 |
+
if not return_dict:
|
2043 |
+
return (image, attn_pils)
|
2044 |
+
return ObjectClearPipelineOutput(images=image, attns=attn_pils)
|
2045 |
+
else:
|
2046 |
+
if not return_dict:
|
2047 |
+
return (image,)
|
2048 |
+
return ObjectClearPipelineOutput(images=image)
|
utils.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
from scipy.ndimage import convolve, zoom
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
def pad_to_multiple(image: np.ndarray, multiple: int = 8):
|
9 |
+
h, w = image.shape[:2]
|
10 |
+
pad_h = (multiple - h % multiple) % multiple
|
11 |
+
pad_w = (multiple - w % multiple) % multiple
|
12 |
+
if image.ndim == 3:
|
13 |
+
padded = np.pad(image, ((0, pad_h), (0, pad_w), (0,0)), mode='reflect')
|
14 |
+
else:
|
15 |
+
padded = np.pad(image, ((0, pad_h), (0, pad_w)), mode='reflect')
|
16 |
+
return padded, h, w
|
17 |
+
|
18 |
+
def crop_to_original(image: np.ndarray, h: int, w: int):
|
19 |
+
return image[:h, :w]
|
20 |
+
|
21 |
+
def wavelet_blur_np(image: np.ndarray, radius: int):
|
22 |
+
kernel = np.array([
|
23 |
+
[0.0625, 0.125, 0.0625],
|
24 |
+
[0.125, 0.25, 0.125],
|
25 |
+
[0.0625, 0.125, 0.0625]
|
26 |
+
], dtype=np.float32)
|
27 |
+
|
28 |
+
blurred = np.empty_like(image)
|
29 |
+
for c in range(image.shape[0]):
|
30 |
+
blurred_c = convolve(image[c], kernel, mode='nearest')
|
31 |
+
if radius > 1:
|
32 |
+
blurred_c = zoom(zoom(blurred_c, 1 / radius, order=1), radius, order=1)
|
33 |
+
blurred[c] = blurred_c
|
34 |
+
return blurred
|
35 |
+
|
36 |
+
def wavelet_decomposition_np(image: np.ndarray, levels=5):
|
37 |
+
high_freq = np.zeros_like(image)
|
38 |
+
for i in range(levels):
|
39 |
+
radius = 2 ** i
|
40 |
+
low_freq = wavelet_blur_np(image, radius)
|
41 |
+
high_freq += (image - low_freq)
|
42 |
+
image = low_freq
|
43 |
+
return high_freq, low_freq
|
44 |
+
|
45 |
+
def wavelet_reconstruction_np(content_feat: np.ndarray, style_feat: np.ndarray):
|
46 |
+
content_high, _ = wavelet_decomposition_np(content_feat)
|
47 |
+
_, style_low = wavelet_decomposition_np(style_feat)
|
48 |
+
return content_high + style_low
|
49 |
+
|
50 |
+
def wavelet_color_fix_np(fused: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
51 |
+
fused_np = fused.astype(np.float32) / 255.0
|
52 |
+
mask_np = mask.astype(np.float32) / 255.0
|
53 |
+
|
54 |
+
fused_np = fused_np.transpose(2, 0, 1)
|
55 |
+
mask_np = mask_np.transpose(2, 0, 1)
|
56 |
+
|
57 |
+
result_np = wavelet_reconstruction_np(fused_np, mask_np)
|
58 |
+
|
59 |
+
result_np = result_np.transpose(1, 2, 0)
|
60 |
+
result_np = np.clip(result_np * 255.0, 0, 255).astype(np.uint8)
|
61 |
+
|
62 |
+
return result_np
|
63 |
+
|
64 |
+
def attention_guided_fusion(ori: np.ndarray, removed: np.ndarray, attn_map: np.ndarray, multiple: int = 8):
|
65 |
+
H, W = ori.shape[:2]
|
66 |
+
attn_map = attn_map.astype(np.float32)
|
67 |
+
_, attn_map = cv2.threshold(attn_map, 128, 255, cv2.THRESH_BINARY)
|
68 |
+
am = attn_map.astype(np.float32)
|
69 |
+
am = am/255.0
|
70 |
+
am_up = cv2.resize(am, (W, H), interpolation=cv2.INTER_NEAREST)
|
71 |
+
|
72 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (21,21))
|
73 |
+
am_d = cv2.dilate(am_up, kernel, iterations=1)
|
74 |
+
am_d = cv2.GaussianBlur(am_d.astype(np.float32), (9,9), sigmaX=2)
|
75 |
+
|
76 |
+
am_merged = np.maximum(am_up, am_d)
|
77 |
+
am_merged = np.clip(am_merged, 0, 1)
|
78 |
+
|
79 |
+
attn_up_3c = np.stack([am_merged]*3, axis=-1)
|
80 |
+
attn_up_ori_3c = np.stack([am_up]*3, axis=-1)
|
81 |
+
|
82 |
+
ori_out = ori * (1 - attn_up_ori_3c)
|
83 |
+
rem_out = removed * (1 - attn_up_ori_3c)
|
84 |
+
|
85 |
+
ori_pad, h0, w0 = pad_to_multiple(ori_out, multiple)
|
86 |
+
rem_pad, _, _ = pad_to_multiple(rem_out, multiple)
|
87 |
+
|
88 |
+
wave_rgb = wavelet_color_fix_np(ori_pad, rem_pad)
|
89 |
+
wave = crop_to_original(wave_rgb, h0, w0)
|
90 |
+
# fusion
|
91 |
+
fused = (wave * (1 - attn_up_3c) + removed * attn_up_3c).astype(np.uint8)
|
92 |
+
return fused
|
93 |
+
|
94 |
+
|
95 |
+
def resize_by_short_side(image, target_short=512, resample=Image.BICUBIC):
|
96 |
+
w, h = image.size
|
97 |
+
if w < h:
|
98 |
+
new_w = target_short
|
99 |
+
new_h = int(h * target_short / w)
|
100 |
+
new_h = (new_h + 15) // 16 * 16
|
101 |
+
else:
|
102 |
+
new_h = target_short
|
103 |
+
new_w = int(w * target_short / h)
|
104 |
+
new_w = (new_w + 15) // 16 * 16
|
105 |
+
return image.resize((new_w, new_h), resample=resample)
|