ChenDY commited on
Commit
f106340
Β·
1 Parent(s): 044103b
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ .idea/
2
+
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+
README.md CHANGED
@@ -1,12 +1,15 @@
1
  ---
2
- title: NAG FLUX.1-Kontext-Dev
3
- emoji: πŸŒ–
4
- colorFrom: yellow
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.35.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: NAG FLUX.1 Kontext Dev
3
+ emoji: 🌍
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.35.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
+ short_description: Demo of Normalized Attention Guidance for FLUX.1-Kontext-dev
12
  ---
13
 
14
+ [[arXiv Paper]](https://arxiv.org/abs/2505.21179) [[Project Page]](https://chendaryen.github.io/NAG.github.io/)
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import os
3
+
4
+ import spaces
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ import huggingface_hub
9
+ import gradio as gr
10
+
11
+ from src.pipeline_flux_kontext_nag import NAGFluxKontextPipeline
12
+ from src.transformer_flux import NAGFluxTransformer2DModel
13
+
14
+
15
+ MAX_SEED = np.iinfo(np.int32).max
16
+ MAX_IMAGE_SIZE = 2048
17
+ DEFAULT_GUIDANCE_SCALE = 2.5
18
+ DEFAULT_NEGATIVE_PROMPT = "Low resolution, blurry, lack of details"
19
+
20
+
21
+ transformer = NAGFluxTransformer2DModel.from_pretrained(
22
+ "black-forest-labs/FLUX.1-Kontext-dev",
23
+ subfolder="transformer",
24
+ torch_dtype=torch.bfloat16,
25
+ )
26
+ pipe = NAGFluxKontextPipeline.from_pretrained(
27
+ "black-forest-labs/FLUX.1-Kontext-dev",
28
+ transformer=transformer,
29
+ torch_dtype=torch.bfloat16,
30
+ )
31
+
32
+ device = "cuda"
33
+ pipe = pipe.to(device)
34
+
35
+ examples = [
36
+ ["./assets/monster.png", "Transform to 1960s pop art poster style.", "Use a bright pink, green and blue color palette.", 5],
37
+ ["./assets/rabbit.jpg", "Using this elegant style, create a portrait of a cute Godzilla wearing a pearl tiara and lace collar, maintaining the same refined quality and soft color tones.", DEFAULT_NEGATIVE_PROMPT, 5],
38
+ ]
39
+
40
+
41
+ def get_duration(
42
+ input_image,
43
+ prompt,
44
+ negative_prompt, guidance_scale,
45
+ nag_negative_prompt, nag_scale,
46
+ width, height,
47
+ num_inference_steps,
48
+ seed, randomize_seed,
49
+ compare,
50
+ ):
51
+ duration = int(num_inference_steps) * 1.5 + 5
52
+ if compare:
53
+ duration *= 1.7
54
+ return duration
55
+
56
+
57
+ @spaces.GPU(duration=get_duration)
58
+ def sample(
59
+ input_image,
60
+ prompt,
61
+ negative_prompt=None, guidance_scale=DEFAULT_GUIDANCE_SCALE,
62
+ nag_negative_prompt=None, nag_scale=5.0,
63
+ width=1024, height=1024,
64
+ num_inference_steps=25,
65
+ seed=2025, randomize_seed=False,
66
+ compare=True,
67
+ ):
68
+ prompt = prompt.strip()
69
+ negative_prompt = negative_prompt.strip() if negative_prompt and negative_prompt.strip() else None
70
+ guidance_scale = float(guidance_scale)
71
+ width, height = int(width), int(height)
72
+ num_inference_steps = int(num_inference_steps)
73
+
74
+ if (randomize_seed):
75
+ seed = random.randint(0, MAX_SEED)
76
+ else:
77
+ seed = int(seed)
78
+
79
+ if input_image is not None:
80
+ input_image = input_image.convert("RGB")
81
+
82
+ generator = torch.Generator(device="cuda").manual_seed(seed)
83
+ if input_image is not None:
84
+ image_nag = pipe(
85
+ prompt=prompt,
86
+ image=input_image,
87
+ negative_prompt=negative_prompt,
88
+ guidance_scale=guidance_scale,
89
+ nag_negative_prompt=nag_negative_prompt,
90
+ nag_scale=nag_scale,
91
+ generator=generator,
92
+ width=input_image.size[0],
93
+ height=input_image.size[1],
94
+ num_inference_steps=num_inference_steps,
95
+ ).images[0]
96
+ else:
97
+ image_nag = pipe(
98
+ prompt=prompt,
99
+ negative_prompt=negative_prompt,
100
+ guidance_scale=guidance_scale,
101
+ nag_negative_prompt=nag_negative_prompt,
102
+ nag_scale=nag_scale,
103
+ generator=generator,
104
+ width=width,
105
+ height=height,
106
+ num_inference_steps=num_inference_steps,
107
+ ).images[0]
108
+
109
+ if compare:
110
+ generator = torch.Generator(device="cuda").manual_seed(seed)
111
+ if input_image is not None:
112
+ image_normal = pipe(
113
+ prompt=prompt,
114
+ image=input_image,
115
+ negative_prompt=negative_prompt,
116
+ guidance_scale=guidance_scale,
117
+ generator=generator,
118
+ width=input_image.size[0],
119
+ height=input_image.size[1],
120
+ num_inference_steps=num_inference_steps,
121
+ ).images[0]
122
+ else:
123
+ image_normal = pipe(
124
+ prompt=prompt,
125
+ negative_prompt=negative_prompt,
126
+ guidance_scale=guidance_scale,
127
+ generator=generator,
128
+ width=width,
129
+ height=height,
130
+ num_inference_steps=num_inference_steps,
131
+ ).images[0]
132
+ else:
133
+ image_normal = Image.new("RGB", image_nag.size, color=(0, 0, 0))
134
+
135
+ return (image_normal, image_nag), seed
136
+
137
+
138
+ def sample_example(
139
+ input_image,
140
+ prompt,
141
+ nag_negative_prompt,
142
+ nag_scale,
143
+ ):
144
+ outputs, seed = sample(
145
+ input_image=input_image,
146
+ prompt=prompt,
147
+ negative_prompt=None, guidance_scale=DEFAULT_GUIDANCE_SCALE,
148
+ nag_negative_prompt=nag_negative_prompt, nag_scale=nag_scale,
149
+ width=1024, height=1024,
150
+ num_inference_steps=25,
151
+ seed=2025, randomize_seed=False,
152
+ compare=True,
153
+ )
154
+ return outputs, DEFAULT_GUIDANCE_SCALE, 1024, 1024, 25, seed, True
155
+
156
+
157
+ css="""
158
+ #col-container {
159
+ margin: 0 auto;
160
+ max-width: 960;
161
+ }
162
+ """
163
+
164
+
165
+ with gr.Blocks(css=css) as demo:
166
+ with gr.Column(elem_id="col-container"):
167
+ gr.Markdown('''# Normalized Attention Guidance (NAG) Flux-Kontext-Dev
168
+ NAG demos: [LTX Video Fast](https://huggingface.co/spaces/ChenDY/NAG_ltx-video-distilled), [Wan2.1-T2V-14B](https://huggingface.co/spaces/ChenDY/NAG_wan2-1-fast), [FLUX.1-dev](https://huggingface.co/spaces/ChenDY/NAG_FLUX.1-dev)
169
+
170
+ Implementation of [Normalized Attention Guidance](https://chendaryen.github.io/NAG.github.io/)
171
+
172
+ [Paper](https://arxiv.org/abs/2505.21179), [GitHub](https://github.com/ChenDarYen/Normalized-Attention-Guidance), [ComfyUI](https://github.com/ChenDarYen/ComfyUI-NAG)
173
+ ''')
174
+ with gr.Row():
175
+ with gr.Column():
176
+ input_image = gr.Image(label="Upload the image for editing", type="pil")
177
+ prompt = gr.Textbox(
178
+ label="Prompt",
179
+ max_lines=3,
180
+ placeholder="Enter your prompt",
181
+ )
182
+ nag_negative_prompt = gr.Textbox(
183
+ label="Negative Prompt for NAG",
184
+ value=DEFAULT_NEGATIVE_PROMPT,
185
+ max_lines=3,
186
+ )
187
+ nag_scale = gr.Slider(label="NAG Scale", minimum=1., maximum=20., step=0.25, value=5.)
188
+ compare = gr.Checkbox(label="Compare with baseline", info="If unchecked, only sample with NAG will be generated.", value=True)
189
+ button = gr.Button("Generate", min_width=120)
190
+ with gr.Accordion("Advanced Settings", open=False):
191
+ negative_prompt = gr.Textbox(label="Negative Prompt", value=None, visible=False)
192
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1., maximum=15., step=0.1, value=DEFAULT_GUIDANCE_SCALE)
193
+ with gr.Row():
194
+ width = gr.Slider(
195
+ label="Width",
196
+ minimum=256,
197
+ maximum=MAX_IMAGE_SIZE,
198
+ step=32,
199
+ value=1024,
200
+ )
201
+ height = gr.Slider(
202
+ label="Height",
203
+ minimum=256,
204
+ maximum=MAX_IMAGE_SIZE,
205
+ step=32,
206
+ value=1024,
207
+ )
208
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=25)
209
+ seed = gr.Slider(label="Seed", minimum=1, maximum=MAX_SEED, step=1, randomize=True)
210
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
211
+
212
+ with gr.Column():
213
+ output = gr.ImageSlider(label="Left: Baseline, Right: With NAG", interactive=False)
214
+
215
+ gr.Examples(
216
+ examples=examples,
217
+ fn=sample_example,
218
+ inputs=[
219
+ input_image,
220
+ prompt,
221
+ nag_negative_prompt,
222
+ nag_scale,
223
+ ],
224
+ outputs=[output, guidance_scale, width, height, num_inference_steps, seed, compare],
225
+ cache_examples="lazy",
226
+ )
227
+
228
+ gr.on(
229
+ triggers=[
230
+ button.click,
231
+ prompt.submit
232
+ ],
233
+ fn=sample,
234
+ inputs=[
235
+ input_image,
236
+ prompt,
237
+ negative_prompt, guidance_scale,
238
+ nag_negative_prompt, nag_scale,
239
+ width, height,
240
+ num_inference_steps,
241
+ seed, randomize_seed,
242
+ compare,
243
+ ],
244
+ outputs=[output, seed],
245
+ )
246
+
247
+
248
+ if __name__ == "__main__":
249
+ huggingface_hub.login(os.getenv('HF_TOKEN'))
250
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ accelerate
2
+ git+https://github.com/huggingface/diffusers.git
3
+ torch
4
+ transformers
5
+ sentencepiece
src/__init__.py ADDED
File without changes
src/attention_flux_nag.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from diffusers.models.attention_processor import Attention
8
+ from diffusers.models.embeddings import apply_rotary_emb
9
+
10
+
11
+ class NAGFluxAttnProcessor2_0:
12
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
13
+
14
+ def __init__(
15
+ self,
16
+ nag_scale: float = 1.0,
17
+ nag_tau=2.5,
18
+ nag_alpha=0.25,
19
+ encoder_hidden_states_length: int = None,
20
+ ):
21
+ if not hasattr(F, "scaled_dot_product_attention"):
22
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
23
+ self.nag_scale = nag_scale
24
+ self.nag_tau = nag_tau
25
+ self.nag_alpha = nag_alpha
26
+ self.encoder_hidden_states_length = encoder_hidden_states_length
27
+
28
+ def __call__(
29
+ self,
30
+ attn: Attention,
31
+ hidden_states: torch.FloatTensor,
32
+ encoder_hidden_states: torch.FloatTensor = None,
33
+ attention_mask: Optional[torch.FloatTensor] = None,
34
+ image_rotary_emb: Optional[torch.Tensor] = None,
35
+ ) -> torch.FloatTensor:
36
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
37
+
38
+ if self.nag_scale > 1.:
39
+ if encoder_hidden_states is not None:
40
+ assert len(hidden_states) == batch_size * 0.5
41
+ apply_guidance = True
42
+ else:
43
+ apply_guidance = False
44
+
45
+ # `sample` projections.
46
+ query = attn.to_q(hidden_states)
47
+ key = attn.to_k(hidden_states)
48
+ value = attn.to_v(hidden_states)
49
+
50
+ # attention
51
+ if apply_guidance and encoder_hidden_states is not None:
52
+ query = query.tile(2, 1, 1)
53
+ key = key.tile(2, 1, 1)
54
+ value = value.tile(2, 1, 1)
55
+
56
+ inner_dim = key.shape[-1]
57
+ head_dim = inner_dim // attn.heads
58
+
59
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
60
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
61
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
62
+
63
+ if attn.norm_q is not None:
64
+ query = attn.norm_q(query)
65
+ if attn.norm_k is not None:
66
+ key = attn.norm_k(key)
67
+
68
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
69
+ if encoder_hidden_states is not None:
70
+ # `context` projections.
71
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
72
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
73
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
74
+
75
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
76
+ batch_size, -1, attn.heads, head_dim
77
+ ).transpose(1, 2)
78
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
79
+ batch_size, -1, attn.heads, head_dim
80
+ ).transpose(1, 2)
81
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
82
+ batch_size, -1, attn.heads, head_dim
83
+ ).transpose(1, 2)
84
+
85
+ if attn.norm_added_q is not None:
86
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
87
+ if attn.norm_added_k is not None:
88
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
89
+
90
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
91
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
92
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
93
+
94
+ encoder_hidden_states_length = encoder_hidden_states.shape[1]
95
+
96
+ else:
97
+ assert self.encoder_hidden_states_length is not None
98
+ encoder_hidden_states_length = self.encoder_hidden_states_length
99
+
100
+ if image_rotary_emb is not None:
101
+ query = apply_rotary_emb(query, image_rotary_emb)
102
+ key = apply_rotary_emb(key, image_rotary_emb)
103
+
104
+ if not apply_guidance:
105
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
106
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
107
+ hidden_states = hidden_states.to(query.dtype)
108
+
109
+ else:
110
+ origin_batch_size = batch_size // 2
111
+ query, query_negative = torch.chunk(query, 2, dim=0)
112
+ key, key_negative = torch.chunk(key, 2, dim=0)
113
+ value, value_negative = torch.chunk(value, 2, dim=0)
114
+
115
+ hidden_states_negative = F.scaled_dot_product_attention(query_negative, key_negative, value_negative, dropout_p=0.0, is_causal=False)
116
+ hidden_states_negative = hidden_states_negative.transpose(1, 2).reshape(origin_batch_size, -1, attn.heads * head_dim)
117
+ hidden_states_negative = hidden_states_negative.to(query.dtype)
118
+
119
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
120
+ hidden_states = hidden_states.transpose(1, 2).reshape(origin_batch_size, -1, attn.heads * head_dim)
121
+ hidden_states = hidden_states.to(query.dtype)
122
+
123
+ if encoder_hidden_states is not None:
124
+ encoder_hidden_states, hidden_states = (
125
+ hidden_states[:, : encoder_hidden_states.shape[1]],
126
+ hidden_states[:, encoder_hidden_states.shape[1] :],
127
+ )
128
+
129
+ if apply_guidance:
130
+ encoder_hidden_states_negative, hidden_states_negative = (
131
+ hidden_states_negative[:, : encoder_hidden_states.shape[1]],
132
+ hidden_states_negative[:, encoder_hidden_states.shape[1]:],
133
+ )
134
+ hidden_states_positive = hidden_states
135
+ hidden_states_guidance = hidden_states_positive * self.nag_scale - hidden_states_negative * (self.nag_scale - 1)
136
+ norm_positive = torch.norm(hidden_states_positive, p=2, dim=-1, keepdim=True).expand(*hidden_states_positive.shape)
137
+ norm_guidance = torch.norm(hidden_states_guidance, p=2, dim=-1, keepdim=True).expand(*hidden_states_positive.shape)
138
+
139
+ scale = norm_guidance / norm_positive
140
+ hidden_states_guidance = hidden_states_guidance * torch.minimum(scale, scale.new_ones(1) * self.nag_tau) / scale
141
+
142
+ hidden_states = hidden_states_guidance * self.nag_alpha + hidden_states_positive * (1 - self.nag_alpha)
143
+
144
+ encoder_hidden_states = torch.cat((encoder_hidden_states, encoder_hidden_states_negative), dim=0)
145
+
146
+ # linear proj
147
+ hidden_states = attn.to_out[0](hidden_states)
148
+ # dropout
149
+ hidden_states = attn.to_out[1](hidden_states)
150
+
151
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
152
+
153
+ return hidden_states, encoder_hidden_states
154
+
155
+ else:
156
+ if apply_guidance:
157
+ image_hidden_states_negative = hidden_states_negative[:, encoder_hidden_states_length:]
158
+ image_hidden_states = hidden_states[:, encoder_hidden_states_length:]
159
+
160
+ image_hidden_states_positive = image_hidden_states
161
+ image_hidden_states_guidance = image_hidden_states_positive * self.nag_scale - image_hidden_states_negative * (self.nag_scale - 1)
162
+ norm_positive = torch.norm(image_hidden_states_positive, p=2, dim=-1, keepdim=True).expand(*image_hidden_states_positive.shape)
163
+ norm_guidance = torch.norm(image_hidden_states_guidance, p=2, dim=-1, keepdim=True).expand(*image_hidden_states_positive.shape)
164
+
165
+ scale = norm_guidance / norm_positive
166
+ image_hidden_states_guidance = image_hidden_states_guidance * torch.minimum(scale, scale.new_ones(1) * self.nag_tau) / scale
167
+ # scale = torch.nan_to_num(scale, 10)
168
+ # image_hidden_states_guidance[scale > self.nag_tau] = image_hidden_states_guidance[scale > self.nag_tau] / (norm_guidance[scale > self.nag_tau] + 1e-7) * norm_positive[scale > self.nag_tau] * self.nag_tau
169
+
170
+ image_hidden_states = image_hidden_states_guidance * self.nag_alpha + image_hidden_states_positive * (1 - self.nag_alpha)
171
+
172
+ hidden_states_negative[:, encoder_hidden_states_length:] = image_hidden_states
173
+ hidden_states[:, encoder_hidden_states_length:] = image_hidden_states
174
+ hidden_states = torch.cat((hidden_states, hidden_states_negative), dim=0)
175
+
176
+ return hidden_states
src/normalization.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, SD35AdaLayerNormZeroX
5
+
6
+
7
+ class TruncAdaLayerNorm(AdaLayerNorm):
8
+ def forward(
9
+ self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
10
+ ) -> torch.Tensor:
11
+ batch_size = x.shape[0]
12
+ return self.forward_old(
13
+ x,
14
+ temb[:batch_size] if temb is not None else None,
15
+ )
16
+
17
+
18
+ class TruncAdaLayerNormContinuous(AdaLayerNormContinuous):
19
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
20
+ batch_size = x.shape[0]
21
+ return self.forward_old(x, conditioning_embedding[:batch_size])
22
+
23
+
24
+ class TruncAdaLayerNormZero(AdaLayerNormZero):
25
+ def forward(
26
+ self,
27
+ x: torch.Tensor,
28
+ timestep: Optional[torch.Tensor] = None,
29
+ class_labels: Optional[torch.LongTensor] = None,
30
+ hidden_dtype: Optional[torch.dtype] = None,
31
+ emb: Optional[torch.Tensor] = None,
32
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
33
+ batch_size = x.shape[0]
34
+ return self.forward_old(
35
+ x,
36
+ timestep[:batch_size] if timestep is not None else None,
37
+ class_labels[:batch_size] if class_labels is not None else None,
38
+ hidden_dtype,
39
+ emb[:batch_size] if emb is not None else None,
40
+ )
41
+
42
+
43
+ class TruncSD35AdaLayerNormZeroX(SD35AdaLayerNormZeroX):
44
+ def forward(
45
+ self,
46
+ hidden_states: torch.Tensor,
47
+ emb: Optional[torch.Tensor] = None,
48
+ ) -> Tuple[torch.Tensor, ...]:
49
+ batch_size = hidden_states.shape[0]
50
+ return self.forward_old(
51
+ hidden_states,
52
+ emb[:batch_size] if emb is not None else None,
53
+ )
54
+
src/pipeline_flux_kontext_nag.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Union, List, Optional, Dict, Any, Callable
3
+ import types
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from diffusers import FluxKontextPipeline
9
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
10
+ from diffusers.image_processor import PipelineImageInput
11
+ from diffusers.utils import is_torch_xla_available, logging
12
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
13
+ from diffusers.pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS
14
+ from diffusers.models.normalization import AdaLayerNormZero, AdaLayerNormContinuous
15
+
16
+ from src.attention_flux_nag import NAGFluxAttnProcessor2_0
17
+ from src.normalization import TruncAdaLayerNormZero, TruncAdaLayerNormContinuous
18
+
19
+ if is_torch_xla_available():
20
+ import torch_xla.core.xla_model as xm
21
+
22
+ XLA_AVAILABLE = True
23
+ else:
24
+ XLA_AVAILABLE = False
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ class NAGFluxKontextPipeline(FluxKontextPipeline):
31
+ @property
32
+ def do_normalized_attention_guidance(self):
33
+ return self._nag_scale > 1
34
+
35
+ def _set_nag_attn_processor(
36
+ self,
37
+ nag_scale,
38
+ encoder_hidden_states_length,
39
+ nag_tau=2.5,
40
+ nag_alpha=0.25,
41
+ ):
42
+ attn_procs = {}
43
+ for name in self.transformer.attn_processors.keys():
44
+ attn_procs[name] = NAGFluxAttnProcessor2_0(
45
+ nag_scale=nag_scale,
46
+ nag_tau=nag_tau,
47
+ nag_alpha=nag_alpha,
48
+ encoder_hidden_states_length=encoder_hidden_states_length,
49
+ )
50
+ self.transformer.set_attn_processor(attn_procs)
51
+
52
+ @torch.no_grad()
53
+ def __call__(
54
+ self,
55
+ image: Optional[PipelineImageInput] = None,
56
+ prompt: Union[str, List[str]] = None,
57
+ prompt_2: Optional[Union[str, List[str]]] = None,
58
+ negative_prompt: Union[str, List[str]] = None,
59
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
60
+ true_cfg_scale: float = 1.0,
61
+ height: Optional[int] = None,
62
+ width: Optional[int] = None,
63
+ num_inference_steps: int = 28,
64
+ sigmas: Optional[List[float]] = None,
65
+ guidance_scale: float = 3.5,
66
+ num_images_per_prompt: Optional[int] = 1,
67
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
68
+ latents: Optional[torch.FloatTensor] = None,
69
+ prompt_embeds: Optional[torch.FloatTensor] = None,
70
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
71
+ ip_adapter_image: Optional[PipelineImageInput] = None,
72
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
73
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
74
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
75
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
76
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
77
+ output_type: Optional[str] = "pil",
78
+ return_dict: bool = True,
79
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
80
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
81
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
82
+ max_sequence_length: int = 512,
83
+ max_area: int = 1024 ** 2,
84
+ _auto_resize: bool = True,
85
+
86
+ nag_scale: float = 1.0,
87
+ nag_tau: float = 2.5,
88
+ nag_alpha: float = 0.25,
89
+ nag_end: float = 0.25,
90
+ nag_negative_prompt: str = None,
91
+ nag_negative_prompt_2: str = None,
92
+ nag_negative_prompt_embeds: Optional[torch.Tensor] = None,
93
+ nag_negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
94
+ ):
95
+ r"""
96
+ Function invoked when calling the pipeline for generation.
97
+
98
+ Args:
99
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
100
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
101
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
102
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
103
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
104
+ latents as `image`, but if passing latents directly it is not encoded again.
105
+ prompt (`str` or `List[str]`, *optional*):
106
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
107
+ instead.
108
+ prompt_2 (`str` or `List[str]`, *optional*):
109
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
110
+ will be used instead.
111
+ negative_prompt (`str` or `List[str]`, *optional*):
112
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
113
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
114
+ not greater than `1`).
115
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
116
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
117
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
118
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
119
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
120
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
121
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
122
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
123
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
124
+ num_inference_steps (`int`, *optional*, defaults to 50):
125
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
126
+ expense of slower inference.
127
+ sigmas (`List[float]`, *optional*):
128
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
129
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
130
+ will be used.
131
+ guidance_scale (`float`, *optional*, defaults to 3.5):
132
+ Guidance scale as defined in [Classifier-Free Diffusion
133
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
134
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
135
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
136
+ the text `prompt`, usually at the expense of lower image quality.
137
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
138
+ The number of images to generate per prompt.
139
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
140
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
141
+ to make generation deterministic.
142
+ latents (`torch.FloatTensor`, *optional*):
143
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
144
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
145
+ tensor will ge generated by sampling using the supplied random `generator`.
146
+ prompt_embeds (`torch.FloatTensor`, *optional*):
147
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
148
+ provided, text embeddings will be generated from `prompt` input argument.
149
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
150
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
151
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
152
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
153
+ Optional image input to work with IP Adapters.
154
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
155
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
156
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
157
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
158
+ negative_ip_adapter_image:
159
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
160
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
161
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
162
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
163
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
164
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
165
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
166
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
167
+ argument.
168
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
169
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
170
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
171
+ input argument.
172
+ output_type (`str`, *optional*, defaults to `"pil"`):
173
+ The output format of the generate image. Choose between
174
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
175
+ return_dict (`bool`, *optional*, defaults to `True`):
176
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
177
+ joint_attention_kwargs (`dict`, *optional*):
178
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
179
+ `self.processor` in
180
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
181
+ callback_on_step_end (`Callable`, *optional*):
182
+ A function that calls at the end of each denoising steps during the inference. The function is called
183
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
184
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
185
+ `callback_on_step_end_tensor_inputs`.
186
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
187
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
188
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
189
+ `._callback_tensor_inputs` attribute of your pipeline class.
190
+ max_sequence_length (`int` defaults to 512):
191
+ Maximum sequence length to use with the `prompt`.
192
+ max_area (`int`, defaults to `1024 ** 2`):
193
+ The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
194
+ area while maintaining the aspect ratio.
195
+
196
+ Examples:
197
+
198
+ Returns:
199
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
200
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
201
+ images.
202
+ """
203
+
204
+ height = height or self.default_sample_size * self.vae_scale_factor
205
+ width = width or self.default_sample_size * self.vae_scale_factor
206
+
207
+ original_height, original_width = height, width
208
+ aspect_ratio = width / height
209
+ width = round((max_area * aspect_ratio) ** 0.5)
210
+ height = round((max_area / aspect_ratio) ** 0.5)
211
+
212
+ multiple_of = self.vae_scale_factor * 2
213
+ width = width // multiple_of * multiple_of
214
+ height = height // multiple_of * multiple_of
215
+
216
+ if height != original_height or width != original_width:
217
+ logger.warning(
218
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
219
+ )
220
+
221
+ # 1. Check inputs. Raise error if not correct
222
+ self.check_inputs(
223
+ prompt,
224
+ prompt_2,
225
+ height,
226
+ width,
227
+ negative_prompt=negative_prompt,
228
+ negative_prompt_2=negative_prompt_2,
229
+ prompt_embeds=prompt_embeds,
230
+ negative_prompt_embeds=negative_prompt_embeds,
231
+ pooled_prompt_embeds=pooled_prompt_embeds,
232
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
233
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
234
+ max_sequence_length=max_sequence_length,
235
+ )
236
+
237
+ self._guidance_scale = guidance_scale
238
+ self._joint_attention_kwargs = joint_attention_kwargs
239
+ self._current_timestep = None
240
+ self._interrupt = False
241
+ self._nag_scale = nag_scale
242
+
243
+ # 2. Define call parameters
244
+ if prompt is not None and isinstance(prompt, str):
245
+ batch_size = 1
246
+ elif prompt is not None and isinstance(prompt, list):
247
+ batch_size = len(prompt)
248
+ else:
249
+ batch_size = prompt_embeds.shape[0]
250
+
251
+ device = self._execution_device
252
+
253
+ lora_scale = (
254
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
255
+ )
256
+ has_neg_prompt = negative_prompt is not None or (
257
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
258
+ )
259
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
260
+ (
261
+ prompt_embeds,
262
+ pooled_prompt_embeds,
263
+ text_ids,
264
+ ) = self.encode_prompt(
265
+ prompt=prompt,
266
+ prompt_2=prompt_2,
267
+ prompt_embeds=prompt_embeds,
268
+ pooled_prompt_embeds=pooled_prompt_embeds,
269
+ device=device,
270
+ num_images_per_prompt=num_images_per_prompt,
271
+ max_sequence_length=max_sequence_length,
272
+ lora_scale=lora_scale,
273
+ )
274
+ if do_true_cfg:
275
+ (
276
+ negative_prompt_embeds,
277
+ negative_pooled_prompt_embeds,
278
+ negative_text_ids,
279
+ ) = self.encode_prompt(
280
+ prompt=negative_prompt,
281
+ prompt_2=negative_prompt_2,
282
+ prompt_embeds=negative_prompt_embeds,
283
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
284
+ device=device,
285
+ num_images_per_prompt=num_images_per_prompt,
286
+ max_sequence_length=max_sequence_length,
287
+ lora_scale=lora_scale,
288
+ )
289
+
290
+ if self.do_normalized_attention_guidance:
291
+ if nag_negative_prompt_embeds is None or nag_negative_pooled_prompt_embeds is None:
292
+ if nag_negative_prompt is None:
293
+ if negative_prompt is not None:
294
+ if do_true_cfg:
295
+ nag_negative_prompt_embeds = negative_prompt_embeds
296
+ nag_negative_pooled_prompt_embeds = negative_pooled_prompt_embeds
297
+ else:
298
+ nag_negative_prompt = negative_prompt
299
+ nag_negative_prompt_2 = negative_prompt_2
300
+ else:
301
+ nag_negative_prompt = ""
302
+
303
+ if nag_negative_prompt is not None:
304
+ nag_negative_prompt_embeds, nag_negative_pooled_prompt_embeds = self.encode_prompt(
305
+ prompt=nag_negative_prompt,
306
+ prompt_2=nag_negative_prompt_2,
307
+ device=device,
308
+ num_images_per_prompt=num_images_per_prompt,
309
+ max_sequence_length=max_sequence_length,
310
+ lora_scale=lora_scale,
311
+ )[:2]
312
+
313
+ if self.do_normalized_attention_guidance:
314
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, nag_negative_pooled_prompt_embeds], dim=0)
315
+ prompt_embeds = torch.cat([prompt_embeds, nag_negative_prompt_embeds], dim=0)
316
+
317
+ # 3. Preprocess image
318
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
319
+ img = image[0] if isinstance(image, list) else image
320
+ image_height, image_width = self.image_processor.get_default_height_width(img)
321
+ aspect_ratio = image_width / image_height
322
+ if _auto_resize:
323
+ # Kontext is trained on specific resolutions, using one of them is recommended
324
+ _, image_width, image_height = min(
325
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
326
+ )
327
+ image_width = image_width // multiple_of * multiple_of
328
+ image_height = image_height // multiple_of * multiple_of
329
+ image = self.image_processor.resize(image, image_height, image_width)
330
+ image = self.image_processor.preprocess(image, image_height, image_width)
331
+
332
+ # 4. Prepare latent variables
333
+ num_channels_latents = self.transformer.config.in_channels // 4
334
+ latents, image_latents, latent_ids, image_ids = self.prepare_latents(
335
+ image,
336
+ batch_size * num_images_per_prompt,
337
+ num_channels_latents,
338
+ height,
339
+ width,
340
+ prompt_embeds.dtype,
341
+ device,
342
+ generator,
343
+ latents,
344
+ )
345
+ if image_ids is not None:
346
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
347
+
348
+ # 5. Prepare timesteps
349
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
350
+ image_seq_len = latents.shape[1]
351
+ mu = calculate_shift(
352
+ image_seq_len,
353
+ self.scheduler.config.get("base_image_seq_len", 256),
354
+ self.scheduler.config.get("max_image_seq_len", 4096),
355
+ self.scheduler.config.get("base_shift", 0.5),
356
+ self.scheduler.config.get("max_shift", 1.15),
357
+ )
358
+ timesteps, num_inference_steps = retrieve_timesteps(
359
+ self.scheduler,
360
+ num_inference_steps,
361
+ device,
362
+ sigmas=sigmas,
363
+ mu=mu,
364
+ )
365
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
366
+ self._num_timesteps = len(timesteps)
367
+
368
+ # handle guidance
369
+ if self.transformer.config.guidance_embeds:
370
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
371
+ guidance = guidance.expand(prompt_embeds.shape[0])
372
+ else:
373
+ guidance = None
374
+
375
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
376
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
377
+ ):
378
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
379
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
380
+
381
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
382
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
383
+ ):
384
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
385
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
386
+
387
+ if self.joint_attention_kwargs is None:
388
+ self._joint_attention_kwargs = {}
389
+
390
+ image_embeds = None
391
+ negative_image_embeds = None
392
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
393
+ image_embeds = self.prepare_ip_adapter_image_embeds(
394
+ ip_adapter_image,
395
+ ip_adapter_image_embeds,
396
+ device,
397
+ batch_size * num_images_per_prompt,
398
+ )
399
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
400
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
401
+ negative_ip_adapter_image,
402
+ negative_ip_adapter_image_embeds,
403
+ device,
404
+ batch_size * num_images_per_prompt,
405
+ )
406
+
407
+ origin_attn_procs = self.transformer.attn_processors
408
+ if self.do_normalized_attention_guidance:
409
+ self._set_nag_attn_processor(nag_scale, prompt_embeds.shape[1], nag_tau, nag_alpha)
410
+ attn_procs_recovered = False
411
+
412
+ for sub_mod in self.transformer.modules():
413
+ if not hasattr(sub_mod, "forward_old") :
414
+ sub_mod.forward_old = sub_mod.forward
415
+ if isinstance(sub_mod, AdaLayerNormZero):
416
+ sub_mod.forward = types.MethodType(TruncAdaLayerNormZero.forward, sub_mod)
417
+ elif isinstance(sub_mod, AdaLayerNormContinuous):
418
+ sub_mod.forward = types.MethodType(TruncAdaLayerNormContinuous.forward, sub_mod)
419
+
420
+
421
+ # 6. Denoising loop
422
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
423
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
424
+ self.scheduler.set_begin_index(0)
425
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
426
+ for i, t in enumerate(timesteps):
427
+ if self.interrupt:
428
+ continue
429
+
430
+ if t < (1 - nag_end) * 1000 and self.do_normalized_attention_guidance and not attn_procs_recovered:
431
+ self.transformer.set_attn_processor(origin_attn_procs)
432
+ if guidance is not None:
433
+ guidance = guidance[:len(latents)]
434
+ pooled_prompt_embeds = pooled_prompt_embeds[:len(latents)]
435
+ prompt_embeds = prompt_embeds[:len(latents)]
436
+ attn_procs_recovered = True
437
+
438
+ self._current_timestep = t
439
+ if image_embeds is not None:
440
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
441
+
442
+ latent_model_input = latents
443
+ if image_latents is not None:
444
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
445
+ timestep = t.expand(prompt_embeds.shape[0]).to(latents.dtype)
446
+
447
+ noise_pred = self.transformer(
448
+ hidden_states=latent_model_input,
449
+ timestep=timestep / 1000,
450
+ guidance=guidance,
451
+ pooled_projections=pooled_prompt_embeds,
452
+ encoder_hidden_states=prompt_embeds,
453
+ txt_ids=text_ids,
454
+ img_ids=latent_ids,
455
+ joint_attention_kwargs=self.joint_attention_kwargs,
456
+ return_dict=False,
457
+ )[0]
458
+ noise_pred = noise_pred[:, : latents.size(1)]
459
+
460
+ if do_true_cfg:
461
+ if negative_image_embeds is not None:
462
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
463
+ neg_noise_pred = self.transformer(
464
+ hidden_states=latent_model_input,
465
+ timestep=timestep / 1000,
466
+ guidance=guidance,
467
+ pooled_projections=negative_pooled_prompt_embeds,
468
+ encoder_hidden_states=negative_prompt_embeds,
469
+ txt_ids=negative_text_ids,
470
+ img_ids=latent_ids,
471
+ joint_attention_kwargs=self.joint_attention_kwargs,
472
+ return_dict=False,
473
+ )[0]
474
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
475
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
476
+
477
+ # compute the previous noisy sample x_t -> x_t-1
478
+ latents_dtype = latents.dtype
479
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
480
+
481
+ if latents.dtype != latents_dtype:
482
+ if torch.backends.mps.is_available():
483
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
484
+ latents = latents.to(latents_dtype)
485
+
486
+ if callback_on_step_end is not None:
487
+ callback_kwargs = {}
488
+ for k in callback_on_step_end_tensor_inputs:
489
+ callback_kwargs[k] = locals()[k]
490
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
491
+
492
+ latents = callback_outputs.pop("latents", latents)
493
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
494
+
495
+ # call the callback, if provided
496
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
497
+ progress_bar.update()
498
+
499
+ if XLA_AVAILABLE:
500
+ xm.mark_step()
501
+
502
+ self._current_timestep = None
503
+
504
+ if output_type == "latent":
505
+ image = latents
506
+ else:
507
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
508
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
509
+ image = self.vae.decode(latents, return_dict=False)[0]
510
+ image = self.image_processor.postprocess(image, output_type=output_type)
511
+
512
+ if self.do_normalized_attention_guidance and not attn_procs_recovered:
513
+ self.transformer.set_attn_processor(origin_attn_procs)
514
+
515
+ # Offload all models
516
+ self.maybe_free_model_hooks()
517
+
518
+ if not return_dict:
519
+ return (image,)
520
+
521
+ return FluxPipelineOutput(images=image)
src/transformer_flux.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
7
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
8
+ from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
9
+
10
+
11
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
12
+
13
+
14
+ class NAGFluxTransformer2DModel(FluxTransformer2DModel):
15
+ def forward(
16
+ self,
17
+ hidden_states: torch.Tensor,
18
+ encoder_hidden_states: torch.Tensor = None,
19
+ pooled_projections: torch.Tensor = None,
20
+ timestep: torch.LongTensor = None,
21
+ img_ids: torch.Tensor = None,
22
+ txt_ids: torch.Tensor = None,
23
+ guidance: torch.Tensor = None,
24
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
25
+ controlnet_block_samples=None,
26
+ controlnet_single_block_samples=None,
27
+ return_dict: bool = True,
28
+ controlnet_blocks_repeat: bool = False,
29
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
30
+ """
31
+ The [`FluxTransformer2DModel`] forward method.
32
+
33
+ Args:
34
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
35
+ Input `hidden_states`.
36
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
37
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
38
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
39
+ from the embeddings of input conditions.
40
+ timestep ( `torch.LongTensor`):
41
+ Used to indicate denoising step.
42
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
43
+ A list of tensors that if specified are added to the residuals of transformer blocks.
44
+ joint_attention_kwargs (`dict`, *optional*):
45
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
46
+ `self.processor` in
47
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
48
+ return_dict (`bool`, *optional*, defaults to `True`):
49
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
50
+ tuple.
51
+
52
+ Returns:
53
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
54
+ `tuple` where the first element is the sample tensor.
55
+ """
56
+ if joint_attention_kwargs is not None:
57
+ joint_attention_kwargs = joint_attention_kwargs.copy()
58
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
59
+ else:
60
+ lora_scale = 1.0
61
+
62
+ if USE_PEFT_BACKEND:
63
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
64
+ scale_lora_layers(self, lora_scale)
65
+ else:
66
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
67
+ logger.warning(
68
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
69
+ )
70
+
71
+ do_nag = hidden_states.shape[0] != encoder_hidden_states.shape[0]
72
+
73
+ hidden_states = self.x_embedder(hidden_states)
74
+
75
+ timestep = timestep.to(hidden_states.dtype) * 1000
76
+ if guidance is not None:
77
+ guidance = guidance.to(hidden_states.dtype) * 1000
78
+ else:
79
+ guidance = None
80
+
81
+ temb = (
82
+ self.time_text_embed(timestep, pooled_projections)
83
+ if guidance is None
84
+ else self.time_text_embed(timestep, guidance, pooled_projections)
85
+ )
86
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
87
+
88
+ if txt_ids.ndim == 3:
89
+ logger.warning(
90
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
91
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
92
+ )
93
+ txt_ids = txt_ids[0]
94
+ if img_ids.ndim == 3:
95
+ logger.warning(
96
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
97
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
98
+ )
99
+ img_ids = img_ids[0]
100
+
101
+ ids = torch.cat((txt_ids, img_ids), dim=0)
102
+ image_rotary_emb = self.pos_embed(ids)
103
+
104
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
105
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
106
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
107
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
108
+
109
+ for index_block, block in enumerate(self.transformer_blocks):
110
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
111
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
112
+ block,
113
+ hidden_states,
114
+ encoder_hidden_states,
115
+ temb,
116
+ image_rotary_emb,
117
+ )
118
+
119
+ else:
120
+ encoder_hidden_states, hidden_states = block(
121
+ hidden_states=hidden_states,
122
+ encoder_hidden_states=encoder_hidden_states,
123
+ temb=temb,
124
+ image_rotary_emb=image_rotary_emb,
125
+ joint_attention_kwargs=joint_attention_kwargs,
126
+ )
127
+
128
+ # controlnet residual
129
+ if controlnet_block_samples is not None:
130
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
131
+ interval_control = int(np.ceil(interval_control))
132
+ # For Xlabs ControlNet.
133
+ if controlnet_blocks_repeat:
134
+ hidden_states = (
135
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
136
+ )
137
+ else:
138
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
139
+
140
+ if do_nag:
141
+ hidden_states = hidden_states.tile(2, 1, 1)
142
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
143
+
144
+ for index_block, block in enumerate(self.single_transformer_blocks):
145
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
146
+ hidden_states = self._gradient_checkpointing_func(
147
+ block,
148
+ hidden_states,
149
+ temb,
150
+ image_rotary_emb,
151
+ )
152
+
153
+ else:
154
+ hidden_states = block(
155
+ hidden_states=hidden_states,
156
+ temb=temb,
157
+ image_rotary_emb=image_rotary_emb,
158
+ joint_attention_kwargs=joint_attention_kwargs,
159
+ )
160
+
161
+ # controlnet residual
162
+ if controlnet_single_block_samples is not None:
163
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
164
+ interval_control = int(np.ceil(interval_control))
165
+ controlnet_single_block_sample = controlnet_single_block_samples[index_block // interval_control]
166
+ if do_nag:
167
+ controlnet_single_block_sample = controlnet_single_block_sample.tile(2, 1, 1)
168
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
169
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] + controlnet_single_block_sample
170
+ )
171
+
172
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
173
+
174
+ if do_nag:
175
+ hidden_states = torch.chunk(hidden_states, 2, dim=0)[0]
176
+
177
+ hidden_states = self.norm_out(hidden_states, temb)
178
+ output = self.proj_out(hidden_states)
179
+
180
+ if USE_PEFT_BACKEND:
181
+ # remove `lora_scale` from each PEFT layer
182
+ unscale_lora_layers(self, lora_scale)
183
+
184
+ if not return_dict:
185
+ return (output,)
186
+
187
+ return Transformer2DModelOutput(sample=output)