orpatashnik commited on
Commit
b197ccc
·
1 Parent(s): a1c0b29
Files changed (6) hide show
  1. README.md +63 -12
  2. app.py +115 -0
  3. nested_attention_pipeline.py +248 -0
  4. nested_attention_processor.py +363 -0
  5. resampler.py +169 -0
  6. utils.py +128 -0
README.md CHANGED
@@ -1,12 +1,63 @@
1
- ---
2
- title: NestedAttentionPersonalization
3
- emoji: 🐢
4
- colorFrom: blue
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.28.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Nested Attention: Semantic-aware Attention Values for Concept Personalization (SIGGRAPH 2025)
2
+
3
+ ![](assets/teaser_site.jpg)
4
+
5
+ > **Nested Attention: Semantic-aware Attention Values for Concept Personalization**
6
+ > Or Patashnik, Rinon Gal, Daniil Ostashev, Sergey Tulyakov, Kfir Aberman, Daniel Cohen-Or
7
+ > https://arxiv.org/abs/2501.01407
8
+ >
9
+ > **Abstract:** Personalizing text-to-image models to generate images of specific subjects across diverse scenes and styles is a rapidly advancing field. Current approaches often struggle to balance identity preservation with alignment to the input text prompt. Some methods rely on a single textual token to represent a subject, limiting expressiveness, while others use richer representations but disrupt the model's prior, weakening prompt alignment.
10
+ > In this work, we introduce **Nested Attention**, a novel mechanism that injects rich and expressive image representations into the model's existing cross-attention layers. Our key idea is to generate query-dependent subject values, derived from nested attention layers that learn to select relevant subject features for each region in the generated image.
11
+ > We integrate these nested layers into an encoder-based personalization method and show that they enable strong identity preservation while maintaining adherence to input text prompts. Our approach is general and can be trained across various domains. Additionally, its prior preservation allows for combining multiple personalized subjects from different domains in a single image.
12
+
13
+ ## Description
14
+
15
+ Official implementation of **Nested Attention**, an encoder-based method for text-to-image personalization using a novel nested attention mechanism.
16
+
17
+ The implementation of the nested attention mechanism can be found in `nested_attention_processor.py`.
18
+
19
+ This repository provides:
20
+ - An inference notebook (`inference_notebook.ipynb`)
21
+ - A trained encoder for faces
22
+ - A Gradio-based application
23
+
24
+ ## Setup
25
+
26
+ Please download the following models:
27
+ - https://github.com/ageitgey/face_recognition_models/blob/master/face_recognition_models/models/shape_predictor_68_face_landmarks.dat
28
+ - https://github.com/justadudewhohacks/face-recognition.js-models/blob/master/models/mmod_human_face_detector.dat
29
+ - image encoder (add link)
30
+ - trained encoder (add link)
31
+
32
+ Tested with:
33
+ - `torch==2.6.0`
34
+ - `diffusers==0.33.1`
35
+ - `transformers==4.51.2`
36
+
37
+ ## Usage
38
+
39
+ Refer to the inference notebook for an example. Key usage notes:
40
+ - The input image should be aligned and cropped.
41
+ - The special token `<person>` represents the personalized subject and **must appear exactly once** in the input prompt.
42
+ - The parameter `special_token_weight` corresponds to $\lambda$ in the paper, controlling the tradeoff between identity preservation and prompt adherence. Increasing this parameter improves identity preservation.
43
+ - The code supports multiple input images of the same subject. To enable this, set `multiple_images=True` and provide a list of images. For single-image usage, pass an image directly instead of a list.
44
+
45
+ ## Related Work
46
+
47
+ This repository builds upon [IP-Adapter](https://ip-adapter.github.io/).
48
+
49
+ ## BibTeX
50
+
51
+ ```bibtex
52
+ @inproceedings{patashnik2025nested,
53
+ author = {Patashnik, Or and Gal, Rinon and Ostashev, Daniil and Tulyakov, Sergey and Aberman, Kfir and Cohen-Or, Daniel},
54
+ title = {Nested Attention: Semantic-aware Attention Values for Concept Personalization},
55
+ year = {2025},
56
+ publisher = {Association for Computing Machinery},
57
+ url = {https://doi.org/10.1145/3721238.3730634},
58
+ booktitle = {Proceedings of the Special Interest Group on Computer Graphics and Interactive Techniques Conference Conference Papers},
59
+ articleno = {6},
60
+ numpages = {12},
61
+ series = {SIGGRAPH Conference Papers '25}
62
+ }
63
+ ```
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from diffusers import StableDiffusionXLPipeline
4
+ import gradio as gr
5
+ from huggingface_hub import hf_hub_download, snapshot_download
6
+ from nested_attention_pipeline import NestedAdapterInference, add_special_token_to_tokenizer
7
+ from utils import align_face
8
+ import dlib
9
+
10
+
11
+ # ----------------------
12
+ # Configuration (update paths as needed)
13
+ # ----------------------
14
+ SHAPE_PREDICTOR_PATH = hf_hub_download("orpatashnik/NestedAttentionEncoder", "shape_predictor_68_face_landmarks.dat")
15
+ FACE_DETECTOR_PATH = hf_hub_download("orpatashnik/NestedAttentionEncoder", "mmod_human_face_detector.dat")
16
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
17
+ image_encoder_path = snapshot_download("orpatashnik/NestedAttentionEncoder", allow_patterns=["image_encoder/**"])
18
+ image_encoder_path = os.path.join(image_encoder_path, "image_encoder")
19
+ personalization_ckpt = hf_hub_download("orpatashnik/NestedAttentionEncoder", "personalization_encoder/pytorch_model.bin")
20
+ device = "cuda"
21
+
22
+ # Special token settings
23
+ placeholder_token = "<person>"
24
+ initializer_token = "person"
25
+
26
+ # ----------------------
27
+ # Load models
28
+ # ----------------------
29
+ pipe = StableDiffusionXLPipeline.from_pretrained(
30
+ base_model_path,
31
+ torch_dtype=torch.float16,
32
+ )
33
+ add_special_token_to_tokenizer(pipe, placeholder_token, initializer_token)
34
+ ip_model = NestedAdapterInference(
35
+ pipe,
36
+ image_encoder_path,
37
+ personalization_ckpt,
38
+ 1024,
39
+ vq_normalize_factor=2.0,
40
+ device=device
41
+ )
42
+
43
+ # Initialize face alignment predictor
44
+ predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)
45
+ detector = dlib.cnn_face_detection_model_v1(FACE_DETECTOR_PATH)
46
+
47
+ # Generation defaults
48
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
49
+ num_inference_steps = 30
50
+ guidance_scale = 5.0
51
+
52
+ # ----------------------
53
+ # Inference function with alignment
54
+ # ----------------------
55
+ def generate_images(img1, img2, img3, prompt, w, num_samples, seed):
56
+ # Collect non-empty reference images
57
+ refs = [img for img in (img1, img2, img3) if img is not None]
58
+ if not refs:
59
+ return []
60
+
61
+ # Align directly on PIL
62
+ aligned_refs = [align_face(img, predictor, detector) for img in refs]
63
+
64
+ # Resize to model resolution
65
+ pil_images = [aligned.resize((512, 512)) for aligned in aligned_refs]
66
+ placeholder_token_ids = ip_model.pipe.tokenizer.convert_tokens_to_ids([placeholder_token])
67
+
68
+ # Generate personalized samples
69
+ results = ip_model.generate(
70
+ pil_image=pil_images,
71
+ prompt=prompt,
72
+ negative_prompt=negative_prompt,
73
+ num_samples=num_samples,
74
+ num_inference_steps=num_inference_steps,
75
+ placeholder_token_ids=placeholder_token_ids,
76
+ seed=seed if seed > 0 else None,
77
+ guidance_scale=guidance_scale,
78
+ multiple_images=True,
79
+ special_token_weight=w
80
+ )
81
+ return results
82
+
83
+ # ----------------------
84
+ # Gradio UI
85
+ # ----------------------
86
+ with gr.Blocks() as demo:
87
+ gr.Markdown("## Personalized Image Generation Demo")
88
+ gr.Markdown(
89
+ "Upload up to 3 reference images. "
90
+ "Faces will be auto-aligned before personalization. Include the placeholder token (e.g., \\<person\\>) in your prompt, "
91
+ "set token weight, and choose how many outputs you want."
92
+ )
93
+ with gr.Row():
94
+ with gr.Column(scale=1):
95
+ # Reference images
96
+ with gr.Row():
97
+ img1 = gr.Image(type="pil", label="Reference Image 1")
98
+ img2 = gr.Image(type="pil", label="Reference Image 2 (optional)")
99
+ img3 = gr.Image(type="pil", label="Reference Image 3 (optional)")
100
+ prompt_input = gr.Textbox(label="Prompt", placeholder="e.g., an abstract pencil drawing of a <person>")
101
+ w_input = gr.Slider(minimum=1.0, maximum=5.0, step=0.5, value=1.0, label="Special Token Weight (w)")
102
+ num_samples_input = gr.Slider(minimum=1, maximum=6, step=1, value=4, label="Number of Images to Generate")
103
+ seed_input = gr.Slider(minimum=-1, maximum=100000, step=1, value=-1, label="Random Seed (use -1 for random and up to 100000)")
104
+ generate_button = gr.Button("Generate Images")
105
+ with gr.Column(scale=1):
106
+ output_gallery = gr.Gallery(label="Generated Images", columns=3)
107
+
108
+ generate_button.click(
109
+ fn=generate_images,
110
+ inputs=[img1, img2, img3, prompt_input, w_input, num_samples_input, seed_input],
111
+ outputs=output_gallery
112
+ )
113
+
114
+ if __name__ == "__main__":
115
+ demo.launch(share=True, debug=True)
nested_attention_pipeline.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
7
+
8
+ from nested_attention_processor import AttnProcessor, NestedAttnProcessor
9
+ from utils import get_generator
10
+
11
+ from resampler import Resampler
12
+
13
+
14
+
15
+ def add_special_token_to_tokenizer(
16
+ pipe,
17
+ placeholder_token,
18
+ initializer_token
19
+ ):
20
+ num_added_tokens1 = pipe.tokenizer.add_tokens([placeholder_token])
21
+ num_added_tokens2 = pipe.tokenizer_2.add_tokens([placeholder_token])
22
+ if num_added_tokens1 != 1 or num_added_tokens2 != 1:
23
+ raise ValueError("Failed to add placeholder token to tokenizer")
24
+
25
+ token_ids1 = pipe.tokenizer.encode(initializer_token, add_special_tokens=False)
26
+ token_ids2 = pipe.tokenizer_2.encode(initializer_token, add_special_tokens=False)
27
+ if len(token_ids1) > 1 or len(token_ids2) > 1:
28
+ raise ValueError("The initializer token must be a single token.")
29
+ initializer_token_id1 = token_ids1[0]
30
+ initializer_token_id2 = token_ids2[0]
31
+ placeholder_token_ids1 = pipe.tokenizer.convert_tokens_to_ids([placeholder_token])
32
+ placeholder_token_ids2 = pipe.tokenizer_2.convert_tokens_to_ids([placeholder_token])
33
+ pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))
34
+ pipe.text_encoder_2.resize_token_embeddings(len(pipe.tokenizer_2))
35
+ token_embeds1 = pipe.text_encoder.get_input_embeddings().weight.data
36
+ token_embeds2 = pipe.text_encoder_2.get_input_embeddings().weight.data
37
+ with torch.no_grad():
38
+ for token_id in placeholder_token_ids1:
39
+ token_embeds1[token_id] = token_embeds1[initializer_token_id1].clone()
40
+ for token_id in placeholder_token_ids2:
41
+ token_embeds2[token_id] = token_embeds2[initializer_token_id2].clone()
42
+
43
+
44
+ class NestedAdapterInference:
45
+ def __init__(
46
+ self,
47
+ sd_pipe,
48
+ image_encoder_path,
49
+ adapter_ckpt,
50
+ resampler_num_queries,
51
+ vq_normalize_factor,
52
+ device,
53
+ ):
54
+ self.device = device
55
+ self.image_encoder_path = image_encoder_path
56
+ self.adapter_ckpt = adapter_ckpt
57
+
58
+ self.vq_normalize_factor = vq_normalize_factor
59
+
60
+ self.pipe = sd_pipe.to(self.device)
61
+ self.set_nested_adapter()
62
+
63
+ # load image encoder
64
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
65
+ self.image_encoder_path
66
+ ).to(self.device, dtype=torch.float16)
67
+ self.clip_image_processor = CLIPImageProcessor()
68
+
69
+ # spatial features model
70
+ self.qformer = Resampler(
71
+ dim=self.pipe.unet.config.cross_attention_dim,
72
+ depth=4,
73
+ dim_head=64,
74
+ heads=12,
75
+ num_queries=resampler_num_queries,
76
+ embedding_dim=self.image_encoder.config.hidden_size,
77
+ output_dim=self.pipe.unet.config.cross_attention_dim,
78
+ ff_mult=4,
79
+ ).to(self.device, dtype=torch.float16)
80
+
81
+ if adapter_ckpt is not None:
82
+ self.load_nested_adapter()
83
+
84
+ def set_nested_adapter(self):
85
+ unet = self.pipe.unet
86
+ attn_procs = {}
87
+ for name in unet.attn_processors.keys():
88
+ cross_attention_dim = (
89
+ None
90
+ if name.endswith("attn1.processor")
91
+ else unet.config.cross_attention_dim
92
+ )
93
+ if name.startswith("mid_block"):
94
+ hidden_size = unet.config.block_out_channels[-1]
95
+ elif name.startswith("up_blocks"):
96
+ block_id = int(name[len("up_blocks.")])
97
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
98
+ elif name.startswith("down_blocks"):
99
+ block_id = int(name[len("down_blocks.")])
100
+ hidden_size = unet.config.block_out_channels[block_id]
101
+ if cross_attention_dim is None:
102
+ attn_procs[name] = AttnProcessor()
103
+ else:
104
+ attn_procs[name] = NestedAttnProcessor(
105
+ hidden_size=hidden_size,
106
+ cross_attention_dim=cross_attention_dim,
107
+ normalize_factor=self.vq_normalize_factor,
108
+ ).to(self.device, dtype=torch.float16)
109
+ unet.set_attn_processor(attn_procs)
110
+
111
+ def load_nested_adapter(self):
112
+ state_dict = {"adapter_modules": {}, "qformer": {}}
113
+ f = torch.load(self.adapter_ckpt, map_location="cpu")
114
+ for key in f.keys():
115
+ if key.startswith("adapter_modules."):
116
+ state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = f[
117
+ key
118
+ ]
119
+ elif key.startswith("spatial_features_model."):
120
+ state_dict["qformer"][key.replace("spatial_features_model.", "")] = f[
121
+ key
122
+ ]
123
+ self.qformer.load_state_dict(state_dict["qformer"])
124
+ adapter_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
125
+ adapter_layers.load_state_dict(state_dict["adapter_modules"])
126
+
127
+ @torch.inference_mode()
128
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
129
+ if isinstance(pil_image, Image.Image):
130
+ pil_image = [pil_image]
131
+ clip_image = self.clip_image_processor(
132
+ images=pil_image, return_tensors="pt"
133
+ ).pixel_values
134
+ clip_image_embeds = self.image_encoder(
135
+ clip_image.to(self.device, dtype=torch.float16)
136
+ )
137
+ spatial_clip_image_embeds = clip_image_embeds.last_hidden_state
138
+ spatial_clip_image_embeds = spatial_clip_image_embeds[:, 1:] # remove CLS token
139
+ return spatial_clip_image_embeds
140
+
141
+ def generate(
142
+ self,
143
+ pil_image=None,
144
+ clip_image_embeds=None,
145
+ prompt=None,
146
+ placeholder_token_ids=None,
147
+ negative_prompt=None,
148
+ scale=1.0,
149
+ num_samples=4,
150
+ seed=None,
151
+ guidance_scale=5.0,
152
+ num_inference_steps=30,
153
+ multiple_images=False,
154
+ special_token_weight=1.0,
155
+ **kwargs,
156
+ ):
157
+ if pil_image is not None:
158
+ num_prompts = (
159
+ 1
160
+ if isinstance(pil_image, Image.Image) or multiple_images
161
+ else len(pil_image)
162
+ )
163
+ else:
164
+ num_prompts = clip_image_embeds.size(0)
165
+
166
+ if prompt is None:
167
+ prompt = "best quality, high quality"
168
+ if negative_prompt is None:
169
+ negative_prompt = (
170
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
171
+ )
172
+
173
+ if not isinstance(prompt, List):
174
+ prompt = [prompt] * num_prompts
175
+ if not isinstance(negative_prompt, List):
176
+ negative_prompt = [negative_prompt] * num_prompts
177
+
178
+ text_input_ids = self.pipe.tokenizer(
179
+ prompt,
180
+ max_length=self.pipe.tokenizer.model_max_length,
181
+ padding="max_length",
182
+ truncation=True,
183
+ return_tensors="pt",
184
+ ).input_ids
185
+ special_token_indices = (text_input_ids == placeholder_token_ids[0]).nonzero()[
186
+ :, 1
187
+ ]
188
+
189
+ spatial_clip_image_embeds = self.get_image_embeds(
190
+ pil_image=pil_image, clip_image_embeds=clip_image_embeds
191
+ ) # (bs, 256, 1280)
192
+
193
+ with torch.no_grad():
194
+ (
195
+ prompt_embeds,
196
+ negative_prompt_embeds,
197
+ pooled_prompt_embeds,
198
+ negative_pooled_prompt_embeds,
199
+ ) = self.pipe.encode_prompt(
200
+ prompt,
201
+ num_images_per_prompt=num_samples,
202
+ do_classifier_free_guidance=True,
203
+ negative_prompt=negative_prompt,
204
+ )
205
+
206
+ special_token_indices = (text_input_ids == placeholder_token_ids[0]).nonzero()[
207
+ :, 1
208
+ ]
209
+
210
+ with torch.no_grad():
211
+ qformer_tokens_out = self.qformer(spatial_clip_image_embeds)
212
+
213
+ if multiple_images:
214
+ b, num_tokens, d = qformer_tokens_out.shape
215
+ qformer_tokens_out = qformer_tokens_out.reshape(
216
+ 1, num_tokens * b, d
217
+ )
218
+
219
+ bs_embed, num_tokens, _ = qformer_tokens_out.shape
220
+
221
+ qformer_tokens_out = qformer_tokens_out.repeat(1, num_samples, 1, 1)
222
+ qformer_tokens_out = qformer_tokens_out.view(
223
+ bs_embed * num_samples, num_tokens, -1
224
+ )
225
+ qformer_tokens_out = qformer_tokens_out.repeat_interleave(2, dim=0)
226
+
227
+ cross_attention_kwargs = {
228
+ "qformer_tokens_out": qformer_tokens_out,
229
+ "special_token_indices": special_token_indices,
230
+ "special_token_weight": special_token_weight,
231
+ "inference_mode": True,
232
+ }
233
+
234
+ generator = get_generator(seed, self.device)
235
+
236
+ images = self.pipe(
237
+ prompt_embeds=prompt_embeds,
238
+ negative_prompt_embeds=negative_prompt_embeds,
239
+ pooled_prompt_embeds=pooled_prompt_embeds,
240
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
241
+ guidance_scale=guidance_scale,
242
+ num_inference_steps=num_inference_steps,
243
+ generator=generator,
244
+ cross_attention_kwargs=cross_attention_kwargs,
245
+ **kwargs,
246
+ ).images
247
+
248
+ return images
nested_attention_processor.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def my_scaled_dot_product_attention(
9
+ query,
10
+ key,
11
+ value,
12
+ attn_mask=None,
13
+ dropout_p=0.0,
14
+ is_causal=False,
15
+ scale=None,
16
+ special_token_weight=1.0,
17
+ special_token_indices=None,
18
+ ) -> torch.Tensor:
19
+ """
20
+ Computes the scaled dot-product attention with additional control over specific tokens.
21
+
22
+ This function is a re-implementation of the scaled dot-product attention mechanism,
23
+ designed to return both the attention map and the output of the attention operation.
24
+ It also provides additional control via a scalar that modifies the attention map
25
+ for specific tokens.
26
+ """
27
+ L, S = query.size(-2), key.size(-2)
28
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
29
+ attn_bias = torch.zeros(L, S, dtype=query.dtype).cuda()
30
+ if is_causal:
31
+ assert attn_mask is None
32
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
33
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
34
+ attn_bias.to(query.dtype)
35
+
36
+ if attn_mask is not None:
37
+ if attn_mask.dtype == torch.bool:
38
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
39
+ else:
40
+ attn_bias += attn_mask
41
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
42
+ attn_weight += attn_bias
43
+ if special_token_indices is not None and special_token_weight != 1.0:
44
+ bs = attn_weight.shape[0]
45
+ attn_weight[torch.arange(bs), :, :, special_token_indices] = torch.max(
46
+ attn_weight[torch.arange(bs), :, :, special_token_indices],
47
+ attn_weight[torch.arange(bs), :, :, special_token_indices]
48
+ * special_token_weight,
49
+ )
50
+
51
+ attn_weight = torch.softmax(attn_weight, dim=-1)
52
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
53
+ return attn_weight @ value, attn_weight
54
+
55
+
56
+ class AttnProcessor(torch.nn.Module):
57
+ r"""
58
+ Processor for implementing scaled dot-product attention.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ hidden_size=None,
64
+ cross_attention_dim=None,
65
+ ):
66
+ super().__init__()
67
+ if not hasattr(F, "scaled_dot_product_attention"):
68
+ raise ImportError(
69
+ "AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
70
+ )
71
+
72
+ def __call__(
73
+ self,
74
+ attn,
75
+ hidden_states,
76
+ qformer_tokens_out=None,
77
+ special_token_indices=None,
78
+ inference_mode=None,
79
+ encoder_hidden_states=None,
80
+ attention_mask=None,
81
+ temb=None,
82
+ special_token_weight=None,
83
+ ):
84
+ residual = hidden_states
85
+
86
+ if attn.spatial_norm is not None:
87
+ hidden_states = attn.spatial_norm(hidden_states, temb)
88
+
89
+ input_ndim = hidden_states.ndim
90
+
91
+ if input_ndim == 4:
92
+ batch_size, channel, height, width = hidden_states.shape
93
+ hidden_states = hidden_states.view(
94
+ batch_size, channel, height * width
95
+ ).transpose(1, 2)
96
+
97
+ batch_size, sequence_length, _ = (
98
+ hidden_states.shape
99
+ if encoder_hidden_states is None
100
+ else encoder_hidden_states.shape
101
+ )
102
+
103
+ if attention_mask is not None:
104
+ attention_mask = attn.prepare_attention_mask(
105
+ attention_mask, sequence_length, batch_size
106
+ )
107
+ # scaled_dot_product_attention expects attention_mask shape to be
108
+ # (batch, heads, source_length, target_length)
109
+ attention_mask = attention_mask.view(
110
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
111
+ )
112
+
113
+ if attn.group_norm is not None:
114
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
115
+ 1, 2
116
+ )
117
+
118
+ query = attn.to_q(hidden_states)
119
+
120
+ if encoder_hidden_states is None:
121
+ encoder_hidden_states = hidden_states
122
+ elif attn.norm_cross:
123
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
124
+ encoder_hidden_states
125
+ )
126
+
127
+ key = attn.to_k(encoder_hidden_states)
128
+ value = attn.to_v(encoder_hidden_states)
129
+
130
+ inner_dim = key.shape[-1]
131
+ head_dim = inner_dim // attn.heads
132
+
133
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
134
+
135
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
136
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
137
+
138
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
139
+ hidden_states = F.scaled_dot_product_attention(
140
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
141
+ )
142
+
143
+ hidden_states = hidden_states.transpose(1, 2).reshape(
144
+ batch_size, -1, attn.heads * head_dim
145
+ )
146
+ hidden_states = hidden_states.to(query.dtype)
147
+
148
+ # linear proj
149
+ hidden_states = attn.to_out[0](hidden_states)
150
+ # dropout
151
+ hidden_states = attn.to_out[1](hidden_states)
152
+
153
+ if input_ndim == 4:
154
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
155
+ batch_size, channel, height, width
156
+ )
157
+
158
+ if attn.residual_connection:
159
+ hidden_states = hidden_states + residual
160
+
161
+ hidden_states = hidden_states / attn.rescale_output_factor
162
+
163
+ return hidden_states
164
+
165
+
166
+ class NestedAttnProcessor(torch.nn.Module):
167
+ r"""
168
+ Nested Attention processor for IP-Adapater for PyTorch 2.0.
169
+ """
170
+
171
+ def __init__(self, hidden_size, cross_attention_dim=None, normalize_factor=1.0):
172
+ super().__init__()
173
+
174
+ if not hasattr(F, "scaled_dot_product_attention"):
175
+ raise ImportError(
176
+ "NestedAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
177
+ )
178
+
179
+ self.hidden_size = hidden_size
180
+ self.cross_attention_dim = cross_attention_dim
181
+
182
+ self.normalize_factor = normalize_factor
183
+
184
+ self.nested_to_k = nn.Linear(
185
+ cross_attention_dim or hidden_size, hidden_size, bias=False
186
+ )
187
+ self.nested_to_v = nn.Linear(
188
+ cross_attention_dim or hidden_size, hidden_size, bias=False
189
+ )
190
+
191
+ def __call__(
192
+ self,
193
+ attn,
194
+ hidden_states,
195
+ qformer_tokens_out,
196
+ special_token_indices,
197
+ inference_mode=False,
198
+ encoder_hidden_states=None,
199
+ attention_mask=None,
200
+ temb=None,
201
+ special_token_weight=1.0,
202
+ ):
203
+ assert (
204
+ special_token_indices.shape[0] > 0
205
+ ), "special_token_indices should not be empty"
206
+
207
+ # if inference mode is set to True, the code assumes that CFG is used and the first half
208
+ # of the batch is used for the null prompt and the second half is used for the prompt
209
+
210
+ residual = hidden_states
211
+
212
+ if attn.spatial_norm is not None:
213
+ hidden_states = attn.spatial_norm(hidden_states, temb)
214
+
215
+ input_ndim = hidden_states.ndim
216
+ bs = hidden_states.shape[0]
217
+
218
+ if input_ndim == 4:
219
+ bs, channel, height, width = hidden_states.shape
220
+ hidden_states = hidden_states.view(bs, channel, height * width).transpose(
221
+ 1, 2
222
+ )
223
+
224
+ bs, sequence_length, _ = (
225
+ hidden_states.shape
226
+ if encoder_hidden_states is None
227
+ else encoder_hidden_states.shape
228
+ )
229
+
230
+ if attention_mask is not None:
231
+ attention_mask = attn.prepare_attention_mask(
232
+ attention_mask, sequence_length, bs
233
+ )
234
+ # scaled_dot_product_attention expects attention_mask shape to be
235
+ # (batch, heads, source_length, target_length)
236
+ attention_mask = attention_mask.view(
237
+ bs, attn.heads, -1, attention_mask.shape[-1]
238
+ )
239
+
240
+ if attn.group_norm is not None:
241
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
242
+ 1, 2
243
+ )
244
+
245
+ query = attn.to_q(hidden_states)
246
+
247
+ if encoder_hidden_states is None:
248
+ encoder_hidden_states = hidden_states
249
+ else:
250
+ if attn.norm_cross:
251
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
252
+ encoder_hidden_states
253
+ )
254
+
255
+ key = attn.to_k(encoder_hidden_states)
256
+ value = attn.to_v(encoder_hidden_states)
257
+
258
+ inner_dim = key.shape[-1]
259
+ head_dim = inner_dim // attn.heads
260
+
261
+ query = query.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
262
+
263
+ key = key.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
264
+ value = value.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
265
+
266
+ # nested attention
267
+ nested_key = self.nested_to_k(qformer_tokens_out)
268
+ nested_value = self.nested_to_v(qformer_tokens_out)
269
+
270
+ nested_key = nested_key.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
271
+ nested_value = nested_value.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
272
+
273
+ nested_hidden_states = F.scaled_dot_product_attention(
274
+ query,
275
+ nested_key,
276
+ nested_value,
277
+ attn_mask=None,
278
+ dropout_p=0.0,
279
+ is_causal=False,
280
+ )
281
+
282
+ # normalize V_q
283
+ textual_values_norms = torch.norm(
284
+ value[torch.arange(bs), :, special_token_indices], dim=-1
285
+ )
286
+ nested_hidden_states = (
287
+ torch.nn.functional.normalize(nested_hidden_states, p=2, dim=-1)
288
+ * self.normalize_factor
289
+ )
290
+ nested_hidden_states = (
291
+ textual_values_norms.view(bs, -1, 1, 1) * nested_hidden_states
292
+ )
293
+
294
+ # outer attention
295
+ value_without_special_tokens = value.clone()
296
+ if inference_mode:
297
+ value_without_special_tokens[bs // 2 : bs, :, special_token_indices, :] = (
298
+ 0.0
299
+ )
300
+ else:
301
+ value_without_special_tokens[
302
+ torch.arange(bs), :, special_token_indices, :
303
+ ] = 0.0
304
+ hidden_states_without_special_tokens, attn_weight = (
305
+ my_scaled_dot_product_attention(
306
+ query,
307
+ key,
308
+ value_without_special_tokens,
309
+ attn_mask=None,
310
+ dropout_p=0.0,
311
+ is_causal=False,
312
+ special_token_weight=special_token_weight,
313
+ special_token_indices=special_token_indices,
314
+ )
315
+ )
316
+
317
+ # add the special token values
318
+ if inference_mode:
319
+ special_token_attn_weight = attn_weight[
320
+ bs // 2 : bs, :, :, special_token_indices
321
+ ]
322
+ else:
323
+ special_token_attn_weight = attn_weight[
324
+ torch.arange(bs), :, :, special_token_indices
325
+ ]
326
+ if inference_mode:
327
+ special_token_weighted_values = (
328
+ special_token_attn_weight * nested_hidden_states[bs // 2 : bs]
329
+ )
330
+ else:
331
+ special_token_weighted_values = (
332
+ special_token_attn_weight.unsqueeze(-1) * nested_hidden_states
333
+ )
334
+ if inference_mode:
335
+ hidden_states = hidden_states_without_special_tokens
336
+ hidden_states[bs // 2 : bs] += special_token_weighted_values
337
+ else:
338
+ hidden_states = (
339
+ hidden_states_without_special_tokens + special_token_weighted_values
340
+ )
341
+
342
+ # arrange hidden states
343
+ hidden_states = hidden_states.transpose(1, 2).reshape(
344
+ bs, -1, attn.heads * head_dim
345
+ )
346
+ hidden_states = hidden_states.to(query.dtype)
347
+
348
+ # linear proj
349
+ hidden_states = attn.to_out[0](hidden_states)
350
+ # dropout
351
+ hidden_states = attn.to_out[1](hidden_states)
352
+
353
+ if input_ndim == 4:
354
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
355
+ bs, channel, height, width
356
+ )
357
+
358
+ if attn.residual_connection:
359
+ hidden_states = hidden_states + residual
360
+
361
+ hidden_states = hidden_states / attn.rescale_output_factor
362
+
363
+ return hidden_states
resampler.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from einops.layers.torch import Rearrange
10
+
11
+
12
+ # FFN
13
+ def FeedForward(dim, mult=4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias=False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias=False),
20
+ )
21
+
22
+
23
+ def reshape_tensor(x, heads):
24
+ bs, length, width = x.shape
25
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
+ x = x.view(bs, length, heads, -1)
27
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
+ x = x.transpose(1, 2)
29
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
+ x = x.reshape(bs, heads, length, -1)
31
+ return x
32
+
33
+
34
+ class PerceiverAttention(nn.Module):
35
+ def __init__(self, *, dim, dim_head=64, heads=8):
36
+ super().__init__()
37
+ self.scale = dim_head**-0.5
38
+ self.dim_head = dim_head
39
+ self.heads = heads
40
+ inner_dim = dim_head * heads
41
+
42
+ self.norm1 = nn.LayerNorm(dim)
43
+ self.norm2 = nn.LayerNorm(dim)
44
+
45
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
+
49
+ def forward(self, x, latents):
50
+ """
51
+ Args:
52
+ x (torch.Tensor): image features
53
+ shape (b, n1, D)
54
+ latent (torch.Tensor): latent features
55
+ shape (b, n2, D)
56
+ """
57
+ x = self.norm1(x)
58
+ latents = self.norm2(latents)
59
+
60
+ b, l, _ = latents.shape
61
+
62
+ q = self.to_q(latents)
63
+ kv_input = torch.cat((x, latents), dim=-2)
64
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
+
66
+ q = reshape_tensor(q, self.heads)
67
+ k = reshape_tensor(k, self.heads)
68
+ v = reshape_tensor(v, self.heads)
69
+
70
+ # attention
71
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
+ weight = (q * scale) @ (k * scale).transpose(
73
+ -2, -1
74
+ ) # More stable with f16 than dividing afterwards
75
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
76
+ out = weight @ v
77
+
78
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
79
+
80
+ return self.to_out(out)
81
+
82
+
83
+ class Resampler(nn.Module):
84
+ def __init__(
85
+ self,
86
+ dim=1024,
87
+ depth=8,
88
+ dim_head=64,
89
+ heads=16,
90
+ num_queries=8,
91
+ embedding_dim=768,
92
+ output_dim=1024,
93
+ ff_mult=4,
94
+ max_seq_len: int = 257, # CLIP tokens + CLS token
95
+ apply_pos_emb: bool = False,
96
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
97
+ ):
98
+ super().__init__()
99
+ self.num_queries = num_queries
100
+
101
+ self.pos_emb = (
102
+ nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
103
+ )
104
+
105
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
106
+
107
+ self.proj_in = nn.Linear(embedding_dim, dim)
108
+
109
+ self.proj_out = nn.Linear(dim, output_dim)
110
+ self.norm_out = nn.LayerNorm(output_dim)
111
+
112
+ self.to_latents_from_mean_pooled_seq = (
113
+ nn.Sequential(
114
+ nn.LayerNorm(dim),
115
+ nn.Linear(dim, dim * num_latents_mean_pooled),
116
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
117
+ )
118
+ if num_latents_mean_pooled > 0
119
+ else None
120
+ )
121
+
122
+ self.layers = nn.ModuleList([])
123
+ for _ in range(depth):
124
+ self.layers.append(
125
+ nn.ModuleList(
126
+ [
127
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
128
+ FeedForward(dim=dim, mult=ff_mult),
129
+ ]
130
+ )
131
+ )
132
+
133
+ def forward(self, x):
134
+
135
+ if self.pos_emb is not None:
136
+ n, device = x.shape[1], x.device
137
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
138
+ x = x + pos_emb
139
+
140
+ latents = self.latents.repeat(x.size(0), 1, 1)
141
+
142
+ x = self.proj_in(x)
143
+
144
+ if self.to_latents_from_mean_pooled_seq:
145
+ meanpooled_seq = masked_mean(
146
+ x,
147
+ dim=1,
148
+ mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool),
149
+ )
150
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
151
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
152
+
153
+ for attn, ff in self.layers:
154
+ latents = attn(x, latents) + latents
155
+ latents = ff(latents) + latents
156
+
157
+ latents = self.proj_out(latents)
158
+ return self.norm_out(latents)
159
+
160
+
161
+ def masked_mean(t, *, dim, mask=None):
162
+ if mask is None:
163
+ return t.mean(dim=dim)
164
+
165
+ denom = mask.sum(dim=dim, keepdim=True)
166
+ mask = rearrange(mask, "b n -> b n 1")
167
+ masked_t = t.masked_fill(~mask, 0.0)
168
+
169
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import numpy as np
4
+ import dlib
5
+ import scipy
6
+
7
+ def image_grid(imgs, rows, cols):
8
+ assert len(imgs) == rows*cols
9
+
10
+ w, h = imgs[0].size
11
+ grid = Image.new('RGB', size=(cols*w, rows*h))
12
+ grid_w, grid_h = grid.size
13
+
14
+ for i, img in enumerate(imgs):
15
+ grid.paste(img, box=(i%cols*w, i//cols*h))
16
+ return grid
17
+
18
+
19
+ def get_generator(seed, device):
20
+
21
+ if seed is not None:
22
+ if isinstance(seed, list):
23
+ generator = [
24
+ torch.Generator(device).manual_seed(seed_item) for seed_item in seed
25
+ ]
26
+ else:
27
+ generator = torch.Generator(device).manual_seed(seed)
28
+ else:
29
+ generator = None
30
+
31
+ return generator
32
+
33
+ def get_landmark_pil(pil_image, predictor, detector):
34
+ """Get 68 facial landmarks as a NumPy array of shape (68, 2)."""
35
+ img_np = np.array(pil_image.convert("RGB"))
36
+ dets = detector(img_np, 1)
37
+ if not dets:
38
+ return None
39
+ # Handle mmod or frontal detector output
40
+ det = dets[0].rect if hasattr(dets[0], 'rect') else dets[0]
41
+ shape = predictor(img_np, det)
42
+ coords = [(pt.x, pt.y) for pt in shape.parts()]
43
+ return np.array(coords)
44
+
45
+
46
+ def align_face(pil_image, predictor, detector):
47
+ """Align a face from a PIL.Image, returning an aligned PIL.Image of size 512x512."""
48
+ lm = get_landmark_pil(pil_image, predictor, detector)
49
+ if lm is None:
50
+ return pil_image
51
+ # Define landmark regions
52
+ lm_chin = lm[0: 17] # left-right
53
+ lm_eyebrow_left = lm[17: 22] # left-right
54
+ lm_eyebrow_right = lm[22: 27] # left-right
55
+ lm_nose = lm[27: 31] # top-down
56
+ lm_nostrils = lm[31: 36] # top-down
57
+ lm_eye_left = lm[36: 42] # left-clockwise
58
+ lm_eye_right = lm[42: 48] # left-clockwise
59
+ lm_mouth_outer = lm[48: 60] # left-clockwise
60
+ lm_mouth_inner = lm[60: 68] # left-clockwise
61
+
62
+ eye_left = np.mean(lm_eye_left, axis=0)
63
+ eye_right = np.mean(lm_eye_right, axis=0)
64
+ eye_avg = (eye_left + eye_right) * 0.5
65
+ eye_to_eye = eye_right - eye_left
66
+ mouth_left = lm_mouth_outer[0]
67
+ mouth_right = lm_mouth_outer[6]
68
+ mouth_avg = (mouth_left + mouth_right) * 0.5
69
+ eye_to_mouth = mouth_avg - eye_avg
70
+
71
+ # Compute oriented crop
72
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
73
+ x /= np.hypot(*x)
74
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
75
+ y = np.flipud(x) * [-1, 1]
76
+ c = eye_avg + eye_to_mouth * 0.1
77
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
78
+ qsize = np.hypot(*x) * 2
79
+
80
+ # Prepare image
81
+ img = pil_image.convert("RGB")
82
+ transform_size = 512
83
+ output_size = 512
84
+ enable_padding = True
85
+
86
+ # Shrink image for speed
87
+ shrink = int(np.floor(qsize / output_size * 0.5))
88
+ if shrink > 1:
89
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
90
+ img = img.resize(rsize, Image.Resampling.LANCZOS)
91
+ quad /= shrink
92
+ qsize /= shrink
93
+
94
+ # Crop around face
95
+ border = max(int(np.rint(qsize * 0.1)), 3)
96
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
97
+ int(np.ceil(max(quad[:, 1]))))
98
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
99
+ min(crop[3] + border, img.size[1]))
100
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
101
+ img = img.crop(crop)
102
+ quad -= crop[0:2]
103
+
104
+ # Pad
105
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
106
+ int(np.ceil(max(quad[:, 1]))))
107
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
108
+ max(pad[3] - img.size[1] + border, 0))
109
+ if enable_padding and max(pad) > border - 4:
110
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
111
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
112
+ h, w, _ = img.shape
113
+ y, x, _ = np.ogrid[:h, :w, :1]
114
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
115
+ 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
116
+ blur = qsize * 0.02
117
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
118
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
119
+ img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
120
+ quad += pad[:2]
121
+
122
+ # Transform image
123
+ img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR)
124
+ if output_size < transform_size:
125
+ img = img.resize((output_size, output_size), Image.Resampling.LANCZOS)
126
+
127
+ # Resize to final output
128
+ return img