jixin0101 commited on
Commit
6fa3be1
·
1 Parent(s): 5721737

feat: enable pipeline to output fused result

Browse files
Files changed (3) hide show
  1. app.py +3 -110
  2. pipeline_objectclear.py +49 -15
  3. utils.py +105 -0
app.py CHANGED
@@ -10,9 +10,8 @@ import argparse
10
  import numpy as np
11
  import torchvision.transforms.functional as TF
12
  from scipy.ndimage import convolve, zoom
13
- import cv2
14
- import time
15
  import spaces
 
16
 
17
  from tools.interact_tools import SamControler
18
  from tools.misc import get_device
@@ -33,106 +32,6 @@ def parse_augment():
33
 
34
  return args
35
 
36
-
37
- def pad_to_multiple(image: np.ndarray, multiple: int = 8):
38
- h, w = image.shape[:2]
39
- pad_h = (multiple - h % multiple) % multiple
40
- pad_w = (multiple - w % multiple) % multiple
41
- if image.ndim == 3:
42
- padded = np.pad(image, ((0, pad_h), (0, pad_w), (0,0)), mode='reflect')
43
- else:
44
- padded = np.pad(image, ((0, pad_h), (0, pad_w)), mode='reflect')
45
- return padded, h, w
46
-
47
- def crop_to_original(image: np.ndarray, h: int, w: int):
48
- return image[:h, :w]
49
-
50
- def wavelet_blur_np(image: np.ndarray, radius: int):
51
- kernel = np.array([
52
- [0.0625, 0.125, 0.0625],
53
- [0.125, 0.25, 0.125],
54
- [0.0625, 0.125, 0.0625]
55
- ], dtype=np.float32)
56
-
57
- blurred = np.empty_like(image)
58
- for c in range(image.shape[0]):
59
- blurred_c = convolve(image[c], kernel, mode='nearest')
60
- if radius > 1:
61
- blurred_c = zoom(zoom(blurred_c, 1 / radius, order=1), radius, order=1)
62
- blurred[c] = blurred_c
63
- return blurred
64
-
65
- def wavelet_decomposition_np(image: np.ndarray, levels=5):
66
- high_freq = np.zeros_like(image)
67
- for i in range(levels):
68
- radius = 2 ** i
69
- low_freq = wavelet_blur_np(image, radius)
70
- high_freq += (image - low_freq)
71
- image = low_freq
72
- return high_freq, low_freq
73
-
74
- def wavelet_reconstruction_np(content_feat: np.ndarray, style_feat: np.ndarray):
75
- content_high, _ = wavelet_decomposition_np(content_feat)
76
- _, style_low = wavelet_decomposition_np(style_feat)
77
- return content_high + style_low
78
-
79
- def wavelet_color_fix_np(fused: np.ndarray, mask: np.ndarray) -> np.ndarray:
80
- fused_np = fused.astype(np.float32) / 255.0
81
- mask_np = mask.astype(np.float32) / 255.0
82
-
83
- fused_np = fused_np.transpose(2, 0, 1)
84
- mask_np = mask_np.transpose(2, 0, 1)
85
-
86
- result_np = wavelet_reconstruction_np(fused_np, mask_np)
87
-
88
- result_np = result_np.transpose(1, 2, 0)
89
- result_np = np.clip(result_np * 255.0, 0, 255).astype(np.uint8)
90
-
91
- return result_np
92
-
93
- def fuse_with_wavelet(ori: np.ndarray, removed: np.ndarray, attn_map: np.ndarray, multiple: int = 8):
94
- H, W = ori.shape[:2]
95
- attn_map = attn_map.astype(np.float32)
96
- _, attn_map = cv2.threshold(attn_map, 128, 255, cv2.THRESH_BINARY)
97
- am = attn_map.astype(np.float32)
98
- am = am/255.0
99
- am_up = cv2.resize(am, (W, H), interpolation=cv2.INTER_NEAREST)
100
-
101
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (21,21))
102
- am_d = cv2.dilate(am_up, kernel, iterations=1)
103
- am_d = cv2.GaussianBlur(am_d.astype(np.float32), (9,9), sigmaX=2)
104
-
105
- am_merged = np.maximum(am_up, am_d)
106
- am_merged = np.clip(am_merged, 0, 1)
107
-
108
- attn_up_3c = np.stack([am_merged]*3, axis=-1)
109
- attn_up_ori_3c = np.stack([am_up]*3, axis=-1)
110
-
111
- ori_out = ori * (1 - attn_up_ori_3c)
112
- rem_out = removed * (1 - attn_up_ori_3c)
113
-
114
- ori_pad, h0, w0 = pad_to_multiple(ori_out, multiple)
115
- rem_pad, _, _ = pad_to_multiple(rem_out, multiple)
116
-
117
- wave_rgb = wavelet_color_fix_np(ori_pad, rem_pad)
118
- wave = crop_to_original(wave_rgb, h0, w0)
119
- # fusion
120
- fused = (wave * (1 - attn_up_3c) + removed * attn_up_3c).astype(np.uint8)
121
- return fused
122
-
123
-
124
- def resize_by_short_side(image, target_short=512, resample=Image.BICUBIC):
125
- w, h = image.size
126
- if w < h:
127
- new_w = target_short
128
- new_h = int(h * target_short / w)
129
- new_h = (new_h + 15) // 16 * 16
130
- else:
131
- new_h = target_short
132
- new_w = int(w * target_short / h)
133
- new_w = (new_w + 15) // 16 * 16
134
- return image.resize((new_w, new_h), resample=resample)
135
-
136
  # convert points input to prompt state
