alexnasa commited on
Commit
e8fbc15
·
verified ·
1 Parent(s): 7537b7c

Delete src

Browse files
src/adapters/__init__.py DELETED
File without changes
src/adapters/mod_adapters.py DELETED
@@ -1,243 +0,0 @@
1
- # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Dict, List, Optional, Set, Tuple, Union
17
- from dataclasses import dataclass
18
- from inspect import isfunction
19
-
20
- import os
21
- import torch
22
- import torch.nn as nn
23
- import torch.nn.functional as F
24
- from einops import rearrange, repeat
25
-
26
- from diffusers.models.modeling_utils import ModelMixin
27
- from diffusers.configuration_utils import ConfigMixin, register_to_config
28
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps
29
-
30
- from src.utils.data_utils import pad_to_square, pad_to_target
31
-
32
- from transformers import CLIPProcessor, CLIPModel, CLIPVisionModelWithProjection, CLIPVisionModel
33
-
34
- from collections import OrderedDict
35
-
36
- class SquaredReLU(nn.Module):
37
- def forward(self, x: torch.Tensor):
38
- return torch.square(torch.relu(x))
39
-
40
- class AdaLayerNorm(nn.Module):
41
- def __init__(self, embedding_dim: int, time_embedding_dim: Optional[int] = None, ln_bias=True):
42
- super().__init__()
43
-
44
- if time_embedding_dim is None:
45
- time_embedding_dim = embedding_dim
46
-
47
- self.silu = nn.SiLU()
48
- self.linear = nn.Linear(time_embedding_dim, 2 * embedding_dim, bias=True)
49
- nn.init.zeros_(self.linear.weight)
50
- nn.init.zeros_(self.linear.bias)
51
-
52
- self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6, bias=ln_bias)
53
-
54
- def forward(
55
- self, x: torch.Tensor, timestep_embedding: torch.Tensor
56
- ) -> tuple[torch.Tensor, torch.Tensor]:
57
- emb = self.linear(self.silu(timestep_embedding))
58
- shift, scale = emb.view(len(x), 1, -1).chunk(2, dim=-1)
59
- x = self.norm(x) * (1 + scale) + shift
60
- return x
61
-
62
- class PerceiverAttentionBlock(nn.Module):
63
- def __init__(
64
- self, d_model: int, n_heads: int,
65
- time_embedding_dim: Optional[int] = None,
66
- double_kv: Optional[bool] = True,
67
- ):
68
- super().__init__()
69
- self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
70
- self.n_heads = n_heads
71
-
72
- self.mlp = nn.Sequential(
73
- OrderedDict(
74
- [
75
- ("c_fc", nn.Linear(d_model, d_model * 4)),
76
- ("sq_relu", SquaredReLU()),
77
- ("c_proj", nn.Linear(d_model * 4, d_model)),
78
- ]
79
- )
80
- )
81
- self.double_kv = double_kv
82
- self.ln_1 = AdaLayerNorm(d_model, time_embedding_dim)
83
- self.ln_2 = AdaLayerNorm(d_model, time_embedding_dim)
84
- self.ln_ff = AdaLayerNorm(d_model, time_embedding_dim)
85
-
86
- def attention(self, q: torch.Tensor, kv: torch.Tensor, attn_mask: torch.Tensor = None):
87
- attn_output, attn_output_weights = self.attn(q, kv, kv, need_weights=False, key_padding_mask=attn_mask)
88
- return attn_output
89
-
90
- def forward(
91
- self,
92
- x: torch.Tensor,
93
- latents: torch.Tensor,
94
- timestep_embedding: torch.Tensor = None,
95
- attn_mask: torch.Tensor = None
96
- ):
97
- normed_latents = self.ln_1(latents, timestep_embedding)
98
- normed_x = self.ln_2(x, timestep_embedding)
99
- if self.double_kv:
100
- kv = torch.cat([normed_latents, normed_x], dim=1)
101
- else:
102
- kv = normed_x
103
- attn = self.attention(
104
- q=normed_latents,
105
- kv=kv,
106
- attn_mask=attn_mask,
107
- )
108
- if attn_mask is not None:
109
- query_padding_mask = attn_mask.chunk(2, -1)[0].unsqueeze(-1) # (B, 2S) -> (B, S, 1)
110
- latents = latents + attn * (~query_padding_mask).to(attn)
111
- else:
112
- latents = latents + attn
113
- latents = latents + self.mlp(self.ln_ff(latents, timestep_embedding))
114
- return latents
115
-
116
-
117
- class CLIPModAdapter(ModelMixin, ConfigMixin):
118
- @register_to_config
119
- def __init__(
120
- self,
121
- out_dim=3072,
122
- width=1024,
123
- pblock_width=512,
124
- layers=6,
125
- pblock_layers=1,
126
- heads=8,
127
- input_text_dim=4096,
128
- input_image_dim=1024,
129
- pblock_single_blocks=0,
130
- ):
131
- super().__init__()
132
- self.out_dim = out_dim
133
-
134
- self.net = TextImageResampler(
135
- width=width,
136
- layers=layers,
137
- heads=heads,
138
- input_text_dim=input_text_dim,
139
- input_image_dim=input_image_dim,
140
- time_embedding_dim=64,
141
- output_dim=out_dim,
142
- )
143
- self.net2 = TextImageResampler(
144
- width=pblock_width,
145
- layers=pblock_layers,
146
- heads=heads,
147
- input_text_dim=input_text_dim,
148
- input_image_dim=input_image_dim,
149
- time_embedding_dim=64,
150
- output_dim=out_dim*(19+pblock_single_blocks),
151
- )
152
-
153
- def enable_gradient_checkpointing(self):
154
- self.gradient_checkpointing = True
155
- self.net.enable_gradient_checkpointing()
156
- self.net2.enable_gradient_checkpointing()
157
-
158
-
159
- def forward(self, t_emb, llm_hidden_states, clip_outputs):
160
- if len(llm_hidden_states.shape) > 3:
161
- llm_hidden_states = llm_hidden_states[..., -1, :]
162
- batch_size, seq_length = llm_hidden_states.shape[:2]
163
-
164
- img_cls_feat = clip_outputs["image_embeds"] # (B, 768)
165
- img_last_feat = clip_outputs["last_hidden_state"] # (B, 257, 1024)
166
- img_layer_feats = clip_outputs["hidden_states"] # [(B, 257, 1024) * 25]
167
- img_second_last_feat = img_layer_feats[-2] # (B, 257, 1024)
168
-
169
- img_hidden_states = img_second_last_feat # (B, 257, 1024)
170
-
171
- x = self.net(llm_hidden_states, img_hidden_states) # (B, S, 3072)
172
- x2 = self.net2(llm_hidden_states, img_hidden_states).view(batch_size, seq_length, -1, self.out_dim) # (B, S, N, 3072)
173
- return x, x2
174
-
175
-
176
- class TextImageResampler(nn.Module):
177
- def __init__(
178
- self,
179
- width: int = 768,
180
- layers: int = 6,
181
- heads: int = 8,
182
- output_dim: int = 3072,
183
- input_text_dim: int = 4096,
184
- input_image_dim: int = 1024,
185
- time_embedding_dim: int = 64,
186
- ):
187
- super().__init__()
188
- self.output_dim = output_dim
189
- self.input_text_dim = input_text_dim
190
- self.input_image_dim = input_image_dim
191
- self.time_embedding_dim = time_embedding_dim
192
-
193
- self.text_proj_in = nn.Linear(input_text_dim, width)
194
- self.image_proj_in = nn.Linear(input_image_dim, width)
195
-
196
- self.perceiver_blocks = nn.Sequential(
197
- *[
198
- PerceiverAttentionBlock(
199
- width, heads, time_embedding_dim=self.time_embedding_dim
200
- )
201
- for _ in range(layers)
202
- ]
203
- )
204
-
205
- self.proj_out = nn.Sequential(
206
- nn.Linear(width, output_dim), nn.LayerNorm(output_dim)
207
- )
208
-
209
- self.gradient_checkpointing = False
210
-
211
- def enable_gradient_checkpointing(self):
212
- self.gradient_checkpointing = True
213
-
214
-
215
- def forward(
216
- self,
217
- text_hidden_states: torch.Tensor,
218
- image_hidden_states: torch.Tensor,
219
- ):
220
- timestep_embedding = torch.zeros((text_hidden_states.shape[0], 1, self.time_embedding_dim)).to(text_hidden_states)
221
-
222
- text_hidden_states = self.text_proj_in(text_hidden_states)
223
- image_hidden_states = self.image_proj_in(image_hidden_states)
224
-
225
- for p_block in self.perceiver_blocks:
226
- if self.gradient_checkpointing:
227
- def create_custom_forward(module):
228
- def custom_forward(*inputs):
229
- return module(*inputs)
230
- return custom_forward
231
-
232
- text_hidden_states = torch.utils.checkpoint.checkpoint(
233
- create_custom_forward(p_block),
234
- image_hidden_states,
235
- text_hidden_states,
236
- timestep_embedding
237
- )
238
- else:
239
- text_hidden_states = p_block(image_hidden_states, text_hidden_states, timestep_embedding=timestep_embedding)
240
-
241
- text_hidden_states = self.proj_out(text_hidden_states)
242
-
243
- return text_hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/environmentdata.py DELETED
@@ -1,55 +0,0 @@
1
-
2
- from huggingface_hub import snapshot_download
3
- import os
4
-
5
- # FLUX.1-dev
6
- snapshot_download(
7
- repo_id="black-forest-labs/FLUX.1-dev",
8
- local_dir="./FLUX.1-dev",
9
- local_dir_use_symlinks=False
10
- )
11
-
12
- # Florence-2-large
13
- snapshot_download(
14
- repo_id="microsoft/Florence-2-large",
15
- local_dir="./Florence-2-large",
16
- local_dir_use_symlinks=False
17
- )
18
-
19
- # CLIP ViT Large
20
- snapshot_download(
21
- repo_id="openai/clip-vit-large-patch14",
22
- local_dir="./clip-vit-large-patch14",
23
- local_dir_use_symlinks=False
24
- )
25
-
26
- # DINO ViT-s16
27
- snapshot_download(
28
- repo_id="facebook/dino-vits16",
29
- local_dir="./dino-vits16",
30
- local_dir_use_symlinks=False
31
- )
32
-
33
- # mPLUG Visual Question Answering
34
- snapshot_download(
35
- repo_id="xingjianleng/mplug_visual-question-answering_coco_large_en",
36
- local_dir="./mplug_visual-question-answering_coco_large_en",
37
- local_dir_use_symlinks=False
38
- )
39
-
40
- # XVerse
41
- snapshot_download(
42
- repo_id="ByteDance/XVerse",
43
- local_dir="./XVerse",
44
- local_dir_use_symlinks=False
45
- )
46
-
47
-
48
- os.environ["FLORENCE2_MODEL_PATH"] = "./checkpoints/Florence-2-large"
49
- os.environ["SAM2_MODEL_PATH"] = "./checkpoints/sam2.1_hiera_large.pt"
50
- os.environ["FACE_ID_MODEL_PATH"] = "./checkpoints/model_ir_se50.pth"
51
- os.environ["CLIP_MODEL_PATH"] = "./checkpoints/clip-vit-large-patch14"
52
- os.environ["FLUX_MODEL_PATH"] = "./checkpoints/FLUX.1-dev"
53
- os.environ["DPG_VQA_MODEL_PATH"] = "./checkpoints/mplug_visual-question-answering_coco_large_en"
54
- os.environ["DINO_MODEL_PATH"] = "./checkpoints/dino-vits16"
55
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/flux/block.py DELETED
@@ -1,814 +0,0 @@
1
- # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import torch
17
- from typing import List, Union, Optional, Tuple, Dict, Any, Callable
18
- from diffusers.models.attention_processor import Attention, F
19
- from .lora_controller import enable_lora
20
- from einops import rearrange
21
- import math
22
- from diffusers.models.embeddings import apply_rotary_emb
23
-
24
-
25
- def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
26
- # Efficient implementation equivalent to the following:
27
- L, S = query.size(-2), key.size(-2)
28
- B = query.size(0)
29
- scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
30
- attn_bias = torch.zeros(B, 1, L, S, dtype=query.dtype, device=query.device)
31
- if is_causal:
32
- assert attn_mask is None
33
- assert False
34
- temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
35
- attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
36
- attn_bias.to(query.dtype)
37
-
38
- if attn_mask is not None:
39
- if attn_mask.dtype == torch.bool:
40
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
41
- else:
42
- attn_bias += attn_mask
43
- attn_weight = query @ key.transpose(-2, -1) * scale_factor
44
- attn_weight += attn_bias.to(attn_weight.device)
45
- attn_weight = torch.softmax(attn_weight, dim=-1)
46
-
47
- return torch.dropout(attn_weight, dropout_p, train=True) @ value, attn_weight
48
-
49
- def attn_forward(
50
- attn: Attention,
51
- hidden_states: torch.FloatTensor,
52
- encoder_hidden_states: torch.FloatTensor = None,
53
- condition_latents: torch.FloatTensor = None,
54
- text_cond_mask: Optional[torch.FloatTensor] = None,
55
- attention_mask: Optional[torch.FloatTensor] = None,
56
- image_rotary_emb: Optional[torch.Tensor] = None,
57
- cond_rotary_emb: Optional[torch.Tensor] = None,
58
- model_config: Optional[Dict[str, Any]] = {},
59
- store_attn_map: bool = False,
60
- latent_height: Optional[int] = None,
61
- timestep: Optional[torch.Tensor] = None,
62
- last_attn_map: Optional[torch.Tensor] = None,
63
- condition_sblora_weight: Optional[float] = None,
64
- latent_sblora_weight: Optional[float] = None,
65
- ) -> torch.FloatTensor:
66
- batch_size, _, _ = (
67
- hidden_states.shape
68
- if encoder_hidden_states is None
69
- else encoder_hidden_states.shape
70
- )
71
-
72
- is_sblock = encoder_hidden_states is None
73
- is_dblock = not is_sblock
74
-
75
- with enable_lora(
76
- (attn.to_q, attn.to_k, attn.to_v),
77
- (is_dblock and model_config["latent_lora"]) or (is_sblock and model_config["sblock_lora"]), latent_sblora_weight=latent_sblora_weight
78
- ):
79
- query = attn.to_q(hidden_states)
80
- key = attn.to_k(hidden_states)
81
- value = attn.to_v(hidden_states)
82
-
83
- inner_dim = key.shape[-1]
84
- head_dim = inner_dim // attn.heads
85
-
86
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
87
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
88
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
89
-
90
- if attn.norm_q is not None:
91
- query = attn.norm_q(query)
92
- if attn.norm_k is not None:
93
- key = attn.norm_k(key)
94
-
95
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
96
- if encoder_hidden_states is not None:
97
- # `context` projections.
98
- with enable_lora((attn.add_q_proj, attn.add_k_proj, attn.add_v_proj), model_config["text_lora"]):
99
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
100
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
101
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
102
-
103
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
104
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
105
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
106
-
107
- if attn.norm_added_q is not None:
108
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
109
- if attn.norm_added_k is not None:
110
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
111
-
112
- # attention
113
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
114
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
115
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
116
-
117
- if image_rotary_emb is not None:
118
- query = apply_rotary_emb(query, image_rotary_emb)
119
- key = apply_rotary_emb(key, image_rotary_emb)
120
-
121
- if condition_latents is not None:
122
- assert condition_latents.shape[0] == batch_size
123
- cond_length = condition_latents.shape[1]
124
-
125
- cond_lora_activate = (is_dblock and model_config["use_condition_dblock_lora"]) or (is_sblock and model_config["use_condition_sblock_lora"])
126
- with enable_lora(
127
- (attn.to_q, attn.to_k, attn.to_v),
128
- dit_activated=not cond_lora_activate, cond_activated=cond_lora_activate, latent_sblora_weight=condition_sblora_weight #TODO implementation for condition lora not share
129
- ):
130
- cond_query = attn.to_q(condition_latents)
131
- cond_key = attn.to_k(condition_latents)
132
- cond_value = attn.to_v(condition_latents)
133
-
134
- cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
135
- cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
136
- cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
137
- if attn.norm_q is not None:
138
- cond_query = attn.norm_q(cond_query)
139
- if attn.norm_k is not None:
140
- cond_key = attn.norm_k(cond_key)
141
-
142
- if cond_rotary_emb is not None:
143
- cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
144
- cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
145
-
146
- if model_config.get("text_cond_attn", False):
147
- if encoder_hidden_states is not None:
148
- assert text_cond_mask is not None
149
- img_length = hidden_states.shape[1]
150
- seq_length = encoder_hidden_states_query_proj.shape[2]
151
- assert len(text_cond_mask.shape) == 2 or len(text_cond_mask.shape) == 3
152
- if len(text_cond_mask.shape) == 2:
153
- text_cond_mask = text_cond_mask.unsqueeze(-1)
154
- N = text_cond_mask.shape[-1] # num_condition
155
- else:
156
- raise NotImplementedError()
157
-
158
- query = torch.cat([query, cond_query], dim=2) # (B, 24, S+HW+NC)
159
- key = torch.cat([key, cond_key], dim=2)
160
- value = torch.cat([value, cond_value], dim=2)
161
-
162
- assert query.shape[2] == key.shape[2]
163
- assert query.shape[2] == cond_length + img_length + seq_length
164
-
165
- attention_mask = torch.ones(batch_size, 1, query.shape[2], key.shape[2], device=query.device, dtype=torch.bool)
166
- attention_mask[..., -cond_length:, :-cond_length] = False
167
- attention_mask[..., :-cond_length, -cond_length:] = False
168
-
169
- if encoder_hidden_states is not None:
170
- tokens_per_cond = cond_length // N
171
- for i in range(batch_size):
172
- for j in range(N):
173
- start = seq_length + img_length + tokens_per_cond * j
174
- attention_mask[i, 0, :seq_length, start:start+tokens_per_cond] = text_cond_mask[i, :, j].unsqueeze(-1)
175
-
176
- elif model_config.get("union_cond_attn", False):
177
- query = torch.cat([query, cond_query], dim=2) # (B, 24, S+HW+NC)
178
- key = torch.cat([key, cond_key], dim=2)
179
- value = torch.cat([value, cond_value], dim=2)
180
-
181
- attention_mask = torch.ones(batch_size, 1, query.shape[2], key.shape[2], device=query.device, dtype=torch.bool)
182
- cond_length = condition_latents.shape[1]
183
- assert len(text_cond_mask.shape) == 2 or len(text_cond_mask.shape) == 3
184
- if len(text_cond_mask.shape) == 2:
185
- text_cond_mask = text_cond_mask.unsqueeze(-1)
186
- N = text_cond_mask.shape[-1] # num_condition
187
- tokens_per_cond = cond_length // N
188
-
189
- seq_length = 0
190
- if encoder_hidden_states is not None:
191
- seq_length = encoder_hidden_states_query_proj.shape[2]
192
- img_length = hidden_states.shape[1]
193
- else:
194
- seq_length = 128 # TODO, pass it here
195
- img_length = hidden_states.shape[1] - seq_length
196
-
197
- if not model_config.get("cond_cond_cross_attn", True):
198
- # no cross attention between different conds
199
- cond_start = seq_length + img_length
200
- attention_mask[:, :, cond_start:, cond_start:] = False
201
-
202
- for j in range(N):
203
- start = cond_start + tokens_per_cond * j
204
- end = cond_start + tokens_per_cond * (j + 1)
205
- attention_mask[..., start:end, start:end] = True
206
-
207
- # double block
208
- if encoder_hidden_states is not None:
209
-
210
- # no cross attention
211
- attention_mask[..., :-cond_length, -cond_length:] = False
212
-
213
- if model_config.get("use_attention_double", False) and last_attn_map is not None:
214
- attention_mask = torch.zeros(batch_size, 1, query.shape[2], key.shape[2], device=query.device, dtype=torch.bfloat16)
215
- last_attn_map = last_attn_map.to(query.device)
216
- attention_mask[..., seq_length:-cond_length, :seq_length] = torch.log(last_attn_map/last_attn_map.mean()*model_config["use_atten_lambda"]).view(-1, seq_length)
217
-
218
- # single block
219
- else:
220
- # print(last_attn_map)
221
- if model_config.get("use_attention_single", False) and last_attn_map is not None:
222
- attention_mask = torch.zeros(batch_size, 1, query.shape[2], key.shape[2], device=query.device, dtype=torch.bfloat16)
223
- attention_mask[..., :seq_length, -cond_length:] = float('-inf')
224
- # 确保 use_atten_lambda 是列表
225
- use_atten_lambdas = model_config["use_atten_lambda"] if len(model_config["use_atten_lambda"])!=1 else model_config["use_atten_lambda"] * (N+1)
226
- attention_mask[..., -cond_length:, seq_length:-cond_length] = math.log(use_atten_lambdas[0])
227
- last_attn_map = last_attn_map.to(query.device)
228
-
229
- cond2latents = []
230
- for i in range(batch_size):
231
- AM = last_attn_map[i] # (H, W, S)
232
- for j in range(N):
233
- start = seq_length + img_length + tokens_per_cond * j
234
- mask = text_cond_mask[i, :, j] # (S,)
235
- weighted_AM = AM * mask.unsqueeze(0).unsqueeze(0) # 扩展 mask 维度以匹配 AM
236
-
237
- cond2latent = weighted_AM.mean(-1)
238
- if model_config.get("attention_norm", "mean") == "max":
239
- cond2latent = cond2latent / cond2latent.max() # 归一化
240
- else:
241
- cond2latent = cond2latent / cond2latent.mean() # 归一化
242
- cond2latent = cond2latent.view(-1,) # (WH,)
243
-
244
- # 使用对应 condition 的 lambda 值
245
- current_lambda = use_atten_lambdas[j+1]
246
- # 将 cond2latent 复制到 attention_mask[i, 0, :seq_length, start:start+tokens_per_cond]
247
- attention_mask[i, 0, seq_length:-cond_length, start:start+tokens_per_cond] = torch.log(current_lambda * cond2latent.unsqueeze(-1))
248
-
249
- # 将 text_cond_mask[i, :, j].unsqueeze(-1) 为 true 的位置设置为当前 lambda 值
250
- cond = mask.unsqueeze(-1).expand(-1, tokens_per_cond)
251
- sub_mask = attention_mask[i, 0, :seq_length, start:start+tokens_per_cond]
252
- attention_mask[i, 0, :seq_length, start:start+tokens_per_cond] = torch.where(cond, math.log(current_lambda), sub_mask)
253
- cond2latents.append(
254
- cond2latent.reshape(latent_height, -1).detach().cpu()
255
- )
256
- if store_attn_map:
257
- if not hasattr(attn, "cond2latents"):
258
- attn.cond2latents = []
259
- attn.cond_timesteps = []
260
- attn.cond2latents.append(torch.stack(cond2latents, dim=0)) # (N, H, W)
261
- attn.cond_timesteps.append(timestep.cpu())
262
-
263
- pass
264
- else:
265
- raise NotImplementedError()
266
- if hasattr(attn, "c_factor"):
267
- assert False
268
- attention_mask = torch.zeros(
269
- query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
270
- )
271
- bias = torch.log(attn.c_factor[0])
272
- attention_mask[-cond_length:, :-cond_length] = bias
273
- attention_mask[:-cond_length, -cond_length:] = bias
274
-
275
- ####################################################################################################
276
- if store_attn_map and encoder_hidden_states is not None:
277
- seq_length = encoder_hidden_states_query_proj.shape[2]
278
- img_length = hidden_states.shape[1]
279
- hidden_states, attention_probs = scaled_dot_product_attention(
280
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
281
- )
282
- # (B, 24, S+HW, S+HW) -> (B, 24, HW, S)
283
- t2i_attention_probs = attention_probs[:, :, seq_length:seq_length+img_length, :seq_length]
284
- # (B, 24, S+HW, S+HW) -> (B, 24, S, HW) -> (B, 24, HW, S)
285
- i2t_attention_probs = attention_probs[:, :, :seq_length, seq_length:seq_length+img_length].transpose(-1, -2)
286
-
287
- if not hasattr(attn, "attn_maps"):
288
- attn.attn_maps = []
289
- attn.timestep = []
290
-
291
- attn.attn_maps.append(
292
- (
293
- rearrange(t2i_attention_probs, 'B attn_head (H W) attn_dim -> B attn_head H W attn_dim', H=latent_height),
294
- rearrange(i2t_attention_probs, 'B attn_head (H W) attn_dim -> B attn_head H W attn_dim', H=latent_height),
295
- )
296
- )
297
-
298
- attn.timestep.append(timestep.cpu())
299
- has_nan = torch.isnan(hidden_states).any().item()
300
- if has_nan:
301
- print("[attn_forward] detect nan hidden_states in store_attn_map")
302
- else:
303
- hidden_states = F.scaled_dot_product_attention(
304
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
305
- )
306
- has_nan = torch.isnan(hidden_states).any().item()
307
- if has_nan:
308
- print("[attn_forward] detect nan hidden_states")
309
- ####################################################################################################
310
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
311
-
312
- if encoder_hidden_states is not None:
313
- if condition_latents is not None:
314
- encoder_hidden_states, hidden_states, condition_latents = (
315
- hidden_states[:, : encoder_hidden_states.shape[1]],
316
- hidden_states[
317
- :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
318
- ],
319
- hidden_states[:, -condition_latents.shape[1] :],
320
- )
321
- if model_config.get("latent_cond_by_text_attn", False):
322
- # hidden_states += add_latent # (B, HW, D)
323
- hidden_states = new_hidden_states # (B, HW, D)
324
-
325
- else:
326
- encoder_hidden_states, hidden_states = (
327
- hidden_states[:, : encoder_hidden_states.shape[1]],
328
- hidden_states[:, encoder_hidden_states.shape[1] :],
329
- )
330
-
331
-
332
- with enable_lora((attn.to_out[0],), model_config["latent_lora"]):
333
- hidden_states = attn.to_out[0](hidden_states) # linear proj
334
- hidden_states = attn.to_out[1](hidden_states) # dropout
335
- with enable_lora((attn.to_add_out,), model_config["text_lora"]):
336
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
337
-
338
- if condition_latents is not None:
339
- cond_lora_activate = model_config["use_condition_dblock_lora"]
340
- with enable_lora(
341
- (attn.to_out[0],),
342
- dit_activated=not cond_lora_activate, cond_activated=cond_lora_activate,
343
- ):
344
- condition_latents = attn.to_out[0](condition_latents)
345
- condition_latents = attn.to_out[1](condition_latents)
346
-
347
-
348
- return (
349
- (hidden_states, encoder_hidden_states, condition_latents)
350
- if condition_latents is not None
351
- else (hidden_states, encoder_hidden_states)
352
- )
353
- elif condition_latents is not None:
354
- hidden_states, condition_latents = (
355
- hidden_states[:, : -condition_latents.shape[1]],
356
- hidden_states[:, -condition_latents.shape[1] :],
357
- )
358
- return hidden_states, condition_latents
359
- else:
360
- return hidden_states
361
-
362
-
363
- def set_delta_by_start_end(
364
- start_ends,
365
- src_delta_emb, src_delta_emb_pblock,
366
- delta_emb, delta_emb_pblock, delta_emb_mask,
367
- ):
368
- for (i, j, src_s, src_e, tar_s, tar_e) in start_ends:
369
- if src_delta_emb is not None:
370
- delta_emb[i, tar_s:tar_e] = src_delta_emb[j, src_s:src_e]
371
- if src_delta_emb_pblock is not None:
372
- delta_emb_pblock[i, tar_s:tar_e] = src_delta_emb_pblock[j, src_s:src_e]
373
- delta_emb_mask[i, tar_s:tar_e] = True
374
- return delta_emb, delta_emb_pblock, delta_emb_mask
375
-
376
- def norm1_context_forward(
377
- self,
378
- x: torch.Tensor,
379
- condition_latents: Optional[torch.Tensor] = None,
380
- timestep: Optional[torch.Tensor] = None,
381
- class_labels: Optional[torch.LongTensor] = None,
382
- hidden_dtype: Optional[torch.dtype] = None,
383
- emb: Optional[torch.Tensor] = None,
384
- delta_emb: Optional[torch.Tensor] = None,
385
- delta_emb_cblock: Optional[torch.Tensor] = None,
386
- delta_emb_mask: Optional[torch.Tensor] = None,
387
- delta_start_ends = None,
388
- mod_adapter = None,
389
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
390
- batch_size, seq_length = x.shape[:2]
391
-
392
- if mod_adapter is not None:
393
- assert False
394
-
395
- if delta_emb is None:
396
- emb = self.linear(self.silu(emb)) # (B, 3072) -> (B, 18432)
397
- emb = emb.unsqueeze(1) # (B, 1, 18432)
398
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) # (B, 1, 3072)
399
- x = self.norm(x) * (1 + scale_msa) + shift_msa # (B, 1, 3072)
400
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
401
- else:
402
- # (B, 3072) > (B, 18432) -> (B, S, 18432)
403
- emb_orig = self.linear(self.silu(emb)).unsqueeze(1).expand((-1, seq_length, -1))
404
- # (B, 3072) -> (B, 1, 3072) -> (B, S, 3072) -> (B, S, 18432)
405
- if delta_emb_cblock is None:
406
- emb_new = self.linear(self.silu(emb.unsqueeze(1) + delta_emb))
407
- else:
408
- emb_new = self.linear(self.silu(emb.unsqueeze(1) + delta_emb + delta_emb_cblock))
409
- emb = torch.where(delta_emb_mask.unsqueeze(-1), emb_new, emb_orig) # (B, S, 18432)
410
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) # (B, S, 3072)
411
- x = self.norm(x) * (1 + scale_msa) + shift_msa # (B, S, 3072)
412
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
413
-
414
-
415
- def norm1_forward(
416
- self,
417
- x: torch.Tensor,
418
- timestep: Optional[torch.Tensor] = None,
419
- class_labels: Optional[torch.LongTensor] = None,
420
- hidden_dtype: Optional[torch.dtype] = None,
421
- emb: Optional[torch.Tensor] = None,
422
- delta_emb: Optional[torch.Tensor] = None,
423
- delta_emb_cblock: Optional[torch.Tensor] = None,
424
- delta_emb_mask: Optional[torch.Tensor] = None,
425
- t2i_attn_map: Optional[torch.Tensor] = None,
426
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
427
- if delta_emb is None:
428
- emb = self.linear(self.silu(emb)) # (B, 3072) -> (B, 18432)
429
- emb = emb.unsqueeze(1) # (B, 1, 18432)
430
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) # (B, 1, 3072)
431
- x = self.norm(x) * (1 + scale_msa) + shift_msa # (B, 1, 3072)
432
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
433
- else:
434
- raise NotImplementedError()
435
- batch_size, HW = x.shape[:2]
436
- seq_length = t2i_attn_map.shape[-1]
437
- # (B, 3072) > (B, 18432) -> (B, S, 18432)
438
- emb_orig = self.linear(self.silu(emb)).unsqueeze(1).expand((-1, seq_length, -1))
439
- # (B, 3072) -> (B, 1, 3072) -> (B, S, 3072) -> (B, S, 18432)
440
- if delta_emb_cblock is None:
441
- emb_new = self.linear(self.silu(emb.unsqueeze(1) + delta_emb))
442
- else:
443
- emb_new = self.linear(self.silu(emb.unsqueeze(1) + delta_emb + delta_emb_cblock))
444
- # attn_weight (B, HW, S)
445
- emb = torch.where(delta_emb_mask.unsqueeze(-1), emb_new, emb_orig) # (B, S, 18432)
446
- emb = t2i_attn_map @ emb # (B, HW, 18432)
447
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) # (B, HW, 3072)
448
- x = self.norm(x) * (1 + scale_msa) + shift_msa # (B, HW, 3072)
449
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
450
-
451
-
452
- def block_forward(
453
- self,
454
- hidden_states: torch.FloatTensor,
455
- encoder_hidden_states: torch.FloatTensor,
456
- condition_latents: torch.FloatTensor,
457
- temb: torch.FloatTensor,
458
- cond_temb: torch.FloatTensor,
459
- text_cond_mask: Optional[torch.FloatTensor] = None,
460
- delta_emb: Optional[torch.FloatTensor] = None,
461
- delta_emb_cblock: Optional[torch.FloatTensor] = None,
462
- delta_emb_mask: Optional[torch.Tensor] = None,
463
- delta_start_ends = None,
464
- cond_rotary_emb=None,
465
- image_rotary_emb=None,
466
- model_config: Optional[Dict[str, Any]] = {},
467
- store_attn_map: bool = False,
468
- use_text_mod: bool = True,
469
- use_img_mod: bool = False,
470
- mod_adapter = None,
471
- latent_height: Optional[int] = None,
472
- timestep: Optional[torch.Tensor] = None,
473
- last_attn_map: Optional[torch.Tensor] = None,
474
- ):
475
- batch_size = hidden_states.shape[0]
476
- use_cond = condition_latents is not None
477
-
478
- train_partial_latent_lora = model_config.get("train_partial_latent_lora", False)
479
- train_partial_text_lora = model_config.get("train_partial_text_lora", False)
480
- if train_partial_latent_lora:
481
- train_partial_latent_lora_layers = model_config.get("train_partial_latent_lora_layers", "")
482
- activate_norm1 = activate_ff = True
483
- if "norm1" not in train_partial_latent_lora_layers:
484
- activate_norm1 = False
485
- if "ff" not in train_partial_latent_lora_layers:
486
- activate_ff = False
487
-
488
- if train_partial_text_lora:
489
- train_partial_text_lora_layers = model_config.get("train_partial_text_lora_layers", "")
490
- activate_norm1_context = activate_ff_context = True
491
- if "norm1" not in train_partial_text_lora_layers:
492
- activate_norm1_context = False
493
- if "ff" not in train_partial_text_lora_layers:
494
- activate_ff_context = False
495
-
496
- if use_cond:
497
- cond_lora_activate = model_config["use_condition_dblock_lora"]
498
- with enable_lora(
499
- (self.norm1.linear,),
500
- dit_activated=activate_norm1 if train_partial_latent_lora else not cond_lora_activate, cond_activated=cond_lora_activate,
501
- ):
502
- norm_condition_latents, cond_gate_msa, cond_shift_mlp, cond_scale_mlp, cond_gate_mlp = (
503
- norm1_forward(
504
- self.norm1,
505
- condition_latents,
506
- emb=cond_temb,
507
- )
508
- )
509
- delta_emb_img = delta_emb_img_cblock = None
510
- if use_img_mod and use_text_mod:
511
- if delta_emb is not None:
512
- delta_emb_img, delta_emb = delta_emb.chunk(2, dim=-1)
513
- if delta_emb_cblock is not None:
514
- delta_emb_img_cblock, delta_emb_cblock = delta_emb_cblock.chunk(2, dim=-1)
515
-
516
- with enable_lora((self.norm1.linear,), activate_norm1 if train_partial_latent_lora else model_config["latent_lora"]):
517
- if use_img_mod and encoder_hidden_states is not None:
518
- with torch.no_grad():
519
- attn = self.attn
520
-
521
- norm_img = self.norm1(hidden_states, emb=temb)[0]
522
- norm_text = self.norm1_context(encoder_hidden_states, emb=temb)[0]
523
-
524
- img_query = attn.to_q(norm_img)
525
- img_key = attn.to_k(norm_img)
526
- text_query = attn.add_q_proj(norm_text)
527
- text_key = attn.add_k_proj(norm_text)
528
-
529
- inner_dim = img_key.shape[-1]
530
- head_dim = inner_dim // attn.heads
531
-
532
- img_query = img_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # (B, N, HW, D)
533
- img_key = img_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # (B, N, HW, D)
534
- text_query = text_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # (B, N, S, D)
535
- text_key = text_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # (B, N, S, D)
536
-
537
- if attn.norm_q is not None:
538
- img_query = attn.norm_q(img_query)
539
- if attn.norm_added_q is not None:
540
- text_query = attn.norm_added_q(text_query)
541
- if attn.norm_k is not None:
542
- img_key = attn.norm_k(img_key)
543
- if attn.norm_added_k is not None:
544
- text_key = attn.norm_added_k(text_key)
545
-
546
- query = torch.cat([text_query, img_query], dim=2) # (B, N, S+HW, D)
547
- key = torch.cat([text_key, img_key], dim=2) # (B, N, S+HW, D)
548
- if image_rotary_emb is not None:
549
- query = apply_rotary_emb(query, image_rotary_emb)
550
- key = apply_rotary_emb(key, image_rotary_emb)
551
-
552
- seq_length = text_query.shape[2]
553
-
554
- scale_factor = 1 / math.sqrt(query.size(-1))
555
- t2i_attn_map = query @ key.transpose(-2, -1) * scale_factor # (B, N, S+HW, S+HW)
556
- t2i_attn_map = t2i_attn_map.mean(1)[:, seq_length:, :seq_length] # (B, S+HW, S+HW) -> (B, HW, S)
557
- t2i_attn_map = torch.softmax(t2i_attn_map, dim=-1) # (B, HW, S)
558
-
559
- else:
560
- t2i_attn_map = None
561
-
562
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
563
- norm1_forward(
564
- self.norm1,
565
- hidden_states,
566
- emb=temb,
567
- delta_emb=delta_emb_img,
568
- delta_emb_cblock=delta_emb_img_cblock,
569
- delta_emb_mask=delta_emb_mask,
570
- t2i_attn_map=t2i_attn_map,
571
- )
572
- )
573
- # Modulation for double block
574
- with enable_lora((self.norm1_context.linear,), activate_norm1_context if train_partial_text_lora else model_config["text_lora"]):
575
- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
576
- norm1_context_forward(
577
- self.norm1_context,
578
- encoder_hidden_states,
579
- emb=temb,
580
- delta_emb=delta_emb if use_text_mod else None,
581
- delta_emb_cblock=delta_emb_cblock if use_text_mod else None,
582
- delta_emb_mask=delta_emb_mask if use_text_mod else None,
583
- delta_start_ends=delta_start_ends if use_text_mod else None,
584
- mod_adapter=mod_adapter,
585
- condition_latents=condition_latents,
586
- )
587
- )
588
-
589
- # Attention.
590
- result = attn_forward(
591
- self.attn,
592
- model_config=model_config,
593
- hidden_states=norm_hidden_states,
594
- encoder_hidden_states=norm_encoder_hidden_states,
595
- condition_latents=norm_condition_latents if use_cond else None,
596
- text_cond_mask=text_cond_mask if use_cond else None,
597
- image_rotary_emb=image_rotary_emb,
598
- cond_rotary_emb=cond_rotary_emb if use_cond else None,
599
- store_attn_map=store_attn_map,
600
- latent_height=latent_height,
601
- timestep=timestep,
602
- last_attn_map=last_attn_map,
603
- )
604
- attn_output, context_attn_output = result[:2]
605
- cond_attn_output = result[2] if use_cond else None
606
-
607
- # Process attention outputs for the `hidden_states`.
608
- # 1. hidden_states
609
- attn_output = gate_msa * attn_output # NOTE: changed by img mod
610
- hidden_states = hidden_states + attn_output
611
- # 2. encoder_hidden_states
612
- context_attn_output = c_gate_msa * context_attn_output # NOTE: changed by delta_temb
613
- encoder_hidden_states = encoder_hidden_states + context_attn_output
614
- # 3. condition_latents
615
- if use_cond:
616
- cond_attn_output = cond_gate_msa * cond_attn_output # NOTE: changed by img mod
617
- condition_latents = condition_latents + cond_attn_output
618
- if model_config.get("add_cond_attn", False):
619
- hidden_states += cond_attn_output
620
-
621
- # LayerNorm + MLP.
622
- # 1. hidden_states
623
- norm_hidden_states = self.norm2(hidden_states)
624
- norm_hidden_states = (
625
- norm_hidden_states * (1 + scale_mlp) + shift_mlp # NOTE: changed by img mod
626
- )
627
- # 2. encoder_hidden_states
628
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
629
- norm_encoder_hidden_states = (
630
- norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp # NOTE: changed by delta_temb
631
- )
632
- # 3. condition_latents
633
- if use_cond:
634
- norm_condition_latents = self.norm2(condition_latents)
635
- norm_condition_latents = (
636
- norm_condition_latents * (1 + cond_scale_mlp) + cond_shift_mlp # NOTE: changed by img mod
637
- )
638
-
639
- # Feed-forward.
640
- with enable_lora((self.ff.net[2],), activate_ff if train_partial_latent_lora else model_config["latent_lora"]):
641
- # 1. hidden_states
642
- ff_output = self.ff(norm_hidden_states)
643
- ff_output = gate_mlp * ff_output # NOTE: changed by img mod
644
- # 2. encoder_hidden_states
645
- with enable_lora((self.ff_context.net[2],), activate_ff_context if train_partial_text_lora else model_config["text_lora"]):
646
- context_ff_output = self.ff_context(norm_encoder_hidden_states)
647
- context_ff_output = c_gate_mlp * context_ff_output # NOTE: changed by delta_temb
648
- # 3. condition_latents
649
- if use_cond:
650
- cond_lora_activate = model_config["use_condition_dblock_lora"]
651
- with enable_lora(
652
- (self.ff.net[2],),
653
- dit_activated=activate_ff if train_partial_latent_lora else not cond_lora_activate, cond_activated=cond_lora_activate,
654
- ):
655
- cond_ff_output = self.ff(norm_condition_latents)
656
- cond_ff_output = cond_gate_mlp * cond_ff_output # NOTE: changed by img mod
657
-
658
- # Process feed-forward outputs.
659
- hidden_states = hidden_states + ff_output
660
- encoder_hidden_states = encoder_hidden_states + context_ff_output
661
- if use_cond:
662
- condition_latents = condition_latents + cond_ff_output
663
-
664
- # Clip to avoid overflow.
665
- if encoder_hidden_states.dtype == torch.float16:
666
- encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
667
-
668
- return encoder_hidden_states, hidden_states, condition_latents if use_cond else None
669
-
670
- def single_norm_forward(
671
- self,
672
- x: torch.Tensor,
673
- timestep: Optional[torch.Tensor] = None,
674
- class_labels: Optional[torch.LongTensor] = None,
675
- hidden_dtype: Optional[torch.dtype] = None,
676
- emb: Optional[torch.Tensor] = None,
677
- delta_emb: Optional[torch.Tensor] = None,
678
- delta_emb_cblock: Optional[torch.Tensor] = None,
679
- delta_emb_mask: Optional[torch.Tensor] = None,
680
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
681
- if delta_emb is None:
682
- emb = self.linear(self.silu(emb)) # (B, 3072) -> (B, 9216)
683
- emb = emb.unsqueeze(1) # (B, 1, 9216)
684
- shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1) # (B, 1, 3072)
685
- x = self.norm(x) * (1 + scale_msa) + shift_msa # (B, S, 3072) * (B, 1, 3072)
686
- return x, gate_msa
687
- else:
688
- img_text_seq_length = x.shape[1] # S+
689
- text_seq_length = delta_emb_mask.shape[1] # S
690
- # (B, 3072) -> (B, 9216) -> (B, S+, 9216)
691
- emb_orig = self.linear(self.silu(emb)).unsqueeze(1).expand((-1, img_text_seq_length, -1))
692
- # (B, 3072) -> (B, 1, 3072) -> (B, S, 3072) -> (B, S, 9216)
693
- if delta_emb_cblock is None:
694
- emb_new = self.linear(self.silu(emb.unsqueeze(1) + delta_emb))
695
- else:
696
- emb_new = self.linear(self.silu(emb.unsqueeze(1) + delta_emb + delta_emb_cblock))
697
-
698
- emb_text = torch.where(delta_emb_mask.unsqueeze(-1), emb_new, emb_orig[:, :text_seq_length]) # (B, S, 9216)
699
- emb_img = emb_orig[:, text_seq_length:] # (B, s, 9216)
700
- emb = torch.cat([emb_text, emb_img], dim=1) # (B, S+, 9216)
701
-
702
- shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1) # (B, S+, 3072)
703
- x = self.norm(x) * (1 + scale_msa) + shift_msa # (B, S+, 3072)
704
- return x, gate_msa
705
-
706
-
707
- def single_block_forward(
708
- self,
709
- hidden_states: torch.FloatTensor,
710
- temb: torch.FloatTensor,
711
- image_rotary_emb=None,
712
- condition_latents: torch.FloatTensor = None,
713
- text_cond_mask: torch.FloatTensor = None,
714
- cond_temb: torch.FloatTensor = None,
715
- delta_emb: Optional[torch.FloatTensor] = None,
716
- delta_emb_cblock: Optional[torch.FloatTensor] = None,
717
- delta_emb_mask: Optional[torch.Tensor] = None,
718
- use_text_mod: bool = True,
719
- use_img_mod: bool = False,
720
- cond_rotary_emb=None,
721
- latent_height: Optional[int] = None,
722
- timestep: Optional[torch.Tensor] = None,
723
- store_attn_map: bool = False,
724
- model_config: Optional[Dict[str, Any]] = {},
725
- last_attn_map: Optional[torch.Tensor] = None,
726
- latent_sblora_weight=None,
727
- condition_sblora_weight=None,
728
- ):
729
-
730
- using_cond = condition_latents is not None
731
- residual = hidden_states
732
-
733
- train_partial_lora = model_config.get("train_partial_lora", False)
734
- if train_partial_lora:
735
- train_partial_lora_layers = model_config.get("train_partial_lora_layers", "")
736
- activate_norm = activate_projmlp = activate_projout = True
737
-
738
- if "norm" not in train_partial_lora_layers:
739
- activate_norm = False
740
- if "projmlp" not in train_partial_lora_layers:
741
- activate_projmlp = False
742
- if "projout" not in train_partial_lora_layers:
743
- activate_projout = False
744
-
745
- with enable_lora((self.norm.linear,), activate_norm if train_partial_lora else model_config["sblock_lora"], latent_sblora_weight=latent_sblora_weight):
746
- # Modulation for single block
747
- norm_hidden_states, gate = single_norm_forward(
748
- self.norm,
749
- hidden_states,
750
- emb=temb,
751
- delta_emb=delta_emb if use_text_mod else None,
752
- delta_emb_cblock=delta_emb_cblock if use_text_mod else None,
753
- delta_emb_mask=delta_emb_mask if use_text_mod else None,
754
- )
755
- with enable_lora((self.proj_mlp,), activate_projmlp if train_partial_lora else model_config["sblock_lora"], latent_sblora_weight=latent_sblora_weight):
756
- mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
757
- if using_cond:
758
- cond_lora_activate = model_config["use_condition_sblock_lora"]
759
- with enable_lora(
760
- (self.norm.linear,),
761
- dit_activated=activate_norm if train_partial_lora else not cond_lora_activate, cond_activated=cond_lora_activate, latent_sblora_weight=condition_sblora_weight
762
- ):
763
- residual_cond = condition_latents
764
- norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
765
- with enable_lora(
766
- (self.proj_mlp,),
767
- dit_activated=activate_projmlp if train_partial_lora else not cond_lora_activate, cond_activated=cond_lora_activate, latent_sblora_weight=condition_sblora_weight
768
- ):
769
- mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
770
-
771
- attn_output = attn_forward(
772
- self.attn,
773
- model_config=model_config,
774
- hidden_states=norm_hidden_states,
775
- image_rotary_emb=image_rotary_emb,
776
- last_attn_map=last_attn_map,
777
- latent_height=latent_height,
778
- store_attn_map=store_attn_map,
779
- timestep=timestep,
780
- latent_sblora_weight=latent_sblora_weight,
781
- condition_sblora_weight=condition_sblora_weight,
782
- **(
783
- {
784
- "condition_latents": norm_condition_latents,
785
- "cond_rotary_emb": cond_rotary_emb if using_cond else None,
786
- "text_cond_mask": text_cond_mask if using_cond else None,
787
- }
788
- if using_cond
789
- else {}
790
- ),
791
- )
792
- if using_cond:
793
- attn_output, cond_attn_output = attn_output
794
-
795
- with enable_lora((self.proj_out,), activate_projout if train_partial_lora else model_config["sblock_lora"], latent_sblora_weight=latent_sblora_weight):
796
- hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
797
- # gate = (B, 1, 3072) or (B, S+, 3072)
798
- hidden_states = gate * self.proj_out(hidden_states)
799
- hidden_states = residual + hidden_states
800
- if using_cond:
801
- cond_lora_activate = model_config["use_condition_sblock_lora"]
802
- with enable_lora(
803
- (self.proj_out,),
804
- dit_activated=activate_projout if train_partial_lora else not cond_lora_activate, cond_activated=cond_lora_activate, latent_sblora_weight=condition_sblora_weight
805
- ):
806
- condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
807
- cond_gate = cond_gate.unsqueeze(1)
808
- condition_latents = cond_gate * self.proj_out(condition_latents)
809
- condition_latents = residual_cond + condition_latents
810
-
811
- if hidden_states.dtype == torch.float16:
812
- hidden_states = hidden_states.clip(-65504, 65504)
813
-
814
- return hidden_states if not using_cond else (hidden_states, condition_latents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/flux/condition.py DELETED
@@ -1,133 +0,0 @@
1
- # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import torch
17
- from typing import Optional, Union, List, Tuple
18
- from diffusers.pipelines import FluxPipeline
19
- from PIL import Image, ImageFilter
20
- import numpy as np
21
- import cv2
22
-
23
- import src.flux.pipeline_tools
24
-
25
- # condition_dict = {
26
- # "depth": 0,
27
- # "canny": 1,
28
- # "subject": 4,
29
- # "coloring": 6,
30
- # "deblurring": 7,
31
- # "depth_pred": 8,
32
- # "fill": 9,
33
- # "sr": 10,
34
- # }
35
-
36
-
37
- # class Condition(object):
38
- # def __init__(
39
- # self,
40
- # condition_type: str,
41
- # raw_img: Union[Image.Image, torch.Tensor] = None,
42
- # condition: Union[Image.Image, torch.Tensor] = None,
43
- # mask=None,
44
- # position_delta=None,
45
- # ) -> None:
46
- # self.condition_type = condition_type
47
- # assert raw_img is not None or condition is not None
48
- # if raw_img is not None:
49
- # self.condition = self.get_condition(condition_type, raw_img)
50
- # else:
51
- # self.condition = condition
52
- # self.position_delta = position_delta
53
- # # TODO: Add mask support
54
- # assert mask is None, "Mask not supported yet"
55
-
56
- # def get_condition(
57
- # self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
58
- # ) -> Union[Image.Image, torch.Tensor]:
59
- # """
60
- # Returns the condition image.
61
- # """
62
- # if condition_type == "depth":
63
- # from transformers import pipeline
64
-
65
- # depth_pipe = pipeline(
66
- # task="depth-estimation",
67
- # model="LiheYoung/depth-anything-small-hf",
68
- # device="cuda",
69
- # )
70
- # source_image = raw_img.convert("RGB")
71
- # condition_img = depth_pipe(source_image)["depth"].convert("RGB")
72
- # return condition_img
73
- # elif condition_type == "canny":
74
- # img = np.array(raw_img)
75
- # edges = cv2.Canny(img, 100, 200)
76
- # edges = Image.fromarray(edges).convert("RGB")
77
- # return edges
78
- # elif condition_type == "subject":
79
- # return raw_img
80
- # elif condition_type == "coloring":
81
- # return raw_img.convert("L").convert("RGB")
82
- # elif condition_type == "deblurring":
83
- # condition_image = (
84
- # raw_img.convert("RGB")
85
- # .filter(ImageFilter.GaussianBlur(10))
86
- # .convert("RGB")
87
- # )
88
- # return condition_image
89
- # elif condition_type == "fill":
90
- # return raw_img.convert("RGB")
91
- # return self.condition
92
-
93
- # @property
94
- # def type_id(self) -> int:
95
- # """
96
- # Returns the type id of the condition.
97
- # """
98
- # return condition_dict[self.condition_type]
99
-
100
- # @classmethod
101
- # def get_type_id(cls, condition_type: str) -> int:
102
- # """
103
- # Returns the type id of the condition.
104
- # """
105
- # return condition_dict[condition_type]
106
-
107
- # def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tensor, int]:
108
- # """
109
- # Encodes the condition into tokens, ids and type_id.
110
- # """
111
- # if self.condition_type in [
112
- # "depth",
113
- # "canny",
114
- # "subject",
115
- # "coloring",
116
- # "deblurring",
117
- # "depth_pred",
118
- # "fill",
119
- # "sr",
120
- # ]:
121
- # tokens, ids = encode_vae_images(pipe, self.condition)
122
- # else:
123
- # raise NotImplementedError(
124
- # f"Condition type {self.condition_type} not implemented"
125
- # )
126
- # if self.position_delta is None and self.condition_type == "subject":
127
- # self.position_delta = [0, -self.condition.size[0] // 16]
128
- # if self.position_delta is not None:
129
- # ids[:, 1] += self.position_delta[0]
130
- # ids[:, 2] += self.position_delta[1]
131
- # print(f"[Condition.encode] position_delta={self.position_delta}")
132
- # type_id = torch.ones_like(ids[:, :1]) * self.type_id
133
- # return tokens, ids, type_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/flux/generate.py DELETED
@@ -1,838 +0,0 @@
1
- # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import torch
17
- import yaml, os
18
- from PIL import Image
19
- from diffusers.pipelines import FluxPipeline
20
- from typing import List, Union, Optional, Dict, Any, Callable
21
- from src.flux.transformer import tranformer_forward
22
- import src.flux.condition
23
-
24
- # # from diffusers.pipelines.flux.pipeline_flux import (
25
- # # FluxPipelineOutput,
26
- # # calculate_shift,
27
- # # retrieve_timesteps,
28
- # # np,
29
- # # )
30
- # from src.flux.pipeline_tools import (
31
- # encode_prompt_with_clip_t5, tokenize_t5_prompt, clear_attn_maps, encode_vae_images
32
- # )
33
-
34
- # from src.flux.pipeline_tools import CustomFluxPipeline, load_modulation_adapter, decode_vae_images, \
35
- # save_attention_maps, gather_attn_maps, clear_attn_maps, load_dit_lora, quantization
36
-
37
- # from src.utils.data_utils import pad_to_square, pad_to_target, pil2tensor, get_closest_ratio, get_aspect_ratios
38
- # from src.utils.modulation_utils import get_word_index, unpad_input_ids
39
-
40
- # def get_config(config_path: str = None):
41
- # config_path = config_path or os.environ.get("XFL_CONFIG")
42
- # if not config_path:
43
- # return {}
44
- # with open(config_path, "r") as f:
45
- # config = yaml.safe_load(f)
46
- # return config
47
-
48
-
49
- # def prepare_params(
50
- # prompt: Union[str, List[str]] = None,
51
- # prompt_2: Optional[Union[str, List[str]]] = None,
52
- # height: Optional[int] = 512,
53
- # width: Optional[int] = 512,
54
- # num_inference_steps: int = 28,
55
- # timesteps: List[int] = None,
56
- # guidance_scale: float = 3.5,
57
- # num_images_per_prompt: Optional[int] = 1,
58
- # generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
59
- # latents: Optional[torch.FloatTensor] = None,
60
- # prompt_embeds: Optional[torch.FloatTensor] = None,
61
- # pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
62
- # output_type: Optional[str] = "pil",
63
- # return_dict: bool = True,
64
- # joint_attention_kwargs: Optional[Dict[str, Any]] = None,
65
- # callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
66
- # callback_on_step_end_tensor_inputs: List[str] = ["latents"],
67
- # max_sequence_length: int = 512,
68
- # verbose: bool = False,
69
- # **kwargs: dict,
70
- # ):
71
- # return (
72
- # prompt,
73
- # prompt_2,
74
- # height,
75
- # width,
76
- # num_inference_steps,
77
- # timesteps,
78
- # guidance_scale,
79
- # num_images_per_prompt,
80
- # generator,
81
- # latents,
82
- # prompt_embeds,
83
- # pooled_prompt_embeds,
84
- # output_type,
85
- # return_dict,
86
- # joint_attention_kwargs,
87
- # callback_on_step_end,
88
- # callback_on_step_end_tensor_inputs,
89
- # max_sequence_length,
90
- # verbose,
91
- # )
92
-
93
-
94
- # def seed_everything(seed: int = 42):
95
- # torch.backends.cudnn.deterministic = True
96
- # torch.manual_seed(seed)
97
- # np.random.seed(seed)
98
-
99
-
100
- # @torch.no_grad()
101
- # def generate(
102
- # pipeline: FluxPipeline,
103
- # vae_conditions: List[Condition] = None,
104
- # config_path: str = None,
105
- # model_config: Optional[Dict[str, Any]] = {},
106
- # vae_condition_scale: float = 1.0,
107
- # default_lora: bool = False,
108
- # condition_pad_to: str = "square",
109
- # condition_size: int = 512,
110
- # text_cond_mask: Optional[torch.FloatTensor] = None,
111
- # delta_emb: Optional[torch.FloatTensor] = None,
112
- # delta_emb_pblock: Optional[torch.FloatTensor] = None,
113
- # delta_emb_mask: Optional[torch.FloatTensor] = None,
114
- # delta_start_ends = None,
115
- # condition_latents = None,
116
- # condition_ids = None,
117
- # mod_adapter = None,
118
- # store_attn_map: bool = False,
119
- # vae_skip_iter: str = None,
120
- # control_weight_lambda: str = None,
121
- # double_attention: bool = False,
122
- # single_attention: bool = False,
123
- # ip_scale: str = None,
124
- # use_latent_sblora_control: bool = False,
125
- # latent_sblora_scale: str = None,
126
- # use_condition_sblora_control: bool = False,
127
- # condition_sblora_scale: str = None,
128
- # idips = None,
129
- # **params: dict,
130
- # ):
131
- # model_config = model_config or get_config(config_path).get("model", {})
132
-
133
- # vae_skip_iter = model_config.get("vae_skip_iter", vae_skip_iter)
134
- # double_attention = model_config.get("double_attention", double_attention)
135
- # single_attention = model_config.get("single_attention", single_attention)
136
- # control_weight_lambda = model_config.get("control_weight_lambda", control_weight_lambda)
137
- # ip_scale = model_config.get("ip_scale", ip_scale)
138
- # use_latent_sblora_control = model_config.get("use_latent_sblora_control", use_latent_sblora_control)
139
- # use_condition_sblora_control = model_config.get("use_condition_sblora_control", use_condition_sblora_control)
140
-
141
- # latent_sblora_scale = model_config.get("latent_sblora_scale", latent_sblora_scale)
142
- # condition_sblora_scale = model_config.get("condition_sblora_scale", condition_sblora_scale)
143
-
144
- # model_config["use_attention_double"] = False
145
- # model_config["use_attention_single"] = False
146
- # use_attention = False
147
-
148
- # if idips is not None:
149
- # if control_weight_lambda != "no":
150
- # parts = control_weight_lambda.split(',')
151
- # new_parts = []
152
- # for part in parts:
153
- # if ':' in part:
154
- # left, right = part.split(':')
155
- # values = right.split('/')
156
- # # 保存整体值
157
- # global_value = values[0]
158
- # id_value = values[1]
159
- # ip_value = values[2]
160
- # new_values = [global_value]
161
- # for is_id in idips:
162
- # if is_id:
163
- # new_values.append(id_value)
164
- # else:
165
- # new_values.append(ip_value)
166
- # new_part = f"{left}:{('/'.join(new_values))}"
167
- # new_parts.append(new_part)
168
- # else:
169
- # new_parts.append(part)
170
- # control_weight_lambda = ','.join(new_parts)
171
-
172
- # if vae_condition_scale != 1:
173
- # for name, module in pipeline.transformer.named_modules():
174
- # if not name.endswith(".attn"):
175
- # continue
176
- # module.c_factor = torch.ones(1, 1) * vae_condition_scale
177
-
178
- # self = pipeline
179
- # (
180
- # prompt,
181
- # prompt_2,
182
- # height,
183
- # width,
184
- # num_inference_steps,
185
- # timesteps,
186
- # guidance_scale,
187
- # num_images_per_prompt,
188
- # generator,
189
- # latents,
190
- # prompt_embeds,
191
- # pooled_prompt_embeds,
192
- # output_type,
193
- # return_dict,
194
- # joint_attention_kwargs,
195
- # callback_on_step_end,
196
- # callback_on_step_end_tensor_inputs,
197
- # max_sequence_length,
198
- # verbose,
199
- # ) = prepare_params(**params)
200
-
201
- # height = height or self.default_sample_size * self.vae_scale_factor
202
- # width = width or self.default_sample_size * self.vae_scale_factor
203
-
204
- # # 1. Check inputs. Raise error if not correct
205
- # self.check_inputs(
206
- # prompt,
207
- # prompt_2,
208
- # height,
209
- # width,
210
- # prompt_embeds=prompt_embeds,
211
- # pooled_prompt_embeds=pooled_prompt_embeds,
212
- # callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
213
- # max_sequence_length=max_sequence_length,
214
- # )
215
-
216
- # self._guidance_scale = guidance_scale
217
- # self._joint_attention_kwargs = joint_attention_kwargs
218
- # self._interrupt = False
219
-
220
- # # 2. Define call parameters
221
- # if prompt is not None and isinstance(prompt, str):
222
- # batch_size = 1
223
- # elif prompt is not None and isinstance(prompt, list):
224
- # batch_size = len(prompt)
225
- # else:
226
- # batch_size = prompt_embeds.shape[0]
227
-
228
- # device = self._execution_device
229
-
230
- # lora_scale = (
231
- # self.joint_attention_kwargs.get("scale", None)
232
- # if self.joint_attention_kwargs is not None
233
- # else None
234
- # )
235
- # (
236
- # t5_prompt_embeds,
237
- # pooled_prompt_embeds,
238
- # text_ids,
239
- # ) = encode_prompt_with_clip_t5(
240
- # self=self,
241
- # prompt="" if self.text_encoder_2 is None else prompt,
242
- # prompt_2=None,
243
- # prompt_embeds=prompt_embeds,
244
- # pooled_prompt_embeds=pooled_prompt_embeds,
245
- # device=device,
246
- # num_images_per_prompt=num_images_per_prompt,
247
- # max_sequence_length=max_sequence_length,
248
- # lora_scale=lora_scale,
249
- # )
250
-
251
- # # 4. Prepare latent variables
252
- # num_channels_latents = self.transformer.config.in_channels // 4
253
- # latents, latent_image_ids = self.prepare_latents(
254
- # batch_size * num_images_per_prompt,
255
- # num_channels_latents,
256
- # height,
257
- # width,
258
- # pooled_prompt_embeds.dtype,
259
- # device,
260
- # generator,
261
- # latents,
262
- # )
263
-
264
- # latent_height = height // 16
265
-
266
- # # 5. Prepare timesteps
267
- # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
268
- # image_seq_len = latents.shape[1]
269
- # mu = calculate_shift(
270
- # image_seq_len,
271
- # self.scheduler.config.base_image_seq_len,
272
- # self.scheduler.config.max_image_seq_len,
273
- # self.scheduler.config.base_shift,
274
- # self.scheduler.config.max_shift,
275
- # )
276
- # timesteps, num_inference_steps = retrieve_timesteps(
277
- # self.scheduler,
278
- # num_inference_steps,
279
- # device,
280
- # timesteps,
281
- # sigmas,
282
- # mu=mu,
283
- # )
284
- # num_warmup_steps = max(
285
- # len(timesteps) - num_inference_steps * self.scheduler.order, 0
286
- # )
287
- # self._num_timesteps = len(timesteps)
288
-
289
- # attn_map = None
290
-
291
- # # 6. Denoising loop
292
- # with self.progress_bar(total=num_inference_steps) as progress_bar:
293
- # totalsteps = timesteps[0]
294
- # if control_weight_lambda is not None:
295
- # print("control_weight_lambda", control_weight_lambda)
296
- # control_weight_lambda_schedule = []
297
- # for scale_str in control_weight_lambda.split(','):
298
- # time_region, scale = scale_str.split(':')
299
- # start, end = time_region.split('-')
300
- # scales = [float(s) for s in scale.split('/')]
301
- # control_weight_lambda_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, scales])
302
-
303
- # if ip_scale is not None:
304
- # print("ip_scale", ip_scale)
305
- # ip_scale_schedule = []
306
- # for scale_str in ip_scale.split(','):
307
- # time_region, scale = scale_str.split(':')
308
- # start, end = time_region.split('-')
309
- # ip_scale_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, float(scale)])
310
-
311
- # if use_latent_sblora_control:
312
- # if latent_sblora_scale is not None:
313
- # print("latent_sblora_scale", latent_sblora_scale)
314
- # latent_sblora_scale_schedule = []
315
- # for scale_str in latent_sblora_scale.split(','):
316
- # time_region, scale = scale_str.split(':')
317
- # start, end = time_region.split('-')
318
- # latent_sblora_scale_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, float(scale)])
319
-
320
- # if use_condition_sblora_control:
321
- # if condition_sblora_scale is not None:
322
- # print("condition_sblora_scale", condition_sblora_scale)
323
- # condition_sblora_scale_schedule = []
324
- # for scale_str in condition_sblora_scale.split(','):
325
- # time_region, scale = scale_str.split(':')
326
- # start, end = time_region.split('-')
327
- # condition_sblora_scale_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, float(scale)])
328
-
329
-
330
- # if vae_skip_iter is not None:
331
- # print("vae_skip_iter", vae_skip_iter)
332
- # vae_skip_iter_schedule = []
333
- # for scale_str in vae_skip_iter.split(','):
334
- # time_region, scale = scale_str.split(':')
335
- # start, end = time_region.split('-')
336
- # vae_skip_iter_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, float(scale)])
337
-
338
- # if control_weight_lambda is not None and attn_map is None:
339
- # batch_size = latents.shape[0]
340
- # latent_width = latents.shape[1]//latent_height
341
- # attn_map = torch.ones(batch_size, latent_height, latent_width, 128, device=latents.device, dtype=torch.bfloat16)
342
- # print("contol_weight_only", attn_map.shape)
343
-
344
- # self.scheduler.set_begin_index(0)
345
- # self.scheduler._init_step_index(0)
346
- # for i, t in enumerate(timesteps):
347
-
348
- # if control_weight_lambda is not None:
349
- # cur_control_weight_lambda = []
350
- # for start, end, scale in control_weight_lambda_schedule:
351
- # if t <= start and t >= end:
352
- # cur_control_weight_lambda = scale
353
- # break
354
- # print(f"timestep:{t}, cur_control_weight_lambda:{cur_control_weight_lambda}")
355
-
356
- # if cur_control_weight_lambda:
357
- # model_config["use_attention_single"] = True
358
- # use_attention = True
359
- # model_config["use_atten_lambda"] = cur_control_weight_lambda
360
- # else:
361
- # model_config["use_attention_single"] = False
362
- # use_attention = False
363
-
364
- # if self.interrupt:
365
- # continue
366
-
367
- # if isinstance(delta_emb, list):
368
- # cur_delta_emb = delta_emb[i]
369
- # cur_delta_emb_pblock = delta_emb_pblock[i]
370
- # cur_delta_emb_mask = delta_emb_mask[i]
371
- # else:
372
- # cur_delta_emb = delta_emb
373
- # cur_delta_emb_pblock = delta_emb_pblock
374
- # cur_delta_emb_mask = delta_emb_mask
375
-
376
-
377
- # # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
378
- # timestep = t.expand(latents.shape[0]).to(latents.dtype) / 1000
379
- # prompt_embeds = t5_prompt_embeds
380
- # text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=prompt_embeds.dtype)
381
-
382
- # # handle guidance
383
- # if self.transformer.config.guidance_embeds:
384
- # guidance = torch.tensor([guidance_scale], device=device)
385
- # guidance = guidance.expand(latents.shape[0])
386
- # else:
387
- # guidance = None
388
- # self.transformer.enable_lora()
389
-
390
- # lora_weight = 1
391
- # if ip_scale is not None:
392
- # lora_weight = 0
393
- # for start, end, scale in ip_scale_schedule:
394
- # if t <= start and t >= end:
395
- # lora_weight = scale
396
- # break
397
- # if lora_weight != 1: print(f"timestep:{t}, lora_weights:{lora_weight}")
398
-
399
- # latent_sblora_weight = None
400
- # if use_latent_sblora_control:
401
- # if latent_sblora_scale is not None:
402
- # latent_sblora_weight = 0
403
- # for start, end, scale in latent_sblora_scale_schedule:
404
- # if t <= start and t >= end:
405
- # latent_sblora_weight = scale
406
- # break
407
- # if latent_sblora_weight != 1: print(f"timestep:{t}, latent_sblora_weight:{latent_sblora_weight}")
408
-
409
- # condition_sblora_weight = None
410
- # if use_condition_sblora_control:
411
- # if condition_sblora_scale is not None:
412
- # condition_sblora_weight = 0
413
- # for start, end, scale in condition_sblora_scale_schedule:
414
- # if t <= start and t >= end:
415
- # condition_sblora_weight = scale
416
- # break
417
- # if condition_sblora_weight !=1: print(f"timestep:{t}, condition_sblora_weight:{condition_sblora_weight}")
418
-
419
- # vae_skip_iter_t = False
420
- # if vae_skip_iter is not None:
421
- # for start, end, scale in vae_skip_iter_schedule:
422
- # if t <= start and t >= end:
423
- # vae_skip_iter_t = bool(scale)
424
- # break
425
- # if vae_skip_iter_t:
426
- # print(f"timestep:{t}, skip vae:{vae_skip_iter_t}")
427
-
428
- # noise_pred = tranformer_forward(
429
- # self.transformer,
430
- # model_config=model_config,
431
- # # Inputs of the condition (new feature)
432
- # text_cond_mask=text_cond_mask,
433
- # delta_emb=cur_delta_emb,
434
- # delta_emb_pblock=cur_delta_emb_pblock,
435
- # delta_emb_mask=cur_delta_emb_mask,
436
- # delta_start_ends=delta_start_ends,
437
- # condition_latents=None if vae_skip_iter_t else condition_latents,
438
- # condition_ids=None if vae_skip_iter_t else condition_ids,
439
- # condition_type_ids=None,
440
- # # Inputs to the original transformer
441
- # hidden_states=latents,
442
- # # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
443
- # timestep=timestep,
444
- # guidance=guidance,
445
- # pooled_projections=pooled_prompt_embeds,
446
- # encoder_hidden_states=prompt_embeds,
447
- # txt_ids=text_ids,
448
- # img_ids=latent_image_ids,
449
- # joint_attention_kwargs={'scale': lora_weight, "latent_sblora_weight": latent_sblora_weight, "condition_sblora_weight": condition_sblora_weight},
450
- # store_attn_map=use_attention,
451
- # last_attn_map=attn_map if cur_control_weight_lambda else None,
452
- # use_text_mod=model_config["modulation"]["use_text_mod"],
453
- # use_img_mod=model_config["modulation"]["use_img_mod"],
454
- # mod_adapter=mod_adapter,
455
- # latent_height=latent_height,
456
- # return_dict=False,
457
- # )[0]
458
-
459
- # if use_attention:
460
- # attn_maps, _ = gather_attn_maps(self.transformer, clear=True)
461
-
462
- # # compute the previous noisy sample x_t -> x_t-1
463
- # latents_dtype = latents.dtype
464
- # latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
465
-
466
- # if latents.dtype != latents_dtype:
467
- # if torch.backends.mps.is_available():
468
- # # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
469
- # latents = latents.to(latents_dtype)
470
-
471
- # if callback_on_step_end is not None:
472
- # callback_kwargs = {}
473
- # for k in callback_on_step_end_tensor_inputs:
474
- # callback_kwargs[k] = locals()[k]
475
- # callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
476
-
477
- # latents = callback_outputs.pop("latents", latents)
478
- # prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
479
-
480
- # # call the callback, if provided
481
- # if i == len(timesteps) - 1 or (
482
- # (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
483
- # ):
484
- # progress_bar.update()
485
-
486
- # if output_type == "latent":
487
- # image = latents
488
-
489
- # else:
490
- # latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
491
- # latents = (
492
- # latents / self.vae.config.scaling_factor
493
- # ) + self.vae.config.shift_factor
494
- # image = self.vae.decode(latents, return_dict=False)[0]
495
- # image = self.image_processor.postprocess(image, output_type=output_type)
496
-
497
- # # Offload all models
498
- # self.maybe_free_model_hooks()
499
-
500
- # self.transformer.enable_lora()
501
-
502
- # if vae_condition_scale != 1:
503
- # for name, module in pipeline.transformer.named_modules():
504
- # if not name.endswith(".attn"):
505
- # continue
506
- # del module.c_factor
507
-
508
- # if not return_dict:
509
- # return (image,)
510
-
511
- # return FluxPipelineOutput(images=image)
512
-
513
-
514
- # @torch.no_grad()
515
- # def generate_from_test_sample(
516
- # test_sample, pipe, config,
517
- # num_images=1,
518
- # vae_skip_iter: str = None,
519
- # target_height: int = None,
520
- # target_width: int = None,
521
- # seed: int = 42,
522
- # control_weight_lambda: str = None,
523
- # double_attention: bool = False,
524
- # single_attention: bool = False,
525
- # ip_scale: str = None,
526
- # use_latent_sblora_control: bool = False,
527
- # latent_sblora_scale: str = None,
528
- # use_condition_sblora_control: bool = False,
529
- # condition_sblora_scale: str = None,
530
- # use_idip = False,
531
- # **kargs
532
- # ):
533
- # target_size = config["train"]["dataset"]["val_target_size"]
534
- # condition_size = config["train"]["dataset"].get("val_condition_size", target_size//2)
535
- # condition_pad_to = config["train"]["dataset"]["condition_pad_to"]
536
- # pos_offset_type = config["model"].get("pos_offset_type", "width")
537
- # seed = config["model"].get("seed", seed)
538
-
539
- # device = pipe._execution_device
540
-
541
- # condition_imgs = test_sample['input_images']
542
- # position_delta = test_sample['position_delta']
543
- # prompt = test_sample['prompt']
544
- # original_image = test_sample.get('original_image', None)
545
- # condition_type = test_sample.get('condition_type', "subject")
546
- # modulation_input = test_sample.get('modulation', None)
547
-
548
- # delta_start_ends = None
549
- # condition_latents = condition_ids = None
550
- # text_cond_mask = None
551
-
552
- # delta_embs = None
553
- # delta_embs_pblock = None
554
- # delta_embs_mask = None
555
-
556
- # try:
557
- # max_length = config["model"]["modulation"]["max_text_len"]
558
- # except Exception as e:
559
- # print(e)
560
- # max_length = 512
561
-
562
- # if modulation_input is None or len(modulation_input) == 0:
563
- # delta_emb = delta_emb_pblock = delta_emb_mask = None
564
- # else:
565
- # dtype = torch.bfloat16
566
- # batch_size = 1
567
- # N = config["model"]["modulation"].get("per_block_adapter_single_blocks", 0) + 19
568
- # guidance = torch.tensor([3.5]).to(device).expand(batch_size)
569
- # out_dim = config["model"]["modulation"]["out_dim"]
570
-
571
- # tar_text_inputs = tokenize_t5_prompt(pipe, prompt, max_length)
572
- # tar_padding_mask = tar_text_inputs.attention_mask.to(device).bool()
573
- # tar_tokens = tar_text_inputs.input_ids.to(device)
574
- # if config["model"]["modulation"]["eos_exclude"]:
575
- # tar_padding_mask[tar_tokens == 1] = False
576
-
577
- # def get_start_end_by_pompt_matching(src_prompts, tar_prompts):
578
- # text_cond_mask = torch.zeros(batch_size, max_length, device=device, dtype=torch.bool)
579
- # tar_prompt_input_ids = tokenize_t5_prompt(pipe, tar_prompts, max_length).input_ids
580
- # src_prompt_count = 1
581
- # start_ends = []
582
- # for i, (src_prompt, tar_prompt, tar_prompt_tokens) in enumerate(zip(src_prompts, tar_prompts, tar_prompt_input_ids)):
583
- # try:
584
- # tar_start, tar_end = get_word_index(pipe, tar_prompt, tar_prompt_tokens, src_prompt, src_prompt_count, max_length, verbose=False)
585
- # start_ends.append([tar_start, tar_end])
586
- # text_cond_mask[i, tar_start:tar_end] = True
587
- # except Exception as e:
588
- # print(e)
589
- # return start_ends, text_cond_mask
590
-
591
- # def encode_mod_image(pil_images):
592
- # if config["model"]["modulation"]["use_dit"]:
593
- # raise NotImplementedError()
594
- # else:
595
- # pil_images = [pad_to_square(img).resize((224, 224)) for img in pil_images]
596
- # if config["model"]["modulation"]["use_vae"]:
597
- # raise NotImplementedError()
598
- # else:
599
- # clip_pixel_values = pipe.clip_processor(
600
- # text=None, images=pil_images, do_resize=False, do_center_crop=False, return_tensors="pt",
601
- # ).pixel_values.to(dtype=dtype, device=device)
602
- # clip_outputs = pipe.clip_model(clip_pixel_values, output_hidden_states=True, interpolate_pos_encoding=True, return_dict=True)
603
- # return clip_outputs
604
-
605
- # def rgba_to_white_background(input_path, background=(255,255,255)):
606
- # with Image.open(input_path).convert("RGBA") as img:
607
- # img_np = np.array(img)
608
- # alpha = img_np[:, :, 3] / 255.0 # 归一化Alpha通道[3](@ref)
609
- # rgb = img_np[:, :, :3].astype(float) # 提取RGB通道
610
-
611
- # background_np = np.full_like(rgb, background, dtype=float) # 根据参数生成背景[7](@ref)
612
-
613
- # # 混合计算:前景色*alpha + 背景色*(1-alpha)
614
- # result_np = rgb * alpha[..., np.newaxis] + \
615
- # background_np * (1 - alpha[..., np.newaxis])
616
-
617
- # result = Image.fromarray(result_np.astype(np.uint8), "RGB")
618
- # return result
619
- # def get_mod_emb(modulation_input, timestep):
620
- # delta_emb = torch.zeros((batch_size, max_length, out_dim), dtype=dtype, device=device)
621
- # delta_emb_pblock = torch.zeros((batch_size, max_length, N, out_dim), dtype=dtype, device=device)
622
- # delta_emb_mask = torch.zeros((batch_size, max_length), dtype=torch.bool, device=device)
623
- # delta_start_ends = None
624
- # condition_latents = condition_ids = None
625
- # text_cond_mask = None
626
-
627
- # if modulation_input[0]["type"] == "adapter":
628
- # num_inputs = len(modulation_input[0]["src_inputs"])
629
- # src_prompts = [x["caption"] for x in modulation_input[0]["src_inputs"]]
630
- # src_text_inputs = tokenize_t5_prompt(pipe, src_prompts, max_length)
631
- # src_input_ids = unpad_input_ids(src_text_inputs.input_ids, src_text_inputs.attention_mask)
632
- # tar_input_ids = unpad_input_ids(tar_text_inputs.input_ids, tar_text_inputs.attention_mask)
633
- # src_prompt_embeds = pipe._get_t5_prompt_embeds(prompt=src_prompts, max_sequence_length=max_length, device=device) # (M, 512, 4096)
634
-
635
- # pil_images = [rgba_to_white_background(x["image_path"]) for x in modulation_input[0]["src_inputs"]]
636
-
637
- # src_ds_scales = [x.get("downsample_scale", 1.0) for x in modulation_input[0]["src_inputs"]]
638
- # resized_pil_images = []
639
- # for img, ds_scale in zip(pil_images, src_ds_scales):
640
- # img = pad_to_square(img)
641
- # if ds_scale < 1.0:
642
- # assert ds_scale > 0
643
- # img = img.resize((int(224 * ds_scale), int(224 * ds_scale))).resize((224, 224))
644
- # resized_pil_images.append(img)
645
- # pil_images = resized_pil_images
646
-
647
- # img_encoded = encode_mod_image(pil_images)
648
- # delta_start_ends = []
649
- # text_cond_mask = torch.zeros(num_inputs, max_length, device=device, dtype=torch.bool)
650
- # if config["model"]["modulation"]["pass_vae"]:
651
- # pil_images = [pad_to_square(img).resize((condition_size, condition_size)) for img in pil_images]
652
- # with torch.no_grad():
653
- # batch_tensor = torch.stack([pil2tensor(x) for x in pil_images])
654
- # x_0, img_ids = encode_vae_images(pipe, batch_tensor) # (N, 256, 64)
655
-
656
- # condition_latents = x_0.clone().detach().reshape(1, -1, 64) # (1, N256, 64)
657
- # condition_ids = img_ids.clone().detach()
658
- # condition_ids = condition_ids.unsqueeze(0).repeat_interleave(num_inputs, dim=0) # (N, 256, 3)
659
- # for i in range(num_inputs):
660
- # condition_ids[i, :, 1] += 0 if pos_offset_type == "width" else -(batch_tensor.shape[-1]//16) * (i + 1)
661
- # condition_ids[i, :, 2] += -(batch_tensor.shape[-1]//16) * (i + 1)
662
- # condition_ids = condition_ids.reshape(-1, 3) # (N256, 3)
663
-
664
- # if config["model"]["modulation"]["use_dit"]:
665
- # raise NotImplementedError()
666
- # else:
667
- # src_delta_embs = [] # [(512, 3072)]
668
- # src_delta_emb_pblock = []
669
- # for i in range(num_inputs):
670
- # if isinstance(img_encoded, dict):
671
- # _src_clip_outputs = {}
672
- # for key in img_encoded:
673
- # if torch.is_tensor(img_encoded[key]):
674
- # _src_clip_outputs[key] = img_encoded[key][i:i+1]
675
- # else:
676
- # _src_clip_outputs[key] = [x[i:i+1] for x in img_encoded[key]]
677
- # _img_encoded = _src_clip_outputs
678
- # else:
679
- # _img_encoded = img_encoded[i:i+1]
680
-
681
- # x1, x2 = pipe.modulation_adapters[0](timestep, src_prompt_embeds[i:i+1], _img_encoded)
682
- # src_delta_embs.append(x1[0]) # (512, 3072)
683
- # src_delta_emb_pblock.append(x2[0]) # (512, N, 3072)
684
-
685
- # for input_args in modulation_input[0]["use_words"]:
686
- # src_word_count = 1
687
- # if len(input_args) == 3:
688
- # src_input_index, src_word, tar_word = input_args
689
- # tar_word_count = 1
690
- # else:
691
- # src_input_index, src_word, tar_word, tar_word_count = input_args[:4]
692
- # src_prompt = src_prompts[src_input_index]
693
- # tar_prompt = prompt
694
-
695
- # src_start, src_end = get_word_index(pipe, src_prompt, src_input_ids[src_input_index], src_word, src_word_count, max_length, verbose=False)
696
- # tar_start, tar_end = get_word_index(pipe, tar_prompt, tar_input_ids[0], tar_word, tar_word_count, max_length, verbose=False)
697
- # if delta_emb is not None:
698
- # delta_emb[:, tar_start:tar_end] = src_delta_embs[src_input_index][src_start:src_end] # (B, 512, 3072)
699
- # if delta_emb_pblock is not None:
700
- # delta_emb_pblock[:, tar_start:tar_end] = src_delta_emb_pblock[src_input_index][src_start:src_end] # (B, 512, N, 3072)
701
- # delta_emb_mask[:, tar_start:tar_end] = True
702
- # text_cond_mask[src_input_index, tar_start:tar_end] = True
703
- # delta_start_ends.append([0, src_input_index, src_start, src_end, tar_start, tar_end])
704
- # text_cond_mask = text_cond_mask.transpose(0, 1).unsqueeze(0)
705
-
706
- # else:
707
- # raise NotImplementedError()
708
- # return delta_emb, delta_emb_pblock, delta_emb_mask, \
709
- # text_cond_mask, delta_start_ends, condition_latents, condition_ids
710
-
711
- # num_inference_steps = 28 # FIXME: harcoded here
712
- # num_channels_latents = pipe.transformer.config.in_channels // 4
713
-
714
- # # set timesteps
715
- # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
716
- # mu = calculate_shift(
717
- # num_channels_latents,
718
- # pipe.scheduler.config.base_image_seq_len,
719
- # pipe.scheduler.config.max_image_seq_len,
720
- # pipe.scheduler.config.base_shift,
721
- # pipe.scheduler.config.max_shift,
722
- # )
723
- # timesteps, num_inference_steps = retrieve_timesteps(
724
- # pipe.scheduler,
725
- # num_inference_steps,
726
- # device,
727
- # None,
728
- # sigmas,
729
- # mu=mu,
730
- # )
731
-
732
- # if modulation_input is not None:
733
- # delta_embs = []
734
- # delta_embs_pblock = []
735
- # delta_embs_mask = []
736
- # for i, t in enumerate(timesteps):
737
- # t = t.expand(1).to(torch.bfloat16) / 1000
738
- # (
739
- # delta_emb, delta_emb_pblock, delta_emb_mask,
740
- # text_cond_mask, delta_start_ends,
741
- # condition_latents, condition_ids
742
- # ) = get_mod_emb(modulation_input, t)
743
- # delta_embs.append(delta_emb)
744
- # delta_embs_pblock.append(delta_emb_pblock)
745
- # delta_embs_mask.append(delta_emb_mask)
746
-
747
- # if original_image is not None:
748
- # raise NotImplementedError()
749
- # (target_height, target_width), closest_ratio = get_closest_ratio(original_image.height, original_image.width, train_aspect_ratios)
750
- # elif modulation_input is None or len(modulation_input) == 0:
751
- # delta_emb = delta_emb_pblock = delta_emb_mask = None
752
- # else:
753
- # for i, t in enumerate(timesteps):
754
- # t = t.expand(1).to(torch.bfloat16) / 1000
755
- # (
756
- # delta_emb, delta_emb_pblock, delta_emb_mask,
757
- # text_cond_mask, delta_start_ends,
758
- # condition_latents, condition_ids
759
- # ) = get_mod_emb(modulation_input, t)
760
- # delta_embs.append(delta_emb)
761
- # delta_embs_pblock.append(delta_emb_pblock)
762
- # delta_embs_mask.append(delta_emb_mask)
763
-
764
- # if target_height is None or target_width is None:
765
- # target_height = target_width = target_size
766
-
767
- # if condition_pad_to == "square":
768
- # condition_imgs = [pad_to_square(x) for x in condition_imgs]
769
- # elif condition_pad_to == "target":
770
- # condition_imgs = [pad_to_target(x, (target_size, target_size)) for x in condition_imgs]
771
- # condition_imgs = [x.resize((condition_size, condition_size)).convert("RGB") for x in condition_imgs]
772
- # # TODO: fix position_delta
773
- # conditions = [
774
- # Condition(
775
- # condition_type=condition_type,
776
- # condition=x,
777
- # position_delta=position_delta,
778
- # ) for x in condition_imgs
779
- # ]
780
- # # vlm_images = condition_imgs if config["model"]["use_vlm"] else []
781
-
782
- # use_perblock_adapter = False
783
- # try:
784
- # if config["model"]["modulation"]["use_perblock_adapter"]:
785
- # use_perblock_adapter = True
786
- # except Exception as e:
787
- # pass
788
-
789
- # results = []
790
- # for i in range(num_images):
791
- # clear_attn_maps(pipe.transformer)
792
- # generator = torch.Generator(device=device)
793
- # generator.manual_seed(seed + i)
794
- # if modulation_input is None or len(modulation_input) == 0:
795
- # idips = None
796
- # else:
797
- # idips = ["human" in p["image_path"] for p in modulation_input[0]["src_inputs"]]
798
- # if len(modulation_input[0]["use_words"][0])==5:
799
- # print("use idips in use_words")
800
- # idips = [x[-1] for x in modulation_input[0]["use_words"]]
801
- # result_img = generate(
802
- # pipe,
803
- # prompt=prompt,
804
- # max_sequence_length=max_length,
805
- # vae_conditions=conditions,
806
- # generator=generator,
807
- # model_config=config["model"],
808
- # height=target_height,
809
- # width=target_width,
810
- # condition_pad_to=condition_pad_to,
811
- # condition_size=condition_size,
812
- # text_cond_mask=text_cond_mask,
813
- # delta_emb=delta_embs,
814
- # delta_emb_pblock=delta_embs_pblock if use_perblock_adapter else None,
815
- # delta_emb_mask=delta_embs_mask,
816
- # delta_start_ends=delta_start_ends,
817
- # condition_latents=condition_latents,
818
- # condition_ids=condition_ids,
819
- # mod_adapter=pipe.modulation_adapters[0] if config["model"]["modulation"]["use_dit"] else None,
820
- # vae_skip_iter=vae_skip_iter,
821
- # control_weight_lambda=control_weight_lambda,
822
- # double_attention=double_attention,
823
- # single_attention=single_attention,
824
- # ip_scale=ip_scale,
825
- # use_latent_sblora_control=use_latent_sblora_control,
826
- # latent_sblora_scale=latent_sblora_scale,
827
- # use_condition_sblora_control=use_condition_sblora_control,
828
- # condition_sblora_scale=condition_sblora_scale,
829
- # idips=idips if use_idip else None,
830
- # **kargs,
831
- # ).images[0]
832
-
833
- # final_image = result_img
834
- # results.append(final_image)
835
-
836
- # if num_images == 1:
837
- # return results[0]
838
- # return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/flux/lora_controller.py DELETED
@@ -1,99 +0,0 @@
1
- # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from peft.tuners.tuners_utils import BaseTunerLayer
17
- from typing import List, Any, Optional, Type
18
-
19
-
20
- class enable_lora:
21
- def __init__(self, lora_modules: List[BaseTunerLayer], dit_activated: bool, cond_activated: bool=False, latent_sblora_weight: float=None, condition_sblora_weight: float=None) -> None:
22
- self.dit_activated = dit_activated
23
- self.cond_activated = cond_activated
24
- self.latent_sblora_weight = latent_sblora_weight
25
- self.condition_sblora_weight = condition_sblora_weight
26
- # assert not (dit_activated and cond_activated)
27
-
28
- self.lora_modules: List[BaseTunerLayer] = [
29
- each for each in lora_modules if isinstance(each, BaseTunerLayer)
30
- ]
31
-
32
- self.scales = [
33
- {
34
- active_adapter: lora_module.scaling[active_adapter] if active_adapter in lora_module.scaling else 1
35
- for active_adapter in lora_module.active_adapters
36
- } for lora_module in self.lora_modules
37
- ]
38
-
39
-
40
- def __enter__(self) -> None:
41
- for i, lora_module in enumerate(self.lora_modules):
42
- if not isinstance(lora_module, BaseTunerLayer):
43
- continue
44
- for active_adapter in lora_module.active_adapters:
45
- if active_adapter == "default":
46
- if self.dit_activated:
47
- lora_module.scaling[active_adapter] = self.scales[0]["default"] if self.latent_sblora_weight is None else self.latent_sblora_weight
48
- else:
49
- lora_module.scaling[active_adapter] = 0
50
- else:
51
- assert active_adapter == "condition"
52
- if self.cond_activated:
53
- lora_module.scaling[active_adapter] = self.scales[0]["condition"] if self.condition_sblora_weight is None else self.condition_sblora_weight
54
- else:
55
- lora_module.scaling[active_adapter] = 0
56
-
57
- def __exit__(
58
- self,
59
- exc_type: Optional[Type[BaseException]],
60
- exc_val: Optional[BaseException],
61
- exc_tb: Optional[Any],
62
- ) -> None:
63
- for i, lora_module in enumerate(self.lora_modules):
64
- if not isinstance(lora_module, BaseTunerLayer):
65
- continue
66
- for active_adapter in lora_module.active_adapters:
67
- lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
68
-
69
- class set_lora_scale:
70
- def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
71
- self.lora_modules: List[BaseTunerLayer] = [
72
- each for each in lora_modules if isinstance(each, BaseTunerLayer)
73
- ]
74
- self.scales = [
75
- {
76
- active_adapter: lora_module.scaling[active_adapter]
77
- for active_adapter in lora_module.active_adapters
78
- }
79
- for lora_module in self.lora_modules
80
- ]
81
- self.scale = scale
82
-
83
- def __enter__(self) -> None:
84
- for lora_module in self.lora_modules:
85
- if not isinstance(lora_module, BaseTunerLayer):
86
- continue
87
- lora_module.scale_layer(self.scale)
88
-
89
- def __exit__(
90
- self,
91
- exc_type: Optional[Type[BaseException]],
92
- exc_val: Optional[BaseException],
93
- exc_tb: Optional[Any],
94
- ) -> None:
95
- for i, lora_module in enumerate(self.lora_modules):
96
- if not isinstance(lora_module, BaseTunerLayer):
97
- continue
98
- for active_adapter in lora_module.active_adapters:
99
- lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/flux/pipeline_tools.py DELETED
@@ -1,685 +0,0 @@
1
- # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import inspect
17
- from typing import Any, Callable, Dict, List, Optional, Union
18
- import os
19
- import torch
20
- from torch import Tensor
21
- import torch.nn.functional as F
22
- from diffusers.pipelines import FluxPipeline
23
- from diffusers.utils import logging
24
- from diffusers.loaders import TextualInversionLoaderMixin
25
- from diffusers.pipelines.flux.pipeline_flux import FluxLoraLoaderMixin
26
- from diffusers.models.transformers.transformer_flux import (
27
- USE_PEFT_BACKEND,
28
- scale_lora_layers,
29
- unscale_lora_layers,
30
- logger,
31
- )
32
- from torchvision.transforms import ToPILImage
33
- from peft.tuners.tuners_utils import BaseTunerLayer
34
- # from optimum.quanto import (
35
- # freeze, quantize, QTensor, qfloat8, qint8, qint4, qint2,
36
- # )
37
- import re
38
- import safetensors
39
- from src.adapters.mod_adapters import CLIPModAdapter
40
- from peft import LoraConfig, set_peft_model_state_dict
41
- from transformers import CLIPProcessor, CLIPModel, CLIPVisionModelWithProjection, CLIPVisionModel
42
-
43
-
44
- def encode_vae_images(pipeline: FluxPipeline, images: Tensor):
45
- images = pipeline.image_processor.preprocess(images)
46
- images = images.to(pipeline.device).to(pipeline.dtype)
47
- images = pipeline.vae.encode(images).latent_dist.sample()
48
- images = (
49
- images - pipeline.vae.config.shift_factor
50
- ) * pipeline.vae.config.scaling_factor
51
- images_tokens = pipeline._pack_latents(images, *images.shape)
52
- images_ids = pipeline._prepare_latent_image_ids(
53
- images.shape[0],
54
- images.shape[2],
55
- images.shape[3],
56
- pipeline.device,
57
- pipeline.dtype,
58
- )
59
- if images_tokens.shape[1] != images_ids.shape[0]:
60
- images_ids = pipeline._prepare_latent_image_ids(
61
- images.shape[0],
62
- images.shape[2] // 2,
63
- images.shape[3] // 2,
64
- pipeline.device,
65
- pipeline.dtype,
66
- )
67
- return images_tokens, images_ids
68
-
69
- def decode_vae_images(pipeline: FluxPipeline, latents: Tensor, height, width, output_type: Optional[str] = "pil"):
70
- latents = pipeline._unpack_latents(latents, height, width, pipeline.vae_scale_factor)
71
- latents = (latents / pipeline.vae.config.scaling_factor) + pipeline.vae.config.shift_factor
72
- image = pipeline.vae.decode(latents, return_dict=False)[0]
73
- return pipeline.image_processor.postprocess(image, output_type=output_type)
74
-
75
-
76
- def _get_clip_prompt_embeds(
77
- self,
78
- prompt: Union[str, List[str]],
79
- num_images_per_prompt: int = 1,
80
- device: Optional[torch.device] = None,
81
- ):
82
- device = device or self._execution_device
83
-
84
- prompt = [prompt] if isinstance(prompt, str) else prompt
85
- batch_size = len(prompt)
86
-
87
- if isinstance(self, TextualInversionLoaderMixin):
88
- prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
89
-
90
- text_inputs = self.tokenizer(
91
- prompt,
92
- padding="max_length",
93
- max_length=self.tokenizer_max_length,
94
- truncation=True,
95
- return_overflowing_tokens=False,
96
- return_length=False,
97
- return_tensors="pt",
98
- )
99
-
100
- text_input_ids = text_inputs.input_ids
101
-
102
- prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
103
-
104
- # Use pooled output of CLIPTextModel
105
- prompt_embeds = prompt_embeds.pooler_output
106
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
107
-
108
- # duplicate text embeddings for each generation per prompt, using mps friendly method
109
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
110
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
111
-
112
- return prompt_embeds
113
-
114
- def encode_prompt_with_clip_t5(
115
- self,
116
- prompt: Union[str, List[str]],
117
- prompt_2: Union[str, List[str]],
118
- device: Optional[torch.device] = None,
119
- num_images_per_prompt: int = 1,
120
- prompt_embeds: Optional[torch.FloatTensor] = None,
121
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
122
- max_sequence_length: int = 512,
123
- lora_scale: Optional[float] = None,
124
- ):
125
- r"""
126
-
127
- Args:
128
- prompt (`str` or `List[str]`, *optional*):
129
- prompt to be encoded
130
- prompt_2 (`str` or `List[str]`, *optional*):
131
- The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
132
- used in all text-encoders
133
- device: (`torch.device`):
134
- torch device
135
- num_images_per_prompt (`int`):
136
- number of images that should be generated per prompt
137
- prompt_embeds (`torch.FloatTensor`, *optional*):
138
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
139
- provided, text embeddings will be generated from `prompt` input argument.
140
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
141
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
142
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
143
- lora_scale (`float`, *optional*):
144
- A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
145
- """
146
- device = device or self._execution_device
147
-
148
- # set lora scale so that monkey patched LoRA
149
- # function of text encoder can correctly access it
150
- if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
151
- self._lora_scale = lora_scale
152
-
153
- # dynamically adjust the LoRA scale
154
- if self.text_encoder is not None and USE_PEFT_BACKEND:
155
- scale_lora_layers(self.text_encoder, lora_scale)
156
- if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
157
- scale_lora_layers(self.text_encoder_2, lora_scale)
158
-
159
- prompt = [prompt] if isinstance(prompt, str) else prompt
160
-
161
- if prompt_embeds is None:
162
- prompt_2 = prompt_2 or prompt
163
- prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
164
-
165
- # We only use the pooled prompt output from the CLIPTextModel
166
- pooled_prompt_embeds = _get_clip_prompt_embeds(
167
- self=self,
168
- prompt=prompt,
169
- device=device,
170
- num_images_per_prompt=num_images_per_prompt,
171
- )
172
- if self.text_encoder_2 is not None:
173
- prompt_embeds = self._get_t5_prompt_embeds(
174
- prompt=prompt_2,
175
- num_images_per_prompt=num_images_per_prompt,
176
- max_sequence_length=max_sequence_length,
177
- device=device,
178
- )
179
-
180
- if self.text_encoder is not None:
181
- if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
182
- # Retrieve the original scale by scaling back the LoRA layers
183
- unscale_lora_layers(self.text_encoder, lora_scale)
184
-
185
- if self.text_encoder_2 is not None:
186
- if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
187
- # Retrieve the original scale by scaling back the LoRA layers
188
- unscale_lora_layers(self.text_encoder_2, lora_scale)
189
-
190
- dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
191
- if self.text_encoder_2 is not None:
192
- text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
193
- else:
194
- text_ids = None
195
-
196
- return prompt_embeds, pooled_prompt_embeds, text_ids
197
-
198
-
199
-
200
- def prepare_text_input(
201
- pipeline: FluxPipeline,
202
- prompts,
203
- max_sequence_length=512,
204
- ):
205
- # Turn off warnings (CLIP overflow)
206
- logger.setLevel(logging.ERROR)
207
- (
208
- t5_prompt_embeds,
209
- pooled_prompt_embeds,
210
- text_ids,
211
- ) = encode_prompt_with_clip_t5(
212
- self=pipeline,
213
- prompt=prompts,
214
- prompt_2=None,
215
- prompt_embeds=None,
216
- pooled_prompt_embeds=None,
217
- device=pipeline.device,
218
- num_images_per_prompt=1,
219
- max_sequence_length=max_sequence_length,
220
- lora_scale=None,
221
- )
222
- # Turn on warnings
223
- logger.setLevel(logging.WARNING)
224
- return t5_prompt_embeds, pooled_prompt_embeds, text_ids
225
-
226
- def prepare_t5_input(
227
- pipeline: FluxPipeline,
228
- prompts,
229
- max_sequence_length=512,
230
- ):
231
- # Turn off warnings (CLIP overflow)
232
- logger.setLevel(logging.ERROR)
233
- (
234
- t5_prompt_embeds,
235
- pooled_prompt_embeds,
236
- text_ids,
237
- ) = encode_prompt_with_clip_t5(
238
- self=pipeline,
239
- prompt=prompts,
240
- prompt_2=None,
241
- prompt_embeds=None,
242
- pooled_prompt_embeds=None,
243
- device=pipeline.device,
244
- num_images_per_prompt=1,
245
- max_sequence_length=max_sequence_length,
246
- lora_scale=None,
247
- )
248
- # Turn on warnings
249
- logger.setLevel(logging.WARNING)
250
- return t5_prompt_embeds, pooled_prompt_embeds, text_ids
251
-
252
- def tokenize_t5_prompt(pipe, input_prompt, max_length, **kargs):
253
- return pipe.tokenizer_2(
254
- input_prompt,
255
- padding="max_length",
256
- max_length=max_length,
257
- truncation=True,
258
- return_length=False,
259
- return_overflowing_tokens=False,
260
- return_tensors="pt",
261
- **kargs,
262
- )
263
-
264
- def clear_attn_maps(transformer):
265
- for i, block in enumerate(transformer.transformer_blocks):
266
- if hasattr(block.attn, "attn_maps"):
267
- del block.attn.attn_maps
268
- del block.attn.timestep
269
- for i, block in enumerate(transformer.single_transformer_blocks):
270
- if hasattr(block.attn, "cond2latents"):
271
- del block.attn.cond2latents
272
-
273
- def gather_attn_maps(transformer, clear=False):
274
- t2i_attn_maps = {}
275
- i2t_attn_maps = {}
276
- for i, block in enumerate(transformer.transformer_blocks):
277
- name = f"block_{i}"
278
- if hasattr(block.attn, "attn_maps"):
279
- attention_maps = block.attn.attn_maps
280
- timesteps = block.attn.timestep # (B,)
281
- for (timestep, (t2i_attn_map, i2t_attn_map)) in zip(timesteps, attention_maps):
282
- timestep = str(timestep.item())
283
-
284
- t2i_attn_maps[timestep] = t2i_attn_maps.get(timestep, dict())
285
- t2i_attn_maps[timestep][name] = t2i_attn_maps[timestep].get(name, [])
286
- t2i_attn_maps[timestep][name].append(t2i_attn_map.cpu())
287
-
288
- i2t_attn_maps[timestep] = i2t_attn_maps.get(timestep, dict())
289
- i2t_attn_maps[timestep][name] = i2t_attn_maps[timestep].get(name, [])
290
- i2t_attn_maps[timestep][name].append(i2t_attn_map.cpu())
291
-
292
- if clear:
293
- del block.attn.attn_maps
294
-
295
- for timestep in t2i_attn_maps:
296
- for name in t2i_attn_maps[timestep]:
297
- t2i_attn_maps[timestep][name] = torch.cat(t2i_attn_maps[timestep][name], dim=0)
298
- i2t_attn_maps[timestep][name] = torch.cat(i2t_attn_maps[timestep][name], dim=0)
299
-
300
- return t2i_attn_maps, i2t_attn_maps
301
-
302
- def process_token(token, startofword):
303
- if '</w>' in token:
304
- token = token.replace('</w>', '')
305
- if startofword:
306
- token = '<' + token + '>'
307
- else:
308
- token = '-' + token + '>'
309
- startofword = True
310
- elif token not in ['<|startoftext|>', '<|endoftext|>']:
311
- if startofword:
312
- token = '<' + token + '-'
313
- startofword = False
314
- else:
315
- token = '-' + token + '-'
316
- return token, startofword
317
-
318
- def save_attention_image(attn_map, tokens, batch_dir, to_pil):
319
- startofword = True
320
- for i, (token, a) in enumerate(zip(tokens, attn_map[:len(tokens)])):
321
- token, startofword = process_token(token, startofword)
322
- token = token.replace("/", "-")
323
- if token == '-<pad>-':
324
- continue
325
- a = a.to(torch.float32)
326
- a = a / a.max() * 255 / 256
327
- to_pil(a).save(os.path.join(batch_dir, f'{i}-{token}.png'))
328
-
329
- def save_attention_maps(attn_maps, pipe, prompts, base_dir='attn_maps'):
330
- to_pil = ToPILImage()
331
-
332
- token_ids = tokenize_t5_prompt(pipe, prompts, 512).input_ids # (B, 512)
333
- token_ids = [x for x in token_ids]
334
- total_tokens = [pipe.tokenizer_2.convert_ids_to_tokens(token_id) for token_id in token_ids]
335
-
336
- os.makedirs(base_dir, exist_ok=True)
337
-
338
- total_attn_map_shape = (256, 256)
339
- total_attn_map_number = 0
340
-
341
- # (B, 24, H, W, 512) -> (B, H, W, 512) -> (B, 512, H, W)
342
- print(attn_maps.keys())
343
- total_attn_map = list(list(attn_maps.values())[0].values())[0].sum(1)
344
- total_attn_map = total_attn_map.permute(0, 3, 1, 2)
345
- total_attn_map = torch.zeros_like(total_attn_map)
346
- total_attn_map = F.interpolate(total_attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False)
347
-
348
- for timestep, layers in attn_maps.items():
349
- timestep_dir = os.path.join(base_dir, f'{timestep}')
350
- os.makedirs(timestep_dir, exist_ok=True)
351
-
352
- for layer, attn_map in layers.items():
353
- layer_dir = os.path.join(timestep_dir, f'{layer}')
354
- os.makedirs(layer_dir, exist_ok=True)
355
-
356
- attn_map = attn_map.sum(1).squeeze(1).permute(0, 3, 1, 2)
357
-
358
- resized_attn_map = F.interpolate(attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False)
359
- total_attn_map += resized_attn_map
360
- total_attn_map_number += 1
361
-
362
- for batch, (attn_map, tokens) in enumerate(zip(resized_attn_map, total_tokens)):
363
- save_attention_image(attn_map, tokens, layer_dir, to_pil)
364
-
365
- # for batch, (tokens, attn) in enumerate(zip(total_tokens, attn_map)):
366
- # batch_dir = os.path.join(layer_dir, f'batch-{batch}')
367
- # os.makedirs(batch_dir, exist_ok=True)
368
- # save_attention_image(attn, tokens, batch_dir, to_pil)
369
-
370
- total_attn_map /= total_attn_map_number
371
- for batch, (attn_map, tokens) in enumerate(zip(total_attn_map, total_tokens)):
372
- batch_dir = os.path.join(base_dir, f'batch-{batch}')
373
- os.makedirs(batch_dir, exist_ok=True)
374
- save_attention_image(attn_map, tokens, batch_dir, to_pil)
375
-
376
- def gather_cond2latents(transformer, clear=False):
377
- c2l_attn_maps = {}
378
- # for i, block in enumerate(transformer.transformer_blocks):
379
- for i, block in enumerate(transformer.single_transformer_blocks):
380
- name = f"block_{i}"
381
- if hasattr(block.attn, "cond2latents"):
382
- attention_maps = block.attn.cond2latents
383
- timesteps = block.attn.cond_timesteps # (B,)
384
- for (timestep, c2l_attn_map) in zip(timesteps, attention_maps):
385
- timestep = str(timestep.item())
386
-
387
- c2l_attn_maps[timestep] = c2l_attn_maps.get(timestep, dict())
388
- c2l_attn_maps[timestep][name] = c2l_attn_maps[timestep].get(name, [])
389
- c2l_attn_maps[timestep][name].append(c2l_attn_map.cpu())
390
-
391
- if clear:
392
- # del block.attn.attn_maps
393
- del block.attn.cond2latents
394
- del block.attn.cond_timesteps
395
-
396
- for timestep in c2l_attn_maps:
397
- for name in c2l_attn_maps[timestep]:
398
- c2l_attn_maps[timestep][name] = torch.cat(c2l_attn_maps[timestep][name], dim=0)
399
-
400
- return c2l_attn_maps
401
-
402
- def save_cond2latent_image(attn_map, batch_dir, to_pil):
403
- for i, a in enumerate(attn_map): # (N, H, W)
404
- a = a.to(torch.float32)
405
- a = a / a.max() * 255 / 256
406
- to_pil(a).save(os.path.join(batch_dir, f'{i}.png'))
407
-
408
- def save_cond2latent(attn_maps, base_dir='attn_maps'):
409
- to_pil = ToPILImage()
410
-
411
- os.makedirs(base_dir, exist_ok=True)
412
-
413
- total_attn_map_shape = (256, 256)
414
- total_attn_map_number = 0
415
-
416
- # (N, H, W) -> (1, N, H, W)
417
- total_attn_map = list(list(attn_maps.values())[0].values())[0].unsqueeze(0)
418
- total_attn_map = torch.zeros_like(total_attn_map)
419
- total_attn_map = F.interpolate(total_attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False)
420
-
421
- for timestep, layers in attn_maps.items():
422
- cur_ts_attn_map = torch.zeros_like(total_attn_map)
423
- cur_ts_attn_map_number = 0
424
-
425
- timestep_dir = os.path.join(base_dir, f'{timestep}')
426
- os.makedirs(timestep_dir, exist_ok=True)
427
-
428
- for layer, attn_map in layers.items():
429
- # layer_dir = os.path.join(timestep_dir, f'{layer}')
430
- # os.makedirs(layer_dir, exist_ok=True)
431
-
432
- attn_map = attn_map.unsqueeze(0) # (1, N, H, W)
433
- resized_attn_map = F.interpolate(attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False)
434
-
435
- cur_ts_attn_map += resized_attn_map
436
- cur_ts_attn_map_number += 1
437
-
438
- for batch, attn_map in enumerate(cur_ts_attn_map / cur_ts_attn_map_number):
439
- save_cond2latent_image(attn_map, timestep_dir, to_pil)
440
-
441
- total_attn_map += cur_ts_attn_map
442
- total_attn_map_number += cur_ts_attn_map_number
443
-
444
- total_attn_map /= total_attn_map_number
445
- for batch, attn_map in enumerate(total_attn_map):
446
- batch_dir = os.path.join(base_dir, f'batch-{batch}')
447
- os.makedirs(batch_dir, exist_ok=True)
448
- save_cond2latent_image(attn_map, batch_dir, to_pil)
449
-
450
- def quantization(pipe, qtype):
451
- if qtype != "None" and qtype != "":
452
- if qtype.endswith("quanto"):
453
- if qtype == "int2-quanto":
454
- quant_level = qint2
455
- elif qtype == "int4-quanto":
456
- quant_level = qint4
457
- elif qtype == "int8-quanto":
458
- quant_level = qint8
459
- elif qtype == "fp8-quanto":
460
- quant_level = qfloat8
461
- else:
462
- raise ValueError(f"Invalid quantisation level: {qtype}")
463
-
464
- extra_quanto_args = {}
465
- extra_quanto_args["exclude"] = [
466
- "*.norm",
467
- "*.norm1",
468
- "*.norm2",
469
- "*.norm2_context",
470
- "proj_out",
471
- "x_embedder",
472
- "norm_out",
473
- "context_embedder",
474
- ]
475
- try:
476
- quantize(pipe.transformer, weights=quant_level, **extra_quanto_args)
477
- quantize(pipe.text_encoder_2, weights=quant_level, **extra_quanto_args)
478
- print("[Quantization] Start freezing")
479
- freeze(pipe.transformer)
480
- freeze(pipe.text_encoder_2)
481
- print("[Quantization] Finished")
482
- except Exception as e:
483
- if "out of memory" in str(e).lower():
484
- print(
485
- "GPU ran out of memory during quantisation. Use --quantize_via=cpu to use the slower CPU method."
486
- )
487
- raise e
488
- else:
489
- assert qtype == "fp8-ao"
490
- from torchao.float8 import convert_to_float8_training, Float8LinearConfig
491
- def module_filter_fn(mod: torch.nn.Module, fqn: str):
492
- # don't convert the output module
493
- if fqn == "proj_out":
494
- return False
495
- # don't convert linear modules with weight dimensions not divisible by 16
496
- if isinstance(mod, torch.nn.Linear):
497
- if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
498
- return False
499
- return True
500
- convert_to_float8_training(
501
- pipe.transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
502
- )
503
-
504
- class CustomFluxPipeline:
505
- def __init__(
506
- self,
507
- config,
508
- device="cuda",
509
- ckpt_root=None,
510
- ckpt_root_condition=None,
511
- torch_dtype=torch.bfloat16,
512
- ):
513
- model_path = os.getenv("FLUX_MODEL_PATH", "black-forest-labs/FLUX.1-dev")
514
- print("[CustomFluxPipeline] Loading FLUX Pipeline")
515
- self.pipe = FluxPipeline.from_pretrained(model_path, torch_dtype=torch_dtype).to(device)
516
-
517
- self.config = config
518
- self.device = device
519
- self.dtype = torch_dtype
520
- if config["model"].get("dit_quant", "None") != "None":
521
- quantization(self.pipe, config["model"]["dit_quant"])
522
-
523
- self.modulation_adapters = []
524
- self.pipe.modulation_adapters = []
525
-
526
- try:
527
- if config["model"]["modulation"]["use_clip"]:
528
- load_clip(self, config, torch_dtype, device, None, is_training=False)
529
- except Exception as e:
530
- print(e)
531
-
532
- if config["model"]["use_dit_lora"] or config["model"]["use_condition_dblock_lora"] or config["model"]["use_condition_sblock_lora"]:
533
- if ckpt_root_condition is None and (config["model"]["use_condition_dblock_lora"] or config["model"]["use_condition_sblock_lora"]):
534
- ckpt_root_condition = ckpt_root
535
- load_dit_lora(self, self.pipe, config, torch_dtype, device, f"{ckpt_root}", f"{ckpt_root_condition}", is_training=False)
536
-
537
- def add_modulation_adapter(self, modulation_adapter):
538
- self.modulation_adapters.append(modulation_adapter)
539
- self.pipe.modulation_adapters.append(modulation_adapter)
540
-
541
- def clear_modulation_adapters(self):
542
- self.modulation_adapters = []
543
- self.pipe.modulation_adapters = []
544
- torch.cuda.empty_cache()
545
-
546
- def load_clip(self, config, torch_dtype, device, ckpt_dir=None, is_training=False):
547
- model_path = os.getenv("CLIP_MODEL_PATH", "openai/clip-vit-large-patch14")
548
- clip_model = CLIPVisionModelWithProjection.from_pretrained(model_path).to(device, dtype=torch_dtype)
549
- clip_processor = CLIPProcessor.from_pretrained(model_path)
550
- self.pipe.clip_model = clip_model
551
- self.pipe.clip_processor = clip_processor
552
-
553
- def load_dit_lora(self, pipe, config, torch_dtype, device, ckpt_dir=None, condition_ckpt_dir=None, is_training=False):
554
-
555
- if not config["model"]["use_condition_dblock_lora"] and not config["model"]["use_condition_sblock_lora"] and not config["model"]["use_dit_lora"]:
556
- print("[load_dit_lora] no dit lora, no condition lora")
557
- return []
558
-
559
- adapter_names = ["default", "condition"]
560
-
561
- if condition_ckpt_dir is None:
562
- condition_ckpt_dir = ckpt_dir
563
-
564
- if not config["model"]["use_condition_dblock_lora"] and not config["model"]["use_condition_sblock_lora"]:
565
- print("[load_dit_lora] no condition lora")
566
- adapter_names.pop(1)
567
- elif condition_ckpt_dir is not None and os.path.exists(os.path.join(condition_ckpt_dir, "pytorch_lora_weights_condition.safetensors")):
568
- assert "condition" in adapter_names
569
- print(f"[load_dit_lora] load condition lora from {condition_ckpt_dir}")
570
- pipe.transformer.load_lora_adapter(condition_ckpt_dir, use_safetensors=True, adapter_name="condition", weight_name="pytorch_lora_weights_condition.safetensors") # TODO: check if they are trainable
571
- else:
572
- assert is_training
573
- assert "condition" in adapter_names
574
- print("[load_dit_lora] init new condition lora")
575
- pipe.transformer.add_adapter(LoraConfig(**config["model"]["condition_lora_config"]), adapter_name="condition")
576
-
577
- if not config["model"]["use_dit_lora"]:
578
- print("[load_dit_lora] no dit lora")
579
- adapter_names.pop(0)
580
- elif ckpt_dir is not None and os.path.exists(os.path.join(ckpt_dir, "pytorch_lora_weights.safetensors")):
581
- assert "default" in adapter_names
582
- print(f"[load_dit_lora] load dit lora from {ckpt_dir}")
583
- lora_file = os.path.join(ckpt_dir, "pytorch_lora_weights.safetensors")
584
- lora_state_dict = safetensors.torch.load_file(lora_file, device="cpu")
585
-
586
- single_lora_pattern = "(.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
587
- latent_lora_pattern = "(.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2)"
588
- use_pretrained_dit_single_lora = config["model"].get("use_pretrained_dit_single_lora", True)
589
- use_pretrained_dit_latent_lora = config["model"].get("use_pretrained_dit_latent_lora", True)
590
- if not use_pretrained_dit_single_lora or not use_pretrained_dit_latent_lora:
591
- lora_state_dict_keys = list(lora_state_dict.keys())
592
- for layer_name in lora_state_dict_keys:
593
- if not use_pretrained_dit_single_lora:
594
- if re.search(single_lora_pattern, layer_name):
595
- del lora_state_dict[layer_name]
596
- if not use_pretrained_dit_latent_lora:
597
- if re.search(latent_lora_pattern, layer_name):
598
- del lora_state_dict[layer_name]
599
- pipe.transformer.add_adapter(LoraConfig(**config["model"]["dit_lora_config"]), adapter_name="default")
600
- set_peft_model_state_dict(pipe.transformer, lora_state_dict, adapter_name="default")
601
- else:
602
- pipe.transformer.load_lora_adapter(ckpt_dir, use_safetensors=True, adapter_name="default", weight_name="pytorch_lora_weights.safetensors") # TODO: check if they are trainable
603
- else:
604
- assert is_training
605
- assert "default" in adapter_names
606
- print("[load_dit_lora] init new dit lora")
607
- pipe.transformer.add_adapter(LoraConfig(**config["model"]["dit_lora_config"]), adapter_name="default")
608
-
609
- assert len(adapter_names) <= 2 and len(adapter_names) > 0
610
- for name, module in pipe.transformer.named_modules():
611
- if isinstance(module, BaseTunerLayer):
612
- module.set_adapter(adapter_names)
613
-
614
- if "default" in adapter_names: assert config["model"]["use_dit_lora"]
615
- if "condition" in adapter_names: assert config["model"]["use_condition_dblock_lora"] or config["model"]["use_condition_sblock_lora"]
616
-
617
- lora_layers = list(filter(
618
- lambda p: p[1].requires_grad, pipe.transformer.named_parameters()
619
- ))
620
-
621
- lora_layers = [l[1] for l in lora_layers]
622
- return lora_layers
623
-
624
- def load_modulation_adapter(self, config, torch_dtype, device, ckpt_dir=None, is_training=False):
625
- adapter_type = config["model"]["modulation"]["adapter_type"]
626
-
627
- if ckpt_dir is not None and os.path.exists(ckpt_dir):
628
- print(f"loading modulation adapter from {ckpt_dir}")
629
- modulation_adapter = CLIPModAdapter.from_pretrained(
630
- ckpt_dir, subfolder="modulation_adapter", strict=False,
631
- low_cpu_mem_usage=False, device_map=None,
632
- ).to(device)
633
- else:
634
- print(f"Init new modulation adapter")
635
- adapter_layers = config["model"]["modulation"]["adapter_layers"]
636
- adapter_width = config["model"]["modulation"]["adapter_width"]
637
- pblock_adapter_layers = config["model"]["modulation"]["per_block_adapter_layers"]
638
- pblock_adapter_width = config["model"]["modulation"]["per_block_adapter_width"]
639
- pblock_adapter_single_blocks = config["model"]["modulation"]["per_block_adapter_single_blocks"]
640
- use_text_mod = config["model"]["modulation"]["use_text_mod"]
641
- use_img_mod = config["model"]["modulation"]["use_img_mod"]
642
-
643
- out_dim = config["model"]["modulation"]["out_dim"]
644
- if adapter_type == "clip_adapter":
645
- modulation_adapter = CLIPModAdapter(
646
- out_dim=out_dim,
647
- width=adapter_width,
648
- pblock_width=pblock_adapter_width,
649
- layers=adapter_layers,
650
- pblock_layers=pblock_adapter_layers,
651
- heads=8,
652
- input_text_dim=4096,
653
- input_image_dim=1024,
654
- pblock_single_blocks=pblock_adapter_single_blocks,
655
- )
656
- else:
657
- raise NotImplementedError()
658
-
659
- if is_training:
660
- modulation_adapter.train()
661
- try:
662
- modulation_adapter.enable_gradient_checkpointing()
663
- except Exception as e:
664
- print(e)
665
- if not config["model"]["modulation"]["use_perblock_adapter"]:
666
- try:
667
- modulation_adapter.net2.requires_grad_(False)
668
- except Exception as e:
669
- print(e)
670
- else:
671
- modulation_adapter.requires_grad_(False)
672
-
673
- modulation_adapter.to(device, dtype=torch_dtype)
674
- return modulation_adapter
675
-
676
-
677
- def load_ckpt(self, ckpt_dir, is_training=False):
678
- if self.config["model"]["use_dit_lora"]:
679
- self.pipe.transformer.delete_adapters(["subject"])
680
- lora_path = f"{ckpt_dir}/pytorch_lora_weights.safetensors"
681
- print(f"Loading DIT Lora from {lora_path}")
682
- self.pipe.load_lora_weights(lora_path, adapter_name="subject")
683
-
684
-
685
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/flux/transformer.py DELETED
@@ -1,363 +0,0 @@
1
- # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import torch
17
- from diffusers.pipelines import FluxPipeline
18
- from typing import List, Union, Optional, Dict, Any, Callable
19
- from .block import block_forward, single_block_forward
20
- from .lora_controller import enable_lora
21
- from diffusers.models.transformers.transformer_flux import (
22
- FluxTransformer2DModel,
23
- Transformer2DModelOutput,
24
- USE_PEFT_BACKEND,
25
- scale_lora_layers,
26
- unscale_lora_layers,
27
- logger,
28
- )
29
- import numpy as np
30
-
31
- import numpy as np
32
- import torch
33
- import torch.nn as nn
34
- import torch.nn.functional as F
35
-
36
-
37
- def prepare_params(
38
- hidden_states: torch.Tensor,
39
- encoder_hidden_states: torch.Tensor = None,
40
- pooled_projections: torch.Tensor = None,
41
- timestep: torch.LongTensor = None,
42
- img_ids: torch.Tensor = None,
43
- txt_ids: torch.Tensor = None,
44
- guidance: torch.Tensor = None,
45
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
46
- controlnet_block_samples=None,
47
- controlnet_single_block_samples=None,
48
- return_dict: bool = True,
49
- **kwargs: dict,
50
- ):
51
- return (
52
- hidden_states,
53
- encoder_hidden_states,
54
- pooled_projections,
55
- timestep,
56
- img_ids,
57
- txt_ids,
58
- guidance,
59
- joint_attention_kwargs,
60
- controlnet_block_samples,
61
- controlnet_single_block_samples,
62
- return_dict,
63
- )
64
-
65
- def is_torch_version(spec: str) -> bool:
66
- # e.g. spec = ">=1.12.0"
67
- return version.parse(torch.__version__) in version.SpecifierSet(spec)
68
-
69
- def tranformer_forward(
70
- transformer: FluxTransformer2DModel,
71
- condition_latents: torch.Tensor,
72
- condition_ids: torch.Tensor,
73
- condition_type_ids: torch.Tensor,
74
- model_config: Optional[Dict[str, Any]] = {},
75
- c_t=0,
76
- text_cond_mask: Optional[torch.FloatTensor] = None,
77
- delta_emb: Optional[torch.FloatTensor] = None,
78
- delta_emb_pblock: Optional[torch.FloatTensor] = None,
79
- delta_emb_mask: Optional[torch.FloatTensor] = None,
80
- delta_start_ends = None,
81
- store_attn_map: bool = False,
82
- use_text_mod: bool = True,
83
- use_img_mod: bool = False,
84
- mod_adapter = None,
85
- latent_height: Optional[int] = None,
86
- last_attn_map = None,
87
- **params: dict,
88
- ):
89
- self = transformer
90
- use_condition = condition_latents is not None
91
-
92
- (
93
- hidden_states,
94
- encoder_hidden_states,
95
- pooled_projections,
96
- timestep,
97
- img_ids,
98
- txt_ids,
99
- guidance,
100
- joint_attention_kwargs,
101
- controlnet_block_samples,
102
- controlnet_single_block_samples,
103
- return_dict,
104
- ) = prepare_params(**params)
105
-
106
- if joint_attention_kwargs is not None:
107
- joint_attention_kwargs = joint_attention_kwargs.copy()
108
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
109
- latent_sblora_weight = joint_attention_kwargs.pop("latent_sblora_weight", None)
110
- condition_sblora_weight = joint_attention_kwargs.pop("condition_sblora_weight", None)
111
- else:
112
- lora_scale = 1.0
113
- latent_sblora_weight = None
114
- condition_sblora_weight = None
115
- if USE_PEFT_BACKEND:
116
- # weight the lora layers by setting `lora_scale` for each PEFT layer
117
- scale_lora_layers(self, lora_scale)
118
- else:
119
- if (
120
- joint_attention_kwargs is not None
121
- and joint_attention_kwargs.get("scale", None) is not None
122
- ):
123
- logger.warning(
124
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
125
- )
126
-
127
- train_partial_text_lora = model_config.get("train_partial_text_lora", False)
128
- train_partial_latent_lora = model_config.get("train_partial_latent_lora", False)
129
-
130
- if train_partial_text_lora or train_partial_latent_lora:
131
- train_partial_text_lora_layers = model_config.get("train_partial_text_lora_layers", "")
132
- train_partial_latent_lora_layers = model_config.get("train_partial_latent_lora_layers", "")
133
- activate_x_embedder = True
134
- if "x_embedder" not in train_partial_text_lora_layers or "x_embedder" not in train_partial_latent_lora_layers:
135
- activate_x_embedder = False
136
- if train_partial_text_lora or train_partial_latent_lora:
137
- activate_x_embedder_ = activate_x_embedder
138
- else:
139
- activate_x_embedder_ = model_config["latent_lora"] or model_config["text_lora"]
140
-
141
- with enable_lora((self.x_embedder,), activate_x_embedder_):
142
- hidden_states = self.x_embedder(hidden_states)
143
- cond_lora_activate = model_config["use_condition_dblock_lora"] or model_config["use_condition_sblock_lora"]
144
- with enable_lora(
145
- (self.x_embedder,),
146
- dit_activated=activate_x_embedder if train_partial_text_lora or train_partial_latent_lora else not cond_lora_activate, cond_activated=cond_lora_activate,
147
- ):
148
- condition_latents = self.x_embedder(condition_latents) if use_condition else None
149
-
150
- timestep = timestep.to(hidden_states.dtype) * 1000
151
-
152
- if guidance is not None:
153
- guidance = guidance.to(hidden_states.dtype) * 1000
154
- else:
155
- guidance = None
156
-
157
- temb = (
158
- self.time_text_embed(timestep, pooled_projections)
159
- if guidance is None
160
- else self.time_text_embed(timestep, guidance, pooled_projections)
161
- ) # (B, 3072)
162
-
163
- cond_temb = (
164
- self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
165
- if guidance is None
166
- else self.time_text_embed(
167
- torch.ones_like(timestep) * c_t * 1000, guidance, pooled_projections
168
- )
169
- )
170
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
171
-
172
- if txt_ids.ndim == 3:
173
- logger.warning(
174
- "Passing `txt_ids` 3d torch.Tensor is deprecated."
175
- "Please remove the batch dimension and pass it as a 2d torch Tensor"
176
- )
177
- txt_ids = txt_ids[0]
178
- if img_ids.ndim == 3:
179
- logger.warning(
180
- "Passing `img_ids` 3d torch.Tensor is deprecated."
181
- "Please remove the batch dimension and pass it as a 2d torch Tensor"
182
- )
183
- img_ids = img_ids[0]
184
-
185
- ids = torch.cat((txt_ids, img_ids), dim=0)
186
- image_rotary_emb = self.pos_embed(ids)
187
- if use_condition:
188
- cond_rotary_emb = self.pos_embed(condition_ids)
189
-
190
- for index_block, block in enumerate(self.transformer_blocks):
191
- if delta_emb_pblock is None:
192
- delta_emb_cblock = None
193
- else:
194
- delta_emb_cblock = delta_emb_pblock[:, :, index_block]
195
- condition_pass_to_double = use_condition and (model_config["double_use_condition"] or model_config["single_use_condition"])
196
- if self.training and self.gradient_checkpointing:
197
- ckpt_kwargs: Dict[str, Any] = (
198
- {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
199
- )
200
-
201
- encoder_hidden_states, hidden_states, condition_latents = (
202
- torch.utils.checkpoint.checkpoint(
203
- block_forward,
204
- self=block,
205
- model_config=model_config,
206
- hidden_states=hidden_states,
207
- encoder_hidden_states=encoder_hidden_states,
208
- condition_latents=condition_latents if condition_pass_to_double else None,
209
- cond_temb=cond_temb if condition_pass_to_double else None,
210
- cond_rotary_emb=cond_rotary_emb if condition_pass_to_double else None,
211
- temb=temb,
212
- text_cond_mask=text_cond_mask,
213
- delta_emb=delta_emb,
214
- delta_emb_cblock=delta_emb_cblock,
215
- delta_emb_mask=delta_emb_mask,
216
- delta_start_ends=delta_start_ends,
217
- image_rotary_emb=image_rotary_emb,
218
- store_attn_map=store_attn_map,
219
- use_text_mod=use_text_mod,
220
- use_img_mod=use_img_mod,
221
- mod_adapter=mod_adapter,
222
- latent_height=latent_height,
223
- timestep=timestep,
224
- last_attn_map=last_attn_map,
225
- **ckpt_kwargs,
226
- )
227
- )
228
-
229
- else:
230
- encoder_hidden_states, hidden_states, condition_latents = block_forward(
231
- block,
232
- model_config=model_config,
233
- hidden_states=hidden_states,
234
- encoder_hidden_states=encoder_hidden_states,
235
- condition_latents=condition_latents if condition_pass_to_double else None,
236
- cond_temb=cond_temb if condition_pass_to_double else None,
237
- cond_rotary_emb=cond_rotary_emb if condition_pass_to_double else None,
238
- temb=temb,
239
- text_cond_mask=text_cond_mask,
240
- delta_emb=delta_emb,
241
- delta_emb_cblock=delta_emb_cblock,
242
- delta_emb_mask=delta_emb_mask,
243
- delta_start_ends=delta_start_ends,
244
- image_rotary_emb=image_rotary_emb,
245
- store_attn_map=store_attn_map,
246
- use_text_mod=use_text_mod,
247
- use_img_mod=use_img_mod,
248
- mod_adapter=mod_adapter,
249
- latent_height=latent_height,
250
- timestep=timestep,
251
- last_attn_map=last_attn_map,
252
- )
253
-
254
- # controlnet residual
255
- if controlnet_block_samples is not None:
256
- interval_control = len(self.transformer_blocks) / len(
257
- controlnet_block_samples
258
- )
259
- interval_control = int(np.ceil(interval_control))
260
- hidden_states = (
261
- hidden_states
262
- + controlnet_block_samples[index_block // interval_control]
263
- )
264
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
265
-
266
- for index_block, block in enumerate(self.single_transformer_blocks):
267
- if delta_emb_pblock is not None and delta_emb_pblock.shape[2] > 19+index_block:
268
- delta_emb_single = delta_emb
269
- delta_emb_cblock = delta_emb_pblock[:, :, index_block+19]
270
- else:
271
- delta_emb_single = None
272
- delta_emb_cblock = None
273
- if self.training and self.gradient_checkpointing:
274
- ckpt_kwargs: Dict[str, Any] = (
275
- {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
276
- )
277
- result = torch.utils.checkpoint.checkpoint(
278
- single_block_forward,
279
- self=block,
280
- model_config=model_config,
281
- hidden_states=hidden_states,
282
- temb=temb,
283
- delta_emb=delta_emb_single,
284
- delta_emb_cblock=delta_emb_cblock,
285
- delta_emb_mask=delta_emb_mask,
286
- use_text_mod=use_text_mod,
287
- use_img_mod=use_img_mod,
288
- image_rotary_emb=image_rotary_emb,
289
- last_attn_map=last_attn_map,
290
- latent_height=latent_height,
291
- timestep=timestep,
292
- store_attn_map=store_attn_map,
293
- **(
294
- {
295
- "condition_latents": condition_latents,
296
- "cond_temb": cond_temb,
297
- "cond_rotary_emb": cond_rotary_emb,
298
- "text_cond_mask": text_cond_mask,
299
- }
300
- if use_condition and model_config["single_use_condition"]
301
- else {}
302
- ),
303
- **ckpt_kwargs,
304
- )
305
-
306
- else:
307
- result = single_block_forward(
308
- block,
309
- model_config=model_config,
310
- hidden_states=hidden_states,
311
- temb=temb,
312
- delta_emb=delta_emb_single,
313
- delta_emb_cblock=delta_emb_cblock,
314
- delta_emb_mask=delta_emb_mask,
315
- use_text_mod=use_text_mod,
316
- use_img_mod=use_img_mod,
317
- image_rotary_emb=image_rotary_emb,
318
- last_attn_map=last_attn_map,
319
- latent_height=latent_height,
320
- timestep=timestep,
321
- store_attn_map=store_attn_map,
322
- latent_sblora_weight=latent_sblora_weight,
323
- condition_sblora_weight=condition_sblora_weight,
324
- **(
325
- {
326
- "condition_latents": condition_latents,
327
- "cond_temb": cond_temb,
328
- "cond_rotary_emb": cond_rotary_emb,
329
- "text_cond_mask": text_cond_mask,
330
- }
331
- if use_condition and model_config["single_use_condition"]
332
- else {}
333
- ),
334
- )
335
- if use_condition and model_config["single_use_condition"]:
336
- hidden_states, condition_latents = result
337
- else:
338
- hidden_states = result
339
-
340
- # controlnet residual
341
- if controlnet_single_block_samples is not None:
342
- interval_control = len(self.single_transformer_blocks) / len(
343
- controlnet_single_block_samples
344
- )
345
- interval_control = int(np.ceil(interval_control))
346
- hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
347
- hidden_states[:, encoder_hidden_states.shape[1] :, ...]
348
- + controlnet_single_block_samples[index_block // interval_control]
349
- )
350
-
351
- hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
352
-
353
- hidden_states = self.norm_out(hidden_states, temb)
354
- output = self.proj_out(hidden_states)
355
-
356
- if USE_PEFT_BACKEND:
357
- # remove `lora_scale` from each PEFT layer
358
- unscale_lora_layers(self, lora_scale)
359
-
360
- if not return_dict:
361
- return (output,)
362
- return Transformer2DModelOutput(sample=output)
363
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/data_utils.py DELETED
@@ -1,404 +0,0 @@
1
- # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import cv2
16
- import json
17
- import torch
18
- import random
19
- import base64
20
- import numpy as np
21
- from PIL import Image, ImageDraw
22
- from glob import glob
23
- from torchvision import transforms as T
24
- import os
25
- import gc
26
- from webdataset.filters import default_collation_fn, pipelinefilter
27
- import yaml
28
-
29
- def get_rank_and_worldsize():
30
- try:
31
- local_rank = int(os.environ.get("LOCAL_RANK"))
32
- global_rank = int(os.environ.get("RANK"))
33
- world_size = int(os.getenv('WORLD_SIZE', 1))
34
- except:
35
- local_rank = 0
36
- global_rank = 0
37
- world_size = 1
38
- return local_rank, global_rank, world_size
39
-
40
- def get_train_config(config_path=None):
41
- if config_path is None:
42
- config_path = os.environ.get("XFL_CONFIG")
43
- assert config_path is not None, "Please set the XFL_CONFIG environment variable"
44
- with open(config_path, "r") as f:
45
- config = yaml.safe_load(f)
46
- return config
47
-
48
- def calculate_aspect_ratios(resolution):
49
- ASPECT_RATIO = {
50
- '0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0],
51
- '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
52
- '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
53
- '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
54
- '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
55
- '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
56
- '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
57
- '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
58
- '2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0],
59
- '3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0]
60
- }
61
- NEW_ASPECT_RATIO = {}
62
- for ratio in ASPECT_RATIO:
63
- height, width = ASPECT_RATIO[ratio]
64
- width = round(width / 256 * resolution)
65
- height = round(height / 256 * resolution)
66
- if width % 8 != 0:
67
- print(f"skip train resolution {width}, {height}")
68
- continue
69
- if height % 8 != 0:
70
- print(f"skip train resolution {width}, {height}")
71
- continue
72
- NEW_ASPECT_RATIO[ratio] = [height, width]
73
- return NEW_ASPECT_RATIO
74
-
75
- ASPECT_RATIO_256 = calculate_aspect_ratios(256)
76
- ASPECT_RATIO_384 = calculate_aspect_ratios(384)
77
- ASPECT_RATIO_512 = calculate_aspect_ratios(512)
78
- ASPECT_RATIO_768 = calculate_aspect_ratios(768)
79
- ASPECT_RATIO_1024 = calculate_aspect_ratios(1024)
80
-
81
- def get_closest_ratio(height: float, width: float, ratios: dict):
82
- aspect_ratio = height / width
83
- closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
84
- return ratios[closest_ratio], closest_ratio
85
-
86
-
87
- def _aspect_ratio_batched(
88
- data,
89
- batchsize=20,
90
- aspect_ratios=ASPECT_RATIO_512,
91
- batch_cross=False,
92
- collation_fn=default_collation_fn,
93
- partial=True,
94
- ):
95
- """Create batches of the given size.
96
-
97
- :param data: iterator
98
- :param batchsize: target batch size
99
- :param tensors: automatically batch lists of ndarrays into ndarrays
100
- :param partial: return partial batches
101
- :returns: iterator
102
-
103
- """
104
- assert collation_fn is not None
105
- buckets = {
106
- ratio: {"cross": [], "no_cross": []} for ratio in aspect_ratios.keys()
107
- }
108
-
109
- def check(buckets):
110
- for ratio in buckets:
111
- for bucket_name in buckets[ratio]:
112
- bucket = buckets[ratio][bucket_name]
113
- assert len(bucket) < batchsize
114
-
115
- for sample in data:
116
- check(buckets)
117
- height, width = sample['original_sizes']
118
- (new_height, new_width), closest_ratio = get_closest_ratio(height, width, aspect_ratios)
119
-
120
- bucket_name = "cross" if sample["has_cross"] and batch_cross else "no_cross"
121
- bucket = buckets[closest_ratio][bucket_name]
122
- bucket.append(sample)
123
-
124
- if len(bucket) >= batchsize:
125
- try:
126
- batch = collation_fn(bucket)
127
- yield batch
128
- del batch
129
- except Exception as e:
130
- print(f"[aspect_ratio_batched] collation_fn batch failed due to error {e}")
131
- for sample in bucket:
132
- if "__key__" in sample:
133
- print("error sample key in batch:", sample["__key__"])
134
- if "__url__" in sample:
135
- print("error sample url in batch:", sample["__url__"])
136
- buckets[closest_ratio][bucket_name] = []
137
- del bucket
138
- gc.collect()
139
-
140
- # yield the rest data and reset the buckets
141
- for ratio in buckets.keys():
142
- for bucket_name in ["cross", "no_cross"]:
143
- bucket = buckets[ratio][bucket_name]
144
- if len(bucket) > 0:
145
- if len(bucket) == batchsize or partial:
146
- batch = collation_fn(bucket)
147
- yield batch
148
- del batch
149
- buckets[ratio][bucket_name] = []
150
- del bucket
151
-
152
- aspect_ratio_batched = pipelinefilter(_aspect_ratio_batched)
153
-
154
- def apply_aspect_ratio_batched(dataset, batchsize, aspect_ratios, batch_cross, collation_fn, partial=True):
155
- return dataset.compose(
156
- aspect_ratio_batched(
157
- batchsize,
158
- aspect_ratios=aspect_ratios,
159
- batch_cross=batch_cross,
160
- collation_fn=collation_fn,
161
- partial=partial
162
- )
163
- )
164
-
165
- def get_aspect_ratios(enable_aspect_ratio, resolution):
166
- if enable_aspect_ratio:
167
- # print("[Dataset] Multi Aspect Ratio Training Enabled")
168
- if resolution == 256:
169
- aspect_ratios = ASPECT_RATIO_256
170
- elif resolution == 384:
171
- aspect_ratios = ASPECT_RATIO_384
172
- elif resolution == 512:
173
- aspect_ratios = ASPECT_RATIO_512
174
- elif resolution == 768:
175
- aspect_ratios = ASPECT_RATIO_768
176
- elif resolution == 1024:
177
- aspect_ratios = ASPECT_RATIO_1024
178
- else:
179
- aspect_ratios = calculate_aspect_ratios(resolution)
180
- else:
181
- # print("[Dataset] Multi Aspect Ratio Training Disabled")
182
- aspect_ratios = {
183
- '1.0': [resolution, resolution]
184
- }
185
- return aspect_ratios
186
-
187
- def bbox_to_grid(bbox, image_size, output_size=(224, 224)):
188
- """
189
- Convert bounding box to a grid of points.
190
- Args:
191
- bbox (list of float): [xmin, ymin, xmax, ymax]
192
- output_size (tuple of int): (height, width) of the output grid
193
-
194
- Returns:
195
- torch.Tensor: Grid of points with shape (output_height, output_width, 2)
196
- """
197
- xmin, ymin, xmax, ymax = bbox
198
-
199
- # Create a meshgrid for the output grid
200
- h, w = output_size
201
- yy, xx = torch.meshgrid(
202
- torch.linspace(ymin, ymax, h),
203
- torch.linspace(xmin, xmax, w)
204
- )
205
- grid = torch.stack((xx, yy), -1)
206
-
207
- # Normalize grid to range [-1, 1]
208
- H, W = image_size
209
- grid[..., 0] = grid[..., 0] / (W - 1) * 2 - 1 # Normalize x to [-1, 1]
210
- grid[..., 1] = grid[..., 1] / (H - 1) * 2 - 1 # Normalize y to [-1, 1]
211
-
212
- return grid
213
-
214
- def random_crop_instance(instance, min_crop_ratio):
215
- assert 0 < min_crop_ratio <= 1
216
- crop_width_ratio = random.uniform(min_crop_ratio, 1)
217
- crop_height_ratio = random.uniform(min_crop_ratio, 1)
218
-
219
- orig_width, orig_height = instance.size
220
-
221
- crop_width = int(orig_width * crop_width_ratio)
222
- crop_height = int(orig_height * crop_height_ratio)
223
-
224
- crop_left = random.randint(0, orig_width - crop_width)
225
- crop_top = random.randint(0, orig_height - crop_height)
226
-
227
- crop_box = (crop_left, crop_top, crop_left + crop_width, crop_top + crop_height) # (left, upper, right, lower)
228
- return instance.crop(crop_box), crop_box
229
-
230
- pil2tensor = T.ToTensor()
231
- tensor2pil = T.ToPILImage()
232
-
233
- cv2pil = lambda x: Image.fromarray(cv2.cvtColor(x, cv2.COLOR_BGR2RGB))
234
- pil2cv2 = lambda x: cv2.cvtColor(np.array(x), cv2.COLOR_RGB2BGR)
235
-
236
- def compute_psnr(x, y):
237
- y = y.resize(x.size)
238
- x = pil2tensor(x) * 255.
239
- y = pil2tensor(y) * 255.
240
- mse = torch.mean((x - y) ** 2)
241
- return 20 * torch.log10(255.0 / torch.sqrt(mse)).item()
242
-
243
- def replace_first_occurrence(sentence, word_or_phrase, replace_with):
244
- # Escape special characters in word_or_phrase for exact matching
245
- escaped_word_or_phrase = re.escape(word_or_phrase)
246
- pattern = r'\b' + escaped_word_or_phrase + r'\b'
247
-
248
- # Finding the first match
249
- match = next(re.finditer(pattern, sentence), None)
250
- if match:
251
- # Perform replacement
252
- result = re.sub(pattern, replace_with, sentence, count=1)
253
- replaced = True
254
- index = match.start()
255
- else:
256
- # No match found
257
- result = sentence
258
- replaced = False
259
- index = -1
260
-
261
- return result, replaced, index
262
-
263
-
264
- def decode_base64_to_image(base64_str):
265
- # Decode the base64 string to bytes
266
- img_bytes = base64.b64decode(base64_str)
267
- # Create a BytesIO buffer from the bytes
268
- img_buffer = io.BytesIO(img_bytes)
269
- # Open the image using Pillow
270
- image = Image.open(img_buffer)
271
- return image
272
-
273
- def jpeg_compression(pil_image, quality):
274
- buffer = io.BytesIO()
275
- pil_image.save(buffer, format="JPEG", quality=quality)
276
- return Image.open(io.BytesIO(buffer.getvalue()))
277
-
278
- def pad_to_square(pil_image):
279
- new_size = max(pil_image.width, pil_image.height)
280
- square_image = Image.new("RGB", (new_size, new_size), "white")
281
- left = (new_size - pil_image.width) // 2
282
- top = (new_size - pil_image.height) // 2
283
- square_image.paste(pil_image, (left, top))
284
- return square_image
285
-
286
- def pad_to_target(pil_image, target_size):
287
- original_width, original_height = pil_image.size
288
- target_width, target_height = target_size
289
-
290
- original_aspect_ratio = original_width / original_height
291
- target_aspect_ratio = target_width / target_height
292
-
293
- # Pad the image to the target aspect ratio
294
- if original_aspect_ratio > target_aspect_ratio:
295
- new_width = original_width
296
- new_height = int(new_width / target_aspect_ratio)
297
- else:
298
- new_height = original_height
299
- new_width = int(new_height * target_aspect_ratio)
300
-
301
- pad_image = Image.new("RGB", (new_width, new_height), "white")
302
- left = (new_width - original_width) // 2
303
- top = (new_height - original_height) // 2
304
- pad_image.paste(pil_image, (left, top))
305
-
306
- # Resize the image to the target size
307
- resized_image = pad_image.resize(target_size)
308
- return resized_image
309
-
310
- def image_grid(imgs, rows, cols):
311
- # assert len(imgs) == rows * cols
312
-
313
- w, h = imgs[0].size
314
- if imgs[0].mode == 'L':
315
- grid = Image.new('L', size=(cols * w, rows * h))
316
- else:
317
- grid = Image.new('RGB', size=(cols * w, rows * h))
318
-
319
- for i, img in enumerate(imgs):
320
- grid.paste(img, box=(i % cols * w, i // cols * h))
321
- return grid
322
-
323
- def split_grid(image):
324
- width = image.width // 2
325
- height = image.height // 2
326
-
327
- crop_tuples_list = [
328
- (0, 0, width, height),
329
- (width, 0, width*2, height),
330
- (0, height, width, height*2),
331
- (width, height, width*2, height*2),
332
- ]
333
- def crop_image(input_image, crop_tuple=None):
334
- if crop_tuple is None:
335
- return input_image
336
- return input_image.crop((crop_tuple[0], crop_tuple[1], crop_tuple[2], crop_tuple[3]))
337
-
338
- return [crop_image(image, crop_tuple) for crop_tuple in crop_tuples_list]
339
-
340
- def add_border(img, border_color, border_thickness):
341
- """
342
- Add a colored border to an image without changing its size.
343
-
344
- Parameters:
345
- border_color (tuple): Border color in RGB (e.g., (255, 0, 0) for red).
346
- border_thickness (int): Thickness of the border in pixels.
347
- """
348
- width, height = img.size
349
- img = img.copy()
350
- draw = ImageDraw.Draw(img)
351
- draw.rectangle((0, 0, width, border_thickness), fill=border_color)
352
- draw.rectangle((0, height - border_thickness, width, height), fill=border_color)
353
- draw.rectangle((0, 0, border_thickness, height), fill=border_color)
354
- draw.rectangle((width - border_thickness, 0, width, height), fill=border_color)
355
- return img
356
-
357
- def merge_bboxes(bboxes):
358
- if not bboxes:
359
- return None # Handle empty input
360
-
361
- # Extract all coordinates
362
- x_mins = [b[0] for b in bboxes]
363
- y_mins = [b[1] for b in bboxes]
364
- x_maxs = [b[2] for b in bboxes]
365
- y_maxs = [b[3] for b in bboxes]
366
-
367
- # Compute the merged box
368
- merged_box = (
369
- min(x_mins), # x_min
370
- min(y_mins), # y_min
371
- max(x_maxs), # x_max
372
- max(y_maxs) # y_max
373
- )
374
- return merged_box
375
-
376
-
377
- def flip_bbox_left_right(bbox, image_width):
378
- """
379
- Flips the bounding box horizontally on an image.
380
-
381
- Parameters:
382
- bbox (list of float): [x_min, y_min, x_max, y_max]
383
- image_width (int): The width of the image
384
-
385
- Returns:
386
- list of float: New bounding box after horizontal flip [x_min', y_min', x_max', y_max']
387
- """
388
- x_min, y_min, x_max, y_max = bbox
389
- new_x_min = image_width - x_max
390
- new_x_max = image_width - x_min
391
- new_bbox = [new_x_min, y_min, new_x_max, y_max]
392
- return new_bbox
393
-
394
- def json_load(path, encoding='ascii'):
395
- with open(path, 'r', encoding=encoding) as file:
396
- return json.load(file)
397
-
398
- def json_dump(obj, path, encoding='ascii', indent=4, create_dir=True, verbose=True, **kwargs):
399
- if create_dir and os.path.dirname(path) != '':
400
- os.makedirs(os.path.dirname(path), exist_ok=True)
401
- with open(path, 'w', encoding=encoding) as file:
402
- json.dump(obj, file, indent=4, ensure_ascii=False, **kwargs)
403
- if verbose:
404
- print(type(obj), 'saved to', path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/modulation_utils.py DELETED
@@ -1,55 +0,0 @@
1
- # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # Copyright (c) Facebook, Inc. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import torch
17
- from src.flux.pipeline_tools import tokenize_t5_prompt
18
-
19
- def unpad_input_ids(input_ids, attention_mask):
20
- return [input_ids[i][attention_mask[i].bool()][:-1] for i in range(input_ids.shape[0])]
21
-
22
- def get_word_index(pipe, prompt, input_ids, word, word_count=1, max_length=512, verbose=True, reverse=False):
23
- word_inputs = tokenize_t5_prompt(pipe, word, max_length)
24
- word_ids = unpad_input_ids(word_inputs.input_ids, word_inputs.attention_mask)[0]
25
- if word_ids[0] == 3:
26
- word_ids = word_ids[1:] # remove prefix space
27
-
28
- if verbose:
29
- print(f"Trying to find {word} {word_ids.tolist()} in {input_ids.tolist()} where")
30
- print([(i, pipe.tokenizer_2.decode(input_ids[i])) for i in range(input_ids.shape[0])])
31
-
32
- count = 0
33
- if reverse:
34
- for i in range(input_ids.shape[0] - word_ids.shape[0],-1,-1):
35
- if torch.equal(input_ids[i:i+word_ids.shape[0]], word_ids):
36
- count += 1
37
- if count == word_count:
38
- if verbose:
39
- reconstructed_word = pipe.tokenizer_2.decode(input_ids[i:i+word_ids.shape[0]])
40
- assert reconstructed_word == word
41
- print(f"[Reverse] Found index {i} to {i+word_ids.shape[0]} for '{word}' in prompt '{prompt}'")
42
- print("Reconstructed word", reconstructed_word)
43
- return i, i + word_ids.shape[0]
44
- else:
45
- for i in range(input_ids.shape[0] - word_ids.shape[0] + 1):
46
- if torch.equal(input_ids[i:i+word_ids.shape[0]], word_ids):
47
- count += 1
48
- if count == word_count:
49
- if verbose:
50
- reconstructed_word = pipe.tokenizer_2.decode(input_ids[i:i+word_ids.shape[0]])
51
- assert reconstructed_word == word
52
- print(f"Found index {i} to {i+word_ids.shape[0]} for '{word}' in prompt '{prompt}'")
53
- print("Reconstructed word", reconstructed_word)
54
- return i, i + word_ids.shape[0]
55
- print(f"[Error] Could not find '{word}' in prompt '{prompt}' with word_count {word_count}")