Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
b197ccc
1
Parent(s):
a1c0b29
add code
Browse files- README.md +63 -12
- app.py +115 -0
- nested_attention_pipeline.py +248 -0
- nested_attention_processor.py +363 -0
- resampler.py +169 -0
- utils.py +128 -0
README.md
CHANGED
@@ -1,12 +1,63 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Nested Attention: Semantic-aware Attention Values for Concept Personalization (SIGGRAPH 2025)
|
2 |
+
|
3 |
+

|
4 |
+
|
5 |
+
> **Nested Attention: Semantic-aware Attention Values for Concept Personalization**
|
6 |
+
> Or Patashnik, Rinon Gal, Daniil Ostashev, Sergey Tulyakov, Kfir Aberman, Daniel Cohen-Or
|
7 |
+
> https://arxiv.org/abs/2501.01407
|
8 |
+
>
|
9 |
+
> **Abstract:** Personalizing text-to-image models to generate images of specific subjects across diverse scenes and styles is a rapidly advancing field. Current approaches often struggle to balance identity preservation with alignment to the input text prompt. Some methods rely on a single textual token to represent a subject, limiting expressiveness, while others use richer representations but disrupt the model's prior, weakening prompt alignment.
|
10 |
+
> In this work, we introduce **Nested Attention**, a novel mechanism that injects rich and expressive image representations into the model's existing cross-attention layers. Our key idea is to generate query-dependent subject values, derived from nested attention layers that learn to select relevant subject features for each region in the generated image.
|
11 |
+
> We integrate these nested layers into an encoder-based personalization method and show that they enable strong identity preservation while maintaining adherence to input text prompts. Our approach is general and can be trained across various domains. Additionally, its prior preservation allows for combining multiple personalized subjects from different domains in a single image.
|
12 |
+
|
13 |
+
## Description
|
14 |
+
|
15 |
+
Official implementation of **Nested Attention**, an encoder-based method for text-to-image personalization using a novel nested attention mechanism.
|
16 |
+
|
17 |
+
The implementation of the nested attention mechanism can be found in `nested_attention_processor.py`.
|
18 |
+
|
19 |
+
This repository provides:
|
20 |
+
- An inference notebook (`inference_notebook.ipynb`)
|
21 |
+
- A trained encoder for faces
|
22 |
+
- A Gradio-based application
|
23 |
+
|
24 |
+
## Setup
|
25 |
+
|
26 |
+
Please download the following models:
|
27 |
+
- https://github.com/ageitgey/face_recognition_models/blob/master/face_recognition_models/models/shape_predictor_68_face_landmarks.dat
|
28 |
+
- https://github.com/justadudewhohacks/face-recognition.js-models/blob/master/models/mmod_human_face_detector.dat
|
29 |
+
- image encoder (add link)
|
30 |
+
- trained encoder (add link)
|
31 |
+
|
32 |
+
Tested with:
|
33 |
+
- `torch==2.6.0`
|
34 |
+
- `diffusers==0.33.1`
|
35 |
+
- `transformers==4.51.2`
|
36 |
+
|
37 |
+
## Usage
|
38 |
+
|
39 |
+
Refer to the inference notebook for an example. Key usage notes:
|
40 |
+
- The input image should be aligned and cropped.
|
41 |
+
- The special token `<person>` represents the personalized subject and **must appear exactly once** in the input prompt.
|
42 |
+
- The parameter `special_token_weight` corresponds to $\lambda$ in the paper, controlling the tradeoff between identity preservation and prompt adherence. Increasing this parameter improves identity preservation.
|
43 |
+
- The code supports multiple input images of the same subject. To enable this, set `multiple_images=True` and provide a list of images. For single-image usage, pass an image directly instead of a list.
|
44 |
+
|
45 |
+
## Related Work
|
46 |
+
|
47 |
+
This repository builds upon [IP-Adapter](https://ip-adapter.github.io/).
|
48 |
+
|
49 |
+
## BibTeX
|
50 |
+
|
51 |
+
```bibtex
|
52 |
+
@inproceedings{patashnik2025nested,
|
53 |
+
author = {Patashnik, Or and Gal, Rinon and Ostashev, Daniil and Tulyakov, Sergey and Aberman, Kfir and Cohen-Or, Daniel},
|
54 |
+
title = {Nested Attention: Semantic-aware Attention Values for Concept Personalization},
|
55 |
+
year = {2025},
|
56 |
+
publisher = {Association for Computing Machinery},
|
57 |
+
url = {https://doi.org/10.1145/3721238.3730634},
|
58 |
+
booktitle = {Proceedings of the Special Interest Group on Computer Graphics and Interactive Techniques Conference Conference Papers},
|
59 |
+
articleno = {6},
|
60 |
+
numpages = {12},
|
61 |
+
series = {SIGGRAPH Conference Papers '25}
|
62 |
+
}
|
63 |
+
```
|
app.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from diffusers import StableDiffusionXLPipeline
|
4 |
+
import gradio as gr
|
5 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
6 |
+
from nested_attention_pipeline import NestedAdapterInference, add_special_token_to_tokenizer
|
7 |
+
from utils import align_face
|
8 |
+
import dlib
|
9 |
+
|
10 |
+
|
11 |
+
# ----------------------
|
12 |
+
# Configuration (update paths as needed)
|
13 |
+
# ----------------------
|
14 |
+
SHAPE_PREDICTOR_PATH = hf_hub_download("orpatashnik/NestedAttentionEncoder", "shape_predictor_68_face_landmarks.dat")
|
15 |
+
FACE_DETECTOR_PATH = hf_hub_download("orpatashnik/NestedAttentionEncoder", "mmod_human_face_detector.dat")
|
16 |
+
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
17 |
+
image_encoder_path = snapshot_download("orpatashnik/NestedAttentionEncoder", allow_patterns=["image_encoder/**"])
|
18 |
+
image_encoder_path = os.path.join(image_encoder_path, "image_encoder")
|
19 |
+
personalization_ckpt = hf_hub_download("orpatashnik/NestedAttentionEncoder", "personalization_encoder/pytorch_model.bin")
|
20 |
+
device = "cuda"
|
21 |
+
|
22 |
+
# Special token settings
|
23 |
+
placeholder_token = "<person>"
|
24 |
+
initializer_token = "person"
|
25 |
+
|
26 |
+
# ----------------------
|
27 |
+
# Load models
|
28 |
+
# ----------------------
|
29 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
30 |
+
base_model_path,
|
31 |
+
torch_dtype=torch.float16,
|
32 |
+
)
|
33 |
+
add_special_token_to_tokenizer(pipe, placeholder_token, initializer_token)
|
34 |
+
ip_model = NestedAdapterInference(
|
35 |
+
pipe,
|
36 |
+
image_encoder_path,
|
37 |
+
personalization_ckpt,
|
38 |
+
1024,
|
39 |
+
vq_normalize_factor=2.0,
|
40 |
+
device=device
|
41 |
+
)
|
42 |
+
|
43 |
+
# Initialize face alignment predictor
|
44 |
+
predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)
|
45 |
+
detector = dlib.cnn_face_detection_model_v1(FACE_DETECTOR_PATH)
|
46 |
+
|
47 |
+
# Generation defaults
|
48 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
49 |
+
num_inference_steps = 30
|
50 |
+
guidance_scale = 5.0
|
51 |
+
|
52 |
+
# ----------------------
|
53 |
+
# Inference function with alignment
|
54 |
+
# ----------------------
|
55 |
+
def generate_images(img1, img2, img3, prompt, w, num_samples, seed):
|
56 |
+
# Collect non-empty reference images
|
57 |
+
refs = [img for img in (img1, img2, img3) if img is not None]
|
58 |
+
if not refs:
|
59 |
+
return []
|
60 |
+
|
61 |
+
# Align directly on PIL
|
62 |
+
aligned_refs = [align_face(img, predictor, detector) for img in refs]
|
63 |
+
|
64 |
+
# Resize to model resolution
|
65 |
+
pil_images = [aligned.resize((512, 512)) for aligned in aligned_refs]
|
66 |
+
placeholder_token_ids = ip_model.pipe.tokenizer.convert_tokens_to_ids([placeholder_token])
|
67 |
+
|
68 |
+
# Generate personalized samples
|
69 |
+
results = ip_model.generate(
|
70 |
+
pil_image=pil_images,
|
71 |
+
prompt=prompt,
|
72 |
+
negative_prompt=negative_prompt,
|
73 |
+
num_samples=num_samples,
|
74 |
+
num_inference_steps=num_inference_steps,
|
75 |
+
placeholder_token_ids=placeholder_token_ids,
|
76 |
+
seed=seed if seed > 0 else None,
|
77 |
+
guidance_scale=guidance_scale,
|
78 |
+
multiple_images=True,
|
79 |
+
special_token_weight=w
|
80 |
+
)
|
81 |
+
return results
|
82 |
+
|
83 |
+
# ----------------------
|
84 |
+
# Gradio UI
|
85 |
+
# ----------------------
|
86 |
+
with gr.Blocks() as demo:
|
87 |
+
gr.Markdown("## Personalized Image Generation Demo")
|
88 |
+
gr.Markdown(
|
89 |
+
"Upload up to 3 reference images. "
|
90 |
+
"Faces will be auto-aligned before personalization. Include the placeholder token (e.g., \\<person\\>) in your prompt, "
|
91 |
+
"set token weight, and choose how many outputs you want."
|
92 |
+
)
|
93 |
+
with gr.Row():
|
94 |
+
with gr.Column(scale=1):
|
95 |
+
# Reference images
|
96 |
+
with gr.Row():
|
97 |
+
img1 = gr.Image(type="pil", label="Reference Image 1")
|
98 |
+
img2 = gr.Image(type="pil", label="Reference Image 2 (optional)")
|
99 |
+
img3 = gr.Image(type="pil", label="Reference Image 3 (optional)")
|
100 |
+
prompt_input = gr.Textbox(label="Prompt", placeholder="e.g., an abstract pencil drawing of a <person>")
|
101 |
+
w_input = gr.Slider(minimum=1.0, maximum=5.0, step=0.5, value=1.0, label="Special Token Weight (w)")
|
102 |
+
num_samples_input = gr.Slider(minimum=1, maximum=6, step=1, value=4, label="Number of Images to Generate")
|
103 |
+
seed_input = gr.Slider(minimum=-1, maximum=100000, step=1, value=-1, label="Random Seed (use -1 for random and up to 100000)")
|
104 |
+
generate_button = gr.Button("Generate Images")
|
105 |
+
with gr.Column(scale=1):
|
106 |
+
output_gallery = gr.Gallery(label="Generated Images", columns=3)
|
107 |
+
|
108 |
+
generate_button.click(
|
109 |
+
fn=generate_images,
|
110 |
+
inputs=[img1, img2, img3, prompt_input, w_input, num_samples_input, seed_input],
|
111 |
+
outputs=output_gallery
|
112 |
+
)
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
demo.launch(share=True, debug=True)
|
nested_attention_pipeline.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
7 |
+
|
8 |
+
from nested_attention_processor import AttnProcessor, NestedAttnProcessor
|
9 |
+
from utils import get_generator
|
10 |
+
|
11 |
+
from resampler import Resampler
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
def add_special_token_to_tokenizer(
|
16 |
+
pipe,
|
17 |
+
placeholder_token,
|
18 |
+
initializer_token
|
19 |
+
):
|
20 |
+
num_added_tokens1 = pipe.tokenizer.add_tokens([placeholder_token])
|
21 |
+
num_added_tokens2 = pipe.tokenizer_2.add_tokens([placeholder_token])
|
22 |
+
if num_added_tokens1 != 1 or num_added_tokens2 != 1:
|
23 |
+
raise ValueError("Failed to add placeholder token to tokenizer")
|
24 |
+
|
25 |
+
token_ids1 = pipe.tokenizer.encode(initializer_token, add_special_tokens=False)
|
26 |
+
token_ids2 = pipe.tokenizer_2.encode(initializer_token, add_special_tokens=False)
|
27 |
+
if len(token_ids1) > 1 or len(token_ids2) > 1:
|
28 |
+
raise ValueError("The initializer token must be a single token.")
|
29 |
+
initializer_token_id1 = token_ids1[0]
|
30 |
+
initializer_token_id2 = token_ids2[0]
|
31 |
+
placeholder_token_ids1 = pipe.tokenizer.convert_tokens_to_ids([placeholder_token])
|
32 |
+
placeholder_token_ids2 = pipe.tokenizer_2.convert_tokens_to_ids([placeholder_token])
|
33 |
+
pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))
|
34 |
+
pipe.text_encoder_2.resize_token_embeddings(len(pipe.tokenizer_2))
|
35 |
+
token_embeds1 = pipe.text_encoder.get_input_embeddings().weight.data
|
36 |
+
token_embeds2 = pipe.text_encoder_2.get_input_embeddings().weight.data
|
37 |
+
with torch.no_grad():
|
38 |
+
for token_id in placeholder_token_ids1:
|
39 |
+
token_embeds1[token_id] = token_embeds1[initializer_token_id1].clone()
|
40 |
+
for token_id in placeholder_token_ids2:
|
41 |
+
token_embeds2[token_id] = token_embeds2[initializer_token_id2].clone()
|
42 |
+
|
43 |
+
|
44 |
+
class NestedAdapterInference:
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
sd_pipe,
|
48 |
+
image_encoder_path,
|
49 |
+
adapter_ckpt,
|
50 |
+
resampler_num_queries,
|
51 |
+
vq_normalize_factor,
|
52 |
+
device,
|
53 |
+
):
|
54 |
+
self.device = device
|
55 |
+
self.image_encoder_path = image_encoder_path
|
56 |
+
self.adapter_ckpt = adapter_ckpt
|
57 |
+
|
58 |
+
self.vq_normalize_factor = vq_normalize_factor
|
59 |
+
|
60 |
+
self.pipe = sd_pipe.to(self.device)
|
61 |
+
self.set_nested_adapter()
|
62 |
+
|
63 |
+
# load image encoder
|
64 |
+
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
65 |
+
self.image_encoder_path
|
66 |
+
).to(self.device, dtype=torch.float16)
|
67 |
+
self.clip_image_processor = CLIPImageProcessor()
|
68 |
+
|
69 |
+
# spatial features model
|
70 |
+
self.qformer = Resampler(
|
71 |
+
dim=self.pipe.unet.config.cross_attention_dim,
|
72 |
+
depth=4,
|
73 |
+
dim_head=64,
|
74 |
+
heads=12,
|
75 |
+
num_queries=resampler_num_queries,
|
76 |
+
embedding_dim=self.image_encoder.config.hidden_size,
|
77 |
+
output_dim=self.pipe.unet.config.cross_attention_dim,
|
78 |
+
ff_mult=4,
|
79 |
+
).to(self.device, dtype=torch.float16)
|
80 |
+
|
81 |
+
if adapter_ckpt is not None:
|
82 |
+
self.load_nested_adapter()
|
83 |
+
|
84 |
+
def set_nested_adapter(self):
|
85 |
+
unet = self.pipe.unet
|
86 |
+
attn_procs = {}
|
87 |
+
for name in unet.attn_processors.keys():
|
88 |
+
cross_attention_dim = (
|
89 |
+
None
|
90 |
+
if name.endswith("attn1.processor")
|
91 |
+
else unet.config.cross_attention_dim
|
92 |
+
)
|
93 |
+
if name.startswith("mid_block"):
|
94 |
+
hidden_size = unet.config.block_out_channels[-1]
|
95 |
+
elif name.startswith("up_blocks"):
|
96 |
+
block_id = int(name[len("up_blocks.")])
|
97 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
98 |
+
elif name.startswith("down_blocks"):
|
99 |
+
block_id = int(name[len("down_blocks.")])
|
100 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
101 |
+
if cross_attention_dim is None:
|
102 |
+
attn_procs[name] = AttnProcessor()
|
103 |
+
else:
|
104 |
+
attn_procs[name] = NestedAttnProcessor(
|
105 |
+
hidden_size=hidden_size,
|
106 |
+
cross_attention_dim=cross_attention_dim,
|
107 |
+
normalize_factor=self.vq_normalize_factor,
|
108 |
+
).to(self.device, dtype=torch.float16)
|
109 |
+
unet.set_attn_processor(attn_procs)
|
110 |
+
|
111 |
+
def load_nested_adapter(self):
|
112 |
+
state_dict = {"adapter_modules": {}, "qformer": {}}
|
113 |
+
f = torch.load(self.adapter_ckpt, map_location="cpu")
|
114 |
+
for key in f.keys():
|
115 |
+
if key.startswith("adapter_modules."):
|
116 |
+
state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = f[
|
117 |
+
key
|
118 |
+
]
|
119 |
+
elif key.startswith("spatial_features_model."):
|
120 |
+
state_dict["qformer"][key.replace("spatial_features_model.", "")] = f[
|
121 |
+
key
|
122 |
+
]
|
123 |
+
self.qformer.load_state_dict(state_dict["qformer"])
|
124 |
+
adapter_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
125 |
+
adapter_layers.load_state_dict(state_dict["adapter_modules"])
|
126 |
+
|
127 |
+
@torch.inference_mode()
|
128 |
+
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
|
129 |
+
if isinstance(pil_image, Image.Image):
|
130 |
+
pil_image = [pil_image]
|
131 |
+
clip_image = self.clip_image_processor(
|
132 |
+
images=pil_image, return_tensors="pt"
|
133 |
+
).pixel_values
|
134 |
+
clip_image_embeds = self.image_encoder(
|
135 |
+
clip_image.to(self.device, dtype=torch.float16)
|
136 |
+
)
|
137 |
+
spatial_clip_image_embeds = clip_image_embeds.last_hidden_state
|
138 |
+
spatial_clip_image_embeds = spatial_clip_image_embeds[:, 1:] # remove CLS token
|
139 |
+
return spatial_clip_image_embeds
|
140 |
+
|
141 |
+
def generate(
|
142 |
+
self,
|
143 |
+
pil_image=None,
|
144 |
+
clip_image_embeds=None,
|
145 |
+
prompt=None,
|
146 |
+
placeholder_token_ids=None,
|
147 |
+
negative_prompt=None,
|
148 |
+
scale=1.0,
|
149 |
+
num_samples=4,
|
150 |
+
seed=None,
|
151 |
+
guidance_scale=5.0,
|
152 |
+
num_inference_steps=30,
|
153 |
+
multiple_images=False,
|
154 |
+
special_token_weight=1.0,
|
155 |
+
**kwargs,
|
156 |
+
):
|
157 |
+
if pil_image is not None:
|
158 |
+
num_prompts = (
|
159 |
+
1
|
160 |
+
if isinstance(pil_image, Image.Image) or multiple_images
|
161 |
+
else len(pil_image)
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
num_prompts = clip_image_embeds.size(0)
|
165 |
+
|
166 |
+
if prompt is None:
|
167 |
+
prompt = "best quality, high quality"
|
168 |
+
if negative_prompt is None:
|
169 |
+
negative_prompt = (
|
170 |
+
"monochrome, lowres, bad anatomy, worst quality, low quality"
|
171 |
+
)
|
172 |
+
|
173 |
+
if not isinstance(prompt, List):
|
174 |
+
prompt = [prompt] * num_prompts
|
175 |
+
if not isinstance(negative_prompt, List):
|
176 |
+
negative_prompt = [negative_prompt] * num_prompts
|
177 |
+
|
178 |
+
text_input_ids = self.pipe.tokenizer(
|
179 |
+
prompt,
|
180 |
+
max_length=self.pipe.tokenizer.model_max_length,
|
181 |
+
padding="max_length",
|
182 |
+
truncation=True,
|
183 |
+
return_tensors="pt",
|
184 |
+
).input_ids
|
185 |
+
special_token_indices = (text_input_ids == placeholder_token_ids[0]).nonzero()[
|
186 |
+
:, 1
|
187 |
+
]
|
188 |
+
|
189 |
+
spatial_clip_image_embeds = self.get_image_embeds(
|
190 |
+
pil_image=pil_image, clip_image_embeds=clip_image_embeds
|
191 |
+
) # (bs, 256, 1280)
|
192 |
+
|
193 |
+
with torch.no_grad():
|
194 |
+
(
|
195 |
+
prompt_embeds,
|
196 |
+
negative_prompt_embeds,
|
197 |
+
pooled_prompt_embeds,
|
198 |
+
negative_pooled_prompt_embeds,
|
199 |
+
) = self.pipe.encode_prompt(
|
200 |
+
prompt,
|
201 |
+
num_images_per_prompt=num_samples,
|
202 |
+
do_classifier_free_guidance=True,
|
203 |
+
negative_prompt=negative_prompt,
|
204 |
+
)
|
205 |
+
|
206 |
+
special_token_indices = (text_input_ids == placeholder_token_ids[0]).nonzero()[
|
207 |
+
:, 1
|
208 |
+
]
|
209 |
+
|
210 |
+
with torch.no_grad():
|
211 |
+
qformer_tokens_out = self.qformer(spatial_clip_image_embeds)
|
212 |
+
|
213 |
+
if multiple_images:
|
214 |
+
b, num_tokens, d = qformer_tokens_out.shape
|
215 |
+
qformer_tokens_out = qformer_tokens_out.reshape(
|
216 |
+
1, num_tokens * b, d
|
217 |
+
)
|
218 |
+
|
219 |
+
bs_embed, num_tokens, _ = qformer_tokens_out.shape
|
220 |
+
|
221 |
+
qformer_tokens_out = qformer_tokens_out.repeat(1, num_samples, 1, 1)
|
222 |
+
qformer_tokens_out = qformer_tokens_out.view(
|
223 |
+
bs_embed * num_samples, num_tokens, -1
|
224 |
+
)
|
225 |
+
qformer_tokens_out = qformer_tokens_out.repeat_interleave(2, dim=0)
|
226 |
+
|
227 |
+
cross_attention_kwargs = {
|
228 |
+
"qformer_tokens_out": qformer_tokens_out,
|
229 |
+
"special_token_indices": special_token_indices,
|
230 |
+
"special_token_weight": special_token_weight,
|
231 |
+
"inference_mode": True,
|
232 |
+
}
|
233 |
+
|
234 |
+
generator = get_generator(seed, self.device)
|
235 |
+
|
236 |
+
images = self.pipe(
|
237 |
+
prompt_embeds=prompt_embeds,
|
238 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
239 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
240 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
241 |
+
guidance_scale=guidance_scale,
|
242 |
+
num_inference_steps=num_inference_steps,
|
243 |
+
generator=generator,
|
244 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
245 |
+
**kwargs,
|
246 |
+
).images
|
247 |
+
|
248 |
+
return images
|
nested_attention_processor.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
def my_scaled_dot_product_attention(
|
9 |
+
query,
|
10 |
+
key,
|
11 |
+
value,
|
12 |
+
attn_mask=None,
|
13 |
+
dropout_p=0.0,
|
14 |
+
is_causal=False,
|
15 |
+
scale=None,
|
16 |
+
special_token_weight=1.0,
|
17 |
+
special_token_indices=None,
|
18 |
+
) -> torch.Tensor:
|
19 |
+
"""
|
20 |
+
Computes the scaled dot-product attention with additional control over specific tokens.
|
21 |
+
|
22 |
+
This function is a re-implementation of the scaled dot-product attention mechanism,
|
23 |
+
designed to return both the attention map and the output of the attention operation.
|
24 |
+
It also provides additional control via a scalar that modifies the attention map
|
25 |
+
for specific tokens.
|
26 |
+
"""
|
27 |
+
L, S = query.size(-2), key.size(-2)
|
28 |
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
29 |
+
attn_bias = torch.zeros(L, S, dtype=query.dtype).cuda()
|
30 |
+
if is_causal:
|
31 |
+
assert attn_mask is None
|
32 |
+
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
33 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
34 |
+
attn_bias.to(query.dtype)
|
35 |
+
|
36 |
+
if attn_mask is not None:
|
37 |
+
if attn_mask.dtype == torch.bool:
|
38 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
39 |
+
else:
|
40 |
+
attn_bias += attn_mask
|
41 |
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
42 |
+
attn_weight += attn_bias
|
43 |
+
if special_token_indices is not None and special_token_weight != 1.0:
|
44 |
+
bs = attn_weight.shape[0]
|
45 |
+
attn_weight[torch.arange(bs), :, :, special_token_indices] = torch.max(
|
46 |
+
attn_weight[torch.arange(bs), :, :, special_token_indices],
|
47 |
+
attn_weight[torch.arange(bs), :, :, special_token_indices]
|
48 |
+
* special_token_weight,
|
49 |
+
)
|
50 |
+
|
51 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
52 |
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
53 |
+
return attn_weight @ value, attn_weight
|
54 |
+
|
55 |
+
|
56 |
+
class AttnProcessor(torch.nn.Module):
|
57 |
+
r"""
|
58 |
+
Processor for implementing scaled dot-product attention.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
hidden_size=None,
|
64 |
+
cross_attention_dim=None,
|
65 |
+
):
|
66 |
+
super().__init__()
|
67 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
68 |
+
raise ImportError(
|
69 |
+
"AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
70 |
+
)
|
71 |
+
|
72 |
+
def __call__(
|
73 |
+
self,
|
74 |
+
attn,
|
75 |
+
hidden_states,
|
76 |
+
qformer_tokens_out=None,
|
77 |
+
special_token_indices=None,
|
78 |
+
inference_mode=None,
|
79 |
+
encoder_hidden_states=None,
|
80 |
+
attention_mask=None,
|
81 |
+
temb=None,
|
82 |
+
special_token_weight=None,
|
83 |
+
):
|
84 |
+
residual = hidden_states
|
85 |
+
|
86 |
+
if attn.spatial_norm is not None:
|
87 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
88 |
+
|
89 |
+
input_ndim = hidden_states.ndim
|
90 |
+
|
91 |
+
if input_ndim == 4:
|
92 |
+
batch_size, channel, height, width = hidden_states.shape
|
93 |
+
hidden_states = hidden_states.view(
|
94 |
+
batch_size, channel, height * width
|
95 |
+
).transpose(1, 2)
|
96 |
+
|
97 |
+
batch_size, sequence_length, _ = (
|
98 |
+
hidden_states.shape
|
99 |
+
if encoder_hidden_states is None
|
100 |
+
else encoder_hidden_states.shape
|
101 |
+
)
|
102 |
+
|
103 |
+
if attention_mask is not None:
|
104 |
+
attention_mask = attn.prepare_attention_mask(
|
105 |
+
attention_mask, sequence_length, batch_size
|
106 |
+
)
|
107 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
108 |
+
# (batch, heads, source_length, target_length)
|
109 |
+
attention_mask = attention_mask.view(
|
110 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
111 |
+
)
|
112 |
+
|
113 |
+
if attn.group_norm is not None:
|
114 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
115 |
+
1, 2
|
116 |
+
)
|
117 |
+
|
118 |
+
query = attn.to_q(hidden_states)
|
119 |
+
|
120 |
+
if encoder_hidden_states is None:
|
121 |
+
encoder_hidden_states = hidden_states
|
122 |
+
elif attn.norm_cross:
|
123 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
124 |
+
encoder_hidden_states
|
125 |
+
)
|
126 |
+
|
127 |
+
key = attn.to_k(encoder_hidden_states)
|
128 |
+
value = attn.to_v(encoder_hidden_states)
|
129 |
+
|
130 |
+
inner_dim = key.shape[-1]
|
131 |
+
head_dim = inner_dim // attn.heads
|
132 |
+
|
133 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
134 |
+
|
135 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
136 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
137 |
+
|
138 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
139 |
+
hidden_states = F.scaled_dot_product_attention(
|
140 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
141 |
+
)
|
142 |
+
|
143 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
144 |
+
batch_size, -1, attn.heads * head_dim
|
145 |
+
)
|
146 |
+
hidden_states = hidden_states.to(query.dtype)
|
147 |
+
|
148 |
+
# linear proj
|
149 |
+
hidden_states = attn.to_out[0](hidden_states)
|
150 |
+
# dropout
|
151 |
+
hidden_states = attn.to_out[1](hidden_states)
|
152 |
+
|
153 |
+
if input_ndim == 4:
|
154 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
155 |
+
batch_size, channel, height, width
|
156 |
+
)
|
157 |
+
|
158 |
+
if attn.residual_connection:
|
159 |
+
hidden_states = hidden_states + residual
|
160 |
+
|
161 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
162 |
+
|
163 |
+
return hidden_states
|
164 |
+
|
165 |
+
|
166 |
+
class NestedAttnProcessor(torch.nn.Module):
|
167 |
+
r"""
|
168 |
+
Nested Attention processor for IP-Adapater for PyTorch 2.0.
|
169 |
+
"""
|
170 |
+
|
171 |
+
def __init__(self, hidden_size, cross_attention_dim=None, normalize_factor=1.0):
|
172 |
+
super().__init__()
|
173 |
+
|
174 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
175 |
+
raise ImportError(
|
176 |
+
"NestedAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
177 |
+
)
|
178 |
+
|
179 |
+
self.hidden_size = hidden_size
|
180 |
+
self.cross_attention_dim = cross_attention_dim
|
181 |
+
|
182 |
+
self.normalize_factor = normalize_factor
|
183 |
+
|
184 |
+
self.nested_to_k = nn.Linear(
|
185 |
+
cross_attention_dim or hidden_size, hidden_size, bias=False
|
186 |
+
)
|
187 |
+
self.nested_to_v = nn.Linear(
|
188 |
+
cross_attention_dim or hidden_size, hidden_size, bias=False
|
189 |
+
)
|
190 |
+
|
191 |
+
def __call__(
|
192 |
+
self,
|
193 |
+
attn,
|
194 |
+
hidden_states,
|
195 |
+
qformer_tokens_out,
|
196 |
+
special_token_indices,
|
197 |
+
inference_mode=False,
|
198 |
+
encoder_hidden_states=None,
|
199 |
+
attention_mask=None,
|
200 |
+
temb=None,
|
201 |
+
special_token_weight=1.0,
|
202 |
+
):
|
203 |
+
assert (
|
204 |
+
special_token_indices.shape[0] > 0
|
205 |
+
), "special_token_indices should not be empty"
|
206 |
+
|
207 |
+
# if inference mode is set to True, the code assumes that CFG is used and the first half
|
208 |
+
# of the batch is used for the null prompt and the second half is used for the prompt
|
209 |
+
|
210 |
+
residual = hidden_states
|
211 |
+
|
212 |
+
if attn.spatial_norm is not None:
|
213 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
214 |
+
|
215 |
+
input_ndim = hidden_states.ndim
|
216 |
+
bs = hidden_states.shape[0]
|
217 |
+
|
218 |
+
if input_ndim == 4:
|
219 |
+
bs, channel, height, width = hidden_states.shape
|
220 |
+
hidden_states = hidden_states.view(bs, channel, height * width).transpose(
|
221 |
+
1, 2
|
222 |
+
)
|
223 |
+
|
224 |
+
bs, sequence_length, _ = (
|
225 |
+
hidden_states.shape
|
226 |
+
if encoder_hidden_states is None
|
227 |
+
else encoder_hidden_states.shape
|
228 |
+
)
|
229 |
+
|
230 |
+
if attention_mask is not None:
|
231 |
+
attention_mask = attn.prepare_attention_mask(
|
232 |
+
attention_mask, sequence_length, bs
|
233 |
+
)
|
234 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
235 |
+
# (batch, heads, source_length, target_length)
|
236 |
+
attention_mask = attention_mask.view(
|
237 |
+
bs, attn.heads, -1, attention_mask.shape[-1]
|
238 |
+
)
|
239 |
+
|
240 |
+
if attn.group_norm is not None:
|
241 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
242 |
+
1, 2
|
243 |
+
)
|
244 |
+
|
245 |
+
query = attn.to_q(hidden_states)
|
246 |
+
|
247 |
+
if encoder_hidden_states is None:
|
248 |
+
encoder_hidden_states = hidden_states
|
249 |
+
else:
|
250 |
+
if attn.norm_cross:
|
251 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
252 |
+
encoder_hidden_states
|
253 |
+
)
|
254 |
+
|
255 |
+
key = attn.to_k(encoder_hidden_states)
|
256 |
+
value = attn.to_v(encoder_hidden_states)
|
257 |
+
|
258 |
+
inner_dim = key.shape[-1]
|
259 |
+
head_dim = inner_dim // attn.heads
|
260 |
+
|
261 |
+
query = query.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
|
262 |
+
|
263 |
+
key = key.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
|
264 |
+
value = value.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
|
265 |
+
|
266 |
+
# nested attention
|
267 |
+
nested_key = self.nested_to_k(qformer_tokens_out)
|
268 |
+
nested_value = self.nested_to_v(qformer_tokens_out)
|
269 |
+
|
270 |
+
nested_key = nested_key.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
|
271 |
+
nested_value = nested_value.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
|
272 |
+
|
273 |
+
nested_hidden_states = F.scaled_dot_product_attention(
|
274 |
+
query,
|
275 |
+
nested_key,
|
276 |
+
nested_value,
|
277 |
+
attn_mask=None,
|
278 |
+
dropout_p=0.0,
|
279 |
+
is_causal=False,
|
280 |
+
)
|
281 |
+
|
282 |
+
# normalize V_q
|
283 |
+
textual_values_norms = torch.norm(
|
284 |
+
value[torch.arange(bs), :, special_token_indices], dim=-1
|
285 |
+
)
|
286 |
+
nested_hidden_states = (
|
287 |
+
torch.nn.functional.normalize(nested_hidden_states, p=2, dim=-1)
|
288 |
+
* self.normalize_factor
|
289 |
+
)
|
290 |
+
nested_hidden_states = (
|
291 |
+
textual_values_norms.view(bs, -1, 1, 1) * nested_hidden_states
|
292 |
+
)
|
293 |
+
|
294 |
+
# outer attention
|
295 |
+
value_without_special_tokens = value.clone()
|
296 |
+
if inference_mode:
|
297 |
+
value_without_special_tokens[bs // 2 : bs, :, special_token_indices, :] = (
|
298 |
+
0.0
|
299 |
+
)
|
300 |
+
else:
|
301 |
+
value_without_special_tokens[
|
302 |
+
torch.arange(bs), :, special_token_indices, :
|
303 |
+
] = 0.0
|
304 |
+
hidden_states_without_special_tokens, attn_weight = (
|
305 |
+
my_scaled_dot_product_attention(
|
306 |
+
query,
|
307 |
+
key,
|
308 |
+
value_without_special_tokens,
|
309 |
+
attn_mask=None,
|
310 |
+
dropout_p=0.0,
|
311 |
+
is_causal=False,
|
312 |
+
special_token_weight=special_token_weight,
|
313 |
+
special_token_indices=special_token_indices,
|
314 |
+
)
|
315 |
+
)
|
316 |
+
|
317 |
+
# add the special token values
|
318 |
+
if inference_mode:
|
319 |
+
special_token_attn_weight = attn_weight[
|
320 |
+
bs // 2 : bs, :, :, special_token_indices
|
321 |
+
]
|
322 |
+
else:
|
323 |
+
special_token_attn_weight = attn_weight[
|
324 |
+
torch.arange(bs), :, :, special_token_indices
|
325 |
+
]
|
326 |
+
if inference_mode:
|
327 |
+
special_token_weighted_values = (
|
328 |
+
special_token_attn_weight * nested_hidden_states[bs // 2 : bs]
|
329 |
+
)
|
330 |
+
else:
|
331 |
+
special_token_weighted_values = (
|
332 |
+
special_token_attn_weight.unsqueeze(-1) * nested_hidden_states
|
333 |
+
)
|
334 |
+
if inference_mode:
|
335 |
+
hidden_states = hidden_states_without_special_tokens
|
336 |
+
hidden_states[bs // 2 : bs] += special_token_weighted_values
|
337 |
+
else:
|
338 |
+
hidden_states = (
|
339 |
+
hidden_states_without_special_tokens + special_token_weighted_values
|
340 |
+
)
|
341 |
+
|
342 |
+
# arrange hidden states
|
343 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
344 |
+
bs, -1, attn.heads * head_dim
|
345 |
+
)
|
346 |
+
hidden_states = hidden_states.to(query.dtype)
|
347 |
+
|
348 |
+
# linear proj
|
349 |
+
hidden_states = attn.to_out[0](hidden_states)
|
350 |
+
# dropout
|
351 |
+
hidden_states = attn.to_out[1](hidden_states)
|
352 |
+
|
353 |
+
if input_ndim == 4:
|
354 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
355 |
+
bs, channel, height, width
|
356 |
+
)
|
357 |
+
|
358 |
+
if attn.residual_connection:
|
359 |
+
hidden_states = hidden_states + residual
|
360 |
+
|
361 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
362 |
+
|
363 |
+
return hidden_states
|
resampler.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2 |
+
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from einops import rearrange
|
9 |
+
from einops.layers.torch import Rearrange
|
10 |
+
|
11 |
+
|
12 |
+
# FFN
|
13 |
+
def FeedForward(dim, mult=4):
|
14 |
+
inner_dim = int(dim * mult)
|
15 |
+
return nn.Sequential(
|
16 |
+
nn.LayerNorm(dim),
|
17 |
+
nn.Linear(dim, inner_dim, bias=False),
|
18 |
+
nn.GELU(),
|
19 |
+
nn.Linear(inner_dim, dim, bias=False),
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def reshape_tensor(x, heads):
|
24 |
+
bs, length, width = x.shape
|
25 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
26 |
+
x = x.view(bs, length, heads, -1)
|
27 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
28 |
+
x = x.transpose(1, 2)
|
29 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
30 |
+
x = x.reshape(bs, heads, length, -1)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class PerceiverAttention(nn.Module):
|
35 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
36 |
+
super().__init__()
|
37 |
+
self.scale = dim_head**-0.5
|
38 |
+
self.dim_head = dim_head
|
39 |
+
self.heads = heads
|
40 |
+
inner_dim = dim_head * heads
|
41 |
+
|
42 |
+
self.norm1 = nn.LayerNorm(dim)
|
43 |
+
self.norm2 = nn.LayerNorm(dim)
|
44 |
+
|
45 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
46 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
47 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
48 |
+
|
49 |
+
def forward(self, x, latents):
|
50 |
+
"""
|
51 |
+
Args:
|
52 |
+
x (torch.Tensor): image features
|
53 |
+
shape (b, n1, D)
|
54 |
+
latent (torch.Tensor): latent features
|
55 |
+
shape (b, n2, D)
|
56 |
+
"""
|
57 |
+
x = self.norm1(x)
|
58 |
+
latents = self.norm2(latents)
|
59 |
+
|
60 |
+
b, l, _ = latents.shape
|
61 |
+
|
62 |
+
q = self.to_q(latents)
|
63 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
64 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
65 |
+
|
66 |
+
q = reshape_tensor(q, self.heads)
|
67 |
+
k = reshape_tensor(k, self.heads)
|
68 |
+
v = reshape_tensor(v, self.heads)
|
69 |
+
|
70 |
+
# attention
|
71 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
72 |
+
weight = (q * scale) @ (k * scale).transpose(
|
73 |
+
-2, -1
|
74 |
+
) # More stable with f16 than dividing afterwards
|
75 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
76 |
+
out = weight @ v
|
77 |
+
|
78 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
79 |
+
|
80 |
+
return self.to_out(out)
|
81 |
+
|
82 |
+
|
83 |
+
class Resampler(nn.Module):
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
dim=1024,
|
87 |
+
depth=8,
|
88 |
+
dim_head=64,
|
89 |
+
heads=16,
|
90 |
+
num_queries=8,
|
91 |
+
embedding_dim=768,
|
92 |
+
output_dim=1024,
|
93 |
+
ff_mult=4,
|
94 |
+
max_seq_len: int = 257, # CLIP tokens + CLS token
|
95 |
+
apply_pos_emb: bool = False,
|
96 |
+
num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
self.num_queries = num_queries
|
100 |
+
|
101 |
+
self.pos_emb = (
|
102 |
+
nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
|
103 |
+
)
|
104 |
+
|
105 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
106 |
+
|
107 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
108 |
+
|
109 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
110 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
111 |
+
|
112 |
+
self.to_latents_from_mean_pooled_seq = (
|
113 |
+
nn.Sequential(
|
114 |
+
nn.LayerNorm(dim),
|
115 |
+
nn.Linear(dim, dim * num_latents_mean_pooled),
|
116 |
+
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
|
117 |
+
)
|
118 |
+
if num_latents_mean_pooled > 0
|
119 |
+
else None
|
120 |
+
)
|
121 |
+
|
122 |
+
self.layers = nn.ModuleList([])
|
123 |
+
for _ in range(depth):
|
124 |
+
self.layers.append(
|
125 |
+
nn.ModuleList(
|
126 |
+
[
|
127 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
128 |
+
FeedForward(dim=dim, mult=ff_mult),
|
129 |
+
]
|
130 |
+
)
|
131 |
+
)
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
|
135 |
+
if self.pos_emb is not None:
|
136 |
+
n, device = x.shape[1], x.device
|
137 |
+
pos_emb = self.pos_emb(torch.arange(n, device=device))
|
138 |
+
x = x + pos_emb
|
139 |
+
|
140 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
141 |
+
|
142 |
+
x = self.proj_in(x)
|
143 |
+
|
144 |
+
if self.to_latents_from_mean_pooled_seq:
|
145 |
+
meanpooled_seq = masked_mean(
|
146 |
+
x,
|
147 |
+
dim=1,
|
148 |
+
mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool),
|
149 |
+
)
|
150 |
+
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
|
151 |
+
latents = torch.cat((meanpooled_latents, latents), dim=-2)
|
152 |
+
|
153 |
+
for attn, ff in self.layers:
|
154 |
+
latents = attn(x, latents) + latents
|
155 |
+
latents = ff(latents) + latents
|
156 |
+
|
157 |
+
latents = self.proj_out(latents)
|
158 |
+
return self.norm_out(latents)
|
159 |
+
|
160 |
+
|
161 |
+
def masked_mean(t, *, dim, mask=None):
|
162 |
+
if mask is None:
|
163 |
+
return t.mean(dim=dim)
|
164 |
+
|
165 |
+
denom = mask.sum(dim=dim, keepdim=True)
|
166 |
+
mask = rearrange(mask, "b n -> b n 1")
|
167 |
+
masked_t = t.masked_fill(~mask, 0.0)
|
168 |
+
|
169 |
+
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
|
utils.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import dlib
|
5 |
+
import scipy
|
6 |
+
|
7 |
+
def image_grid(imgs, rows, cols):
|
8 |
+
assert len(imgs) == rows*cols
|
9 |
+
|
10 |
+
w, h = imgs[0].size
|
11 |
+
grid = Image.new('RGB', size=(cols*w, rows*h))
|
12 |
+
grid_w, grid_h = grid.size
|
13 |
+
|
14 |
+
for i, img in enumerate(imgs):
|
15 |
+
grid.paste(img, box=(i%cols*w, i//cols*h))
|
16 |
+
return grid
|
17 |
+
|
18 |
+
|
19 |
+
def get_generator(seed, device):
|
20 |
+
|
21 |
+
if seed is not None:
|
22 |
+
if isinstance(seed, list):
|
23 |
+
generator = [
|
24 |
+
torch.Generator(device).manual_seed(seed_item) for seed_item in seed
|
25 |
+
]
|
26 |
+
else:
|
27 |
+
generator = torch.Generator(device).manual_seed(seed)
|
28 |
+
else:
|
29 |
+
generator = None
|
30 |
+
|
31 |
+
return generator
|
32 |
+
|
33 |
+
def get_landmark_pil(pil_image, predictor, detector):
|
34 |
+
"""Get 68 facial landmarks as a NumPy array of shape (68, 2)."""
|
35 |
+
img_np = np.array(pil_image.convert("RGB"))
|
36 |
+
dets = detector(img_np, 1)
|
37 |
+
if not dets:
|
38 |
+
return None
|
39 |
+
# Handle mmod or frontal detector output
|
40 |
+
det = dets[0].rect if hasattr(dets[0], 'rect') else dets[0]
|
41 |
+
shape = predictor(img_np, det)
|
42 |
+
coords = [(pt.x, pt.y) for pt in shape.parts()]
|
43 |
+
return np.array(coords)
|
44 |
+
|
45 |
+
|
46 |
+
def align_face(pil_image, predictor, detector):
|
47 |
+
"""Align a face from a PIL.Image, returning an aligned PIL.Image of size 512x512."""
|
48 |
+
lm = get_landmark_pil(pil_image, predictor, detector)
|
49 |
+
if lm is None:
|
50 |
+
return pil_image
|
51 |
+
# Define landmark regions
|
52 |
+
lm_chin = lm[0: 17] # left-right
|
53 |
+
lm_eyebrow_left = lm[17: 22] # left-right
|
54 |
+
lm_eyebrow_right = lm[22: 27] # left-right
|
55 |
+
lm_nose = lm[27: 31] # top-down
|
56 |
+
lm_nostrils = lm[31: 36] # top-down
|
57 |
+
lm_eye_left = lm[36: 42] # left-clockwise
|
58 |
+
lm_eye_right = lm[42: 48] # left-clockwise
|
59 |
+
lm_mouth_outer = lm[48: 60] # left-clockwise
|
60 |
+
lm_mouth_inner = lm[60: 68] # left-clockwise
|
61 |
+
|
62 |
+
eye_left = np.mean(lm_eye_left, axis=0)
|
63 |
+
eye_right = np.mean(lm_eye_right, axis=0)
|
64 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
65 |
+
eye_to_eye = eye_right - eye_left
|
66 |
+
mouth_left = lm_mouth_outer[0]
|
67 |
+
mouth_right = lm_mouth_outer[6]
|
68 |
+
mouth_avg = (mouth_left + mouth_right) * 0.5
|
69 |
+
eye_to_mouth = mouth_avg - eye_avg
|
70 |
+
|
71 |
+
# Compute oriented crop
|
72 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
73 |
+
x /= np.hypot(*x)
|
74 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
|
75 |
+
y = np.flipud(x) * [-1, 1]
|
76 |
+
c = eye_avg + eye_to_mouth * 0.1
|
77 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
78 |
+
qsize = np.hypot(*x) * 2
|
79 |
+
|
80 |
+
# Prepare image
|
81 |
+
img = pil_image.convert("RGB")
|
82 |
+
transform_size = 512
|
83 |
+
output_size = 512
|
84 |
+
enable_padding = True
|
85 |
+
|
86 |
+
# Shrink image for speed
|
87 |
+
shrink = int(np.floor(qsize / output_size * 0.5))
|
88 |
+
if shrink > 1:
|
89 |
+
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
|
90 |
+
img = img.resize(rsize, Image.Resampling.LANCZOS)
|
91 |
+
quad /= shrink
|
92 |
+
qsize /= shrink
|
93 |
+
|
94 |
+
# Crop around face
|
95 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
96 |
+
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
97 |
+
int(np.ceil(max(quad[:, 1]))))
|
98 |
+
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
|
99 |
+
min(crop[3] + border, img.size[1]))
|
100 |
+
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
|
101 |
+
img = img.crop(crop)
|
102 |
+
quad -= crop[0:2]
|
103 |
+
|
104 |
+
# Pad
|
105 |
+
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
106 |
+
int(np.ceil(max(quad[:, 1]))))
|
107 |
+
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
|
108 |
+
max(pad[3] - img.size[1] + border, 0))
|
109 |
+
if enable_padding and max(pad) > border - 4:
|
110 |
+
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
|
111 |
+
img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
112 |
+
h, w, _ = img.shape
|
113 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
114 |
+
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
|
115 |
+
1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
|
116 |
+
blur = qsize * 0.02
|
117 |
+
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
118 |
+
img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
|
119 |
+
img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
|
120 |
+
quad += pad[:2]
|
121 |
+
|
122 |
+
# Transform image
|
123 |
+
img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR)
|
124 |
+
if output_size < transform_size:
|
125 |
+
img = img.resize((output_size, output_size), Image.Resampling.LANCZOS)
|
126 |
+
|
127 |
+
# Resize to final output
|
128 |
+
return img
|