137
  def get_prompt(click_state, click_input):
138
  inputs = json.loads(click_input)
@@ -281,7 +180,7 @@ pipe = ObjectClearPipeline.from_pretrained_with_custom_modules(
281
  "jixin0101/ObjectClear",
282
  torch_dtype=torch.float16,
283
  variant='fp16',
284
- save_cross_attn=True
285
  )
286
 
287
  pipe.to(device)
@@ -325,13 +224,7 @@ def process(image_state, interactive_state, mask_dropdown, guidance_scale, seed,
325
  height=h,
326
  width=w,
327
  )
328
-
329
- inpainted_img = result[0].images[0]
330
- attn_map = result[1]
331
- attn_np = attn_map.mean(dim=1)[0].cpu().numpy() * 255.
332
-
333
- fused_img = fuse_with_wavelet(np.array(image), np.array(inpainted_img), attn_np)
334
- fused_img_pil = Image.fromarray(fused_img.astype(np.uint8))
335
 
336
  return fused_img_pil.resize((image_or.size[:2])), (image.resize((image_or.size[:2])), fused_img_pil.resize((image_or.size[:2])))
337
 
 
10
  import numpy as np
11
  import torchvision.transforms.functional as TF
12
  from scipy.ndimage import convolve, zoom
 
 
13
  import spaces
14
+ from utils import resize_by_short_side
15
 
16
  from tools.interact_tools import SamControler
17
  from tools.misc import get_device
 
32
 
33
  return args
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # convert points input to prompt state
36
  def get_prompt(click_state, click_input):
37
  inputs = json.loads(click_input)
 
180
  "jixin0101/ObjectClear",
181
  torch_dtype=torch.float16,
182
  variant='fp16',
183
+ apply_attention_guided_fusion=True
184
  )
185
 
186
  pipe.to(device)
 
224
  height=h,
225
  width=w,
226
  )
227
+ fused_img_pil = result.images[0]
 
 
 
 
 
 
228
 
229
  return fused_img_pil.resize((image_or.size[:2])), (image.resize((image_or.size[:2])), fused_img_pil.resize((image_or.size[:2])))
230
 
pipeline_objectclear.py CHANGED
@@ -58,6 +58,7 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusio
58
  from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
59
 
60
  from model import CLIPImageEncoder, PostfuseModule
 
61
  import gc
62
  import torch.nn.functional as F
63
 
@@ -328,6 +329,10 @@ def retrieve_timesteps(
328
  return timesteps, num_inference_steps
329
 
330
 
 
 
 
 
331
  class ObjectClearPipeline(
332
  DiffusionPipeline,
333
  StableDiffusionMixin,
@@ -422,7 +427,7 @@ class ObjectClearPipeline(
422
  requires_aesthetics_score: bool = False,
423
  force_zeros_for_empty_prompt: bool = True,
424
  add_watermarker: Optional[bool] = None,
425
- save_cross_attn: bool = False,
426
  ):
427
  super().__init__()
428
 
@@ -441,7 +446,7 @@ class ObjectClearPipeline(
441
  )
442
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
443
  self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
444
- self.register_to_config(save_cross_attn=save_cross_attn)
445
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
446
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
447
  self.mask_processor = VaeImageProcessor(
@@ -455,7 +460,7 @@ class ObjectClearPipeline(
455
  else:
456
  self.watermark = None
457
 
458
- if self.config.save_cross_attn:
459
  self.cross_attention_scores = {}
460
  self.unet = self.unet_store_cross_attention_scores(
461
  self.unet, self.cross_attention_scores
@@ -1367,6 +1372,7 @@ class ObjectClearPipeline(
1367
  ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1368
  output_type: Optional[str] = "pil",
1369
  return_dict: bool = True,
 
1370
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1371
  guidance_rescale: float = 0.0,
1372
  original_size: Tuple[int, int] = None,
@@ -1859,7 +1865,7 @@ class ObjectClearPipeline(
1859
  ).to(device=device, dtype=latents.dtype)
1860
 
1861
  self._num_timesteps = len(timesteps)
1862
- self.attn_map = None
1863
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1864
  for i, t in enumerate(timesteps):
1865
  if self.interrupt:
@@ -1906,16 +1912,16 @@ class ObjectClearPipeline(
1906
 
1907
  # progressive attention mask blending
1908
  fuse_index = 5
1909
- if self.config.save_cross_attn:
1910
  if i == len(timesteps) - 1:
1911
  attn_key, attn_map = next(iter(self.cross_attention_scores.items()))
1912
- self.attn_map = self.resize_attn_map_divide2(attn_map, mask, fuse_index)
1913
  init_latents_proper = image_latents
1914
  if self.do_classifier_free_guidance:
1915
- _, init_mask = self.attn_map.chunk(2)
1916
  else:
1917
- init_mask = self.attn_map
1918
- self.attn_map = init_mask
1919
  self.clear_cross_attention_scores(self.cross_attention_scores)
1920
 
1921
  if num_channels_unet == 4:
@@ -1994,7 +2000,7 @@ class ObjectClearPipeline(
1994
  if needs_upcasting:
1995
  self.vae.to(dtype=torch.float16)
1996
  else:
1997
- return StableDiffusionXLPipelineOutput(images=latents)
1998
 
1999
  # apply watermark if available
2000
  if self.watermark is not None:
@@ -2004,11 +2010,39 @@ class ObjectClearPipeline(
2004
 
2005
  if padding_mask_crop is not None:
2006
  image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
 
 
 
 
 
 
 
 
 
2007
 
2008
- # Offload all models
2009
- self.maybe_free_model_hooks()
 
 
 
 
 
2010
 
2011
- if not return_dict:
2012
- return (image,)
2013
 
2014
- return StableDiffusionXLPipelineOutput(images=image), self.attn_map
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
59
 
60
  from model import CLIPImageEncoder, PostfuseModule
61
+ from utils import attention_guided_fusion
62
  import gc
63
  import torch.nn.functional as F
64
 
 
329
  return timesteps, num_inference_steps
330
 
331
 
332
+ @dataclass
333
+ class ObjectClearPipelineOutput(StableDiffusionXLPipelineOutput):
334
+ attns: Optional[List[PIL.Image.Image]] = None
335
+
336
  class ObjectClearPipeline(
337
  DiffusionPipeline,
338
  StableDiffusionMixin,
 
427
  requires_aesthetics_score: bool = False,
428
  force_zeros_for_empty_prompt: bool = True,
429
  add_watermarker: Optional[bool] = None,
430
+ apply_attention_guided_fusion: bool = False,
431
  ):
432
  super().__init__()
433
 
 
446
  )
447
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
448
  self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
449
+ self.register_to_config(apply_attention_guided_fusion=apply_attention_guided_fusion)
450
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
451
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
452
  self.mask_processor = VaeImageProcessor(
 
460
  else:
461
  self.watermark = None
462
 
463
+ if self.config.apply_attention_guided_fusion:
464
  self.cross_attention_scores = {}
465
  self.unet = self.unet_store_cross_attention_scores(
466
  self.unet, self.cross_attention_scores
 
1372
  ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1373
  output_type: Optional[str] = "pil",
1374
  return_dict: bool = True,
1375
+ return_attn_map: bool = False,
1376
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1377
  guidance_rescale: float = 0.0,
1378
  original_size: Tuple[int, int] = None,
 
1865
  ).to(device=device, dtype=latents.dtype)
1866
 
1867
  self._num_timesteps = len(timesteps)
1868
+ attn_map = None
1869
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1870
  for i, t in enumerate(timesteps):
1871
  if self.interrupt:
 
1912
 
1913
  # progressive attention mask blending
1914
  fuse_index = 5
1915
+ if self.config.apply_attention_guided_fusion:
1916
  if i == len(timesteps) - 1:
1917
  attn_key, attn_map = next(iter(self.cross_attention_scores.items()))
1918
+ attn_map = self.resize_attn_map_divide2(attn_map, mask, fuse_index)
1919
  init_latents_proper = image_latents
1920
  if self.do_classifier_free_guidance:
1921
+ _, init_mask = attn_map.chunk(2)
1922
  else:
1923
+ init_mask = attn_map
1924
+ attn_map = init_mask
1925
  self.clear_cross_attention_scores(self.cross_attention_scores)
1926
 
1927
  if num_channels_unet == 4:
 
2000
  if needs_upcasting:
2001
  self.vae.to(dtype=torch.float16)
2002
  else:
2003
+ return ObjectClearPipelineOutput(images=latents)
2004
 
2005
  # apply watermark if available
2006
  if self.watermark is not None:
 
2010
 
2011
  if padding_mask_crop is not None:
2012
  image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
2013
+
2014
+ attn_pils = []
2015
+ if output_type == "pil" and attn_map is not None:
2016
+ for i in range(len(attn_map)):
2017
+ attn_np = attn_map[i].mean(dim=0).cpu().numpy() * 255.
2018
+ attn_pil = PIL.Image.fromarray(attn_np.astype(np.uint8)).convert("L")
2019
+ attn_pils.append(attn_pil)
2020
+
2021
+ original_pils = self.image_processor.postprocess(init_image, output_type="pil")
2022
 
2023
+ generated_pils = image
2024
+
2025
+ fused_images = []
2026
+ for i in range(len(generated_pils)):
2027
+ ori_pil = original_pils[i]
2028
+ gen_pil = generated_pils[i]
2029
+ attn_pil = attn_pils[i]
2030
 
2031
+ fused_np = attention_guided_fusion(np.array(ori_pil), np.array(gen_pil), np.array(attn_pil))
2032
+ fused_pil = PIL.Image.fromarray(fused_np.astype(np.uint8)).resize(ori_pil.size)
2033
 
2034
+ fused_images.append(fused_pil)
2035
+
2036
+ image = fused_images
2037
+
2038
+ # Offload all models
2039
+ self.maybe_free_model_hooks()
2040
+
2041
+ if return_attn_map and len(attn_pils) > 0:
2042
+ if not return_dict:
2043
+ return (image, attn_pils)
2044
+ return ObjectClearPipelineOutput(images=image, attns=attn_pils)
2045
+ else:
2046
+ if not return_dict:
2047
+ return (image,)
2048
+ return ObjectClearPipelineOutput(images=image)
utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from scipy.ndimage import convolve, zoom
4
+ from PIL import Image
5
+
6
+
7
+
8
+ def pad_to_multiple(image: np.ndarray, multiple: int = 8):
9
+ h, w = image.shape[:2]
10
+ pad_h = (multiple - h % multiple) % multiple
11
+ pad_w = (multiple - w % multiple) % multiple
12
+ if image.ndim == 3:
13
+ padded = np.pad(image, ((0, pad_h), (0, pad_w), (0,0)), mode='reflect')
14
+ else:
15
+ padded = np.pad(image, ((0, pad_h), (0, pad_w)), mode='reflect')
16
+ return padded, h, w
17
+
18
+ def crop_to_original(image: np.ndarray, h: int, w: int):
19
+ return image[:h, :w]
20
+
21
+ def wavelet_blur_np(image: np.ndarray, radius: int):
22
+ kernel = np.array([
23
+ [0.0625, 0.125, 0.0625],
24
+ [0.125, 0.25, 0.125],
25
+ [0.0625, 0.125, 0.0625]
26
+ ], dtype=np.float32)
27
+
28
+ blurred = np.empty_like(image)
29
+ for c in range(image.shape[0]):
30
+ blurred_c = convolve(image[c], kernel, mode='nearest')
31
+ if radius > 1:
32
+ blurred_c = zoom(zoom(blurred_c, 1 / radius, order=1), radius, order=1)
33
+ blurred[c] = blurred_c
34
+ return blurred
35
+
36
+ def wavelet_decomposition_np(image: np.ndarray, levels=5):
37
+ high_freq = np.zeros_like(image)
38
+ for i in range(levels):
39
+ radius = 2 ** i
40
+ low_freq = wavelet_blur_np(image, radius)
41
+ high_freq += (image - low_freq)
42
+ image = low_freq
43
+ return high_freq, low_freq
44
+
45
+ def wavelet_reconstruction_np(content_feat: np.ndarray, style_feat: np.ndarray):
46
+ content_high, _ = wavelet_decomposition_np(content_feat)
47
+ _, style_low = wavelet_decomposition_np(style_feat)
48
+ return content_high + style_low
49
+
50
+ def wavelet_color_fix_np(fused: np.ndarray, mask: np.ndarray) -> np.ndarray:
51
+ fused_np = fused.astype(np.float32) / 255.0
52
+ mask_np = mask.astype(np.float32) / 255.0
53
+
54
+ fused_np = fused_np.transpose(2, 0, 1)
55
+ mask_np = mask_np.transpose(2, 0, 1)
56
+
57
+ result_np = wavelet_reconstruction_np(fused_np, mask_np)
58
+
59
+ result_np = result_np.transpose(1, 2, 0)
60
+ result_np = np.clip(result_np * 255.0, 0, 255).astype(np.uint8)
61
+
62
+ return result_np
63
+
64
+ def attention_guided_fusion(ori: np.ndarray, removed: np.ndarray, attn_map: np.ndarray, multiple: int = 8):
65
+ H, W = ori.shape[:2]
66
+ attn_map = attn_map.astype(np.float32)
67
+ _, attn_map = cv2.threshold(attn_map, 128, 255, cv2.THRESH_BINARY)
68
+ am = attn_map.astype(np.float32)
69
+ am = am/255.0
70
+ am_up = cv2.resize(am, (W, H), interpolation=cv2.INTER_NEAREST)
71
+
72
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (21,21))
73
+ am_d = cv2.dilate(am_up, kernel, iterations=1)
74
+ am_d = cv2.GaussianBlur(am_d.astype(np.float32), (9,9), sigmaX=2)
75
+
76
+ am_merged = np.maximum(am_up, am_d)
77
+ am_merged = np.clip(am_merged, 0, 1)
78
+
79
+ attn_up_3c = np.stack([am_merged]*3, axis=-1)
80
+ attn_up_ori_3c = np.stack([am_up]*3, axis=-1)
81
+
82
+ ori_out = ori * (1 - attn_up_ori_3c)
83
+ rem_out = removed * (1 - attn_up_ori_3c)
84
+
85
+ ori_pad, h0, w0 = pad_to_multiple(ori_out, multiple)
86
+ rem_pad, _, _ = pad_to_multiple(rem_out, multiple)
87
+
88
+ wave_rgb = wavelet_color_fix_np(ori_pad, rem_pad)
89
+ wave = crop_to_original(wave_rgb, h0, w0)
90
+ # fusion
91
+ fused = (wave * (1 - attn_up_3c) + removed * attn_up_3c).astype(np.uint8)
92
+ return fused
93
+
94
+
95
+ def resize_by_short_side(image, target_short=512, resample=Image.BICUBIC):
96
+ w, h = image.size
97
+ if w < h:
98
+ new_w = target_short
99
+ new_h = int(h * target_short / w)
100
+ new_h = (new_h + 15) // 16 * 16
101
+ else:
102
+ new_h = target_short
103
+ new_w = int(w * target_short / h)
104
+ new_w = (new_w + 15) // 16 * 16
105
+ return image.resize((new_w, new_h), resample=resample)