yichenchenchen commited on
Commit
a0250fc
·
verified ·
1 Parent(s): 40888ec

Upload 7 files

Browse files
ai.png ADDED
unipicv2/configuration_connector.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+
7
+ class ConnectorConfig(PretrainedConfig):
8
+ def __init__(
9
+ self,
10
+ hidden_size=768,
11
+ intermediate_size=3072,
12
+ num_hidden_layers=12,
13
+ num_attention_heads=12,
14
+ hidden_act="gelu_pytorch_tanh",
15
+ layer_norm_eps=1e-6,
16
+ attention_dropout=0.0,
17
+ **kwargs,
18
+ ):
19
+ super().__init__(**kwargs)
20
+
21
+ self.hidden_size = hidden_size
22
+ self.intermediate_size = intermediate_size
23
+ self.num_hidden_layers = num_hidden_layers
24
+ self.num_attention_heads = num_attention_heads
25
+ self.attention_dropout = attention_dropout
26
+ self.layer_norm_eps = layer_norm_eps
27
+ self.hidden_act = hidden_act
unipicv2/modeling_connector.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from typing import Any, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+ from torch.nn.init import _calculate_fan_in_and_fan_out
9
+
10
+ from transformers.activations import ACT2FN
11
+ from transformers.utils import (
12
+ is_flash_attn_2_available,
13
+ is_flash_attn_greater_or_equal_2_10,
14
+ logging,
15
+ )
16
+ from .configuration_connector import ConnectorConfig
17
+
18
+
19
+ if is_flash_attn_2_available():
20
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ def init_weights(module):
27
+ """Initialize the weights"""
28
+ if isinstance(module, nn.Embedding):
29
+ default_flax_embed_init(module.weight)
30
+ elif isinstance(module, ConnectorAttention):
31
+ nn.init.xavier_uniform_(module.q_proj.weight)
32
+ nn.init.xavier_uniform_(module.k_proj.weight)
33
+ nn.init.xavier_uniform_(module.v_proj.weight)
34
+ nn.init.xavier_uniform_(module.out_proj.weight)
35
+ nn.init.zeros_(module.q_proj.bias)
36
+ nn.init.zeros_(module.k_proj.bias)
37
+ nn.init.zeros_(module.v_proj.bias)
38
+ nn.init.zeros_(module.out_proj.bias)
39
+ elif isinstance(module, ConnectorMLP):
40
+ nn.init.xavier_uniform_(module.fc1.weight)
41
+ nn.init.xavier_uniform_(module.fc2.weight)
42
+ nn.init.normal_(module.fc1.bias, std=1e-6)
43
+ nn.init.normal_(module.fc2.bias, std=1e-6)
44
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
45
+ lecun_normal_(module.weight)
46
+ if module.bias is not None:
47
+ nn.init.zeros_(module.bias)
48
+ elif isinstance(module, nn.LayerNorm):
49
+ module.bias.data.zero_()
50
+ module.weight.data.fill_(1.0)
51
+
52
+
53
+ def _trunc_normal_(tensor, mean, std, a, b):
54
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
55
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
56
+ def norm_cdf(x):
57
+ # Computes standard normal cumulative distribution function
58
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
59
+
60
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
61
+ warnings.warn(
62
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
63
+ "The distribution of values may be incorrect.",
64
+ stacklevel=2,
65
+ )
66
+
67
+ # Values are generated by using a truncated uniform distribution and
68
+ # then using the inverse CDF for the normal distribution.
69
+ # Get upper and lower cdf values
70
+ l = norm_cdf((a - mean) / std)
71
+ u = norm_cdf((b - mean) / std)
72
+
73
+ # Uniformly fill tensor with values from [l, u], then translate to
74
+ # [2l-1, 2u-1].
75
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
76
+
77
+ # Use inverse cdf transform for normal distribution to get truncated
78
+ # standard normal
79
+ tensor.erfinv_()
80
+
81
+ # Transform to proper mean, std
82
+ tensor.mul_(std * math.sqrt(2.0))
83
+ tensor.add_(mean)
84
+
85
+ # Clamp to ensure it's in the proper range
86
+ tensor.clamp_(min=a, max=b)
87
+
88
+
89
+ def trunc_normal_tf_(
90
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
91
+ ) -> torch.Tensor:
92
+ """Fills the input Tensor with values drawn from a truncated
93
+ normal distribution. The values are effectively drawn from the
94
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
95
+ with values outside :math:`[a, b]` redrawn until they are within
96
+ the bounds. The method used for generating the random values works
97
+ best when :math:`a \\leq \text{mean} \\leq b`.
98
+
99
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
100
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
101
+ and the result is subsequently scaled and shifted by the mean and std args.
102
+
103
+ Args:
104
+ tensor: an n-dimensional `torch.Tensor`
105
+ mean: the mean of the normal distribution
106
+ std: the standard deviation of the normal distribution
107
+ a: the minimum cutoff value
108
+ b: the maximum cutoff value
109
+ """
110
+ with torch.no_grad():
111
+ _trunc_normal_(tensor, 0, 1.0, a, b)
112
+ tensor.mul_(std).add_(mean)
113
+
114
+
115
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
116
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
117
+ if mode == "fan_in":
118
+ denom = fan_in
119
+ elif mode == "fan_out":
120
+ denom = fan_out
121
+ elif mode == "fan_avg":
122
+ denom = (fan_in + fan_out) / 2
123
+
124
+ variance = scale / denom
125
+
126
+ if distribution == "truncated_normal":
127
+ # constant is stddev of standard normal truncated to (-2, 2)
128
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
129
+ elif distribution == "normal":
130
+ with torch.no_grad():
131
+ tensor.normal_(std=math.sqrt(variance))
132
+ elif distribution == "uniform":
133
+ bound = math.sqrt(3 * variance)
134
+ with torch.no_grad():
135
+ tensor.uniform_(-bound, bound)
136
+ else:
137
+ raise ValueError(f"invalid distribution {distribution}")
138
+
139
+
140
+ def lecun_normal_(tensor):
141
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
142
+
143
+
144
+ def default_flax_embed_init(tensor):
145
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
146
+
147
+
148
+ class ConnectorAttention(nn.Module):
149
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
150
+
151
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
152
+ def __init__(self, config):
153
+ super().__init__()
154
+ self.config = config
155
+ self.embed_dim = config.hidden_size
156
+ self.num_heads = config.num_attention_heads
157
+ self.head_dim = self.embed_dim // self.num_heads
158
+ if self.head_dim * self.num_heads != self.embed_dim:
159
+ raise ValueError(
160
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
161
+ f" {self.num_heads})."
162
+ )
163
+ self.scale = self.head_dim**-0.5
164
+ self.dropout = config.attention_dropout
165
+
166
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
167
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
168
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
169
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
170
+
171
+ def forward(
172
+ self,
173
+ hidden_states: torch.Tensor,
174
+ attention_mask: Optional[torch.Tensor] = None,
175
+ output_attentions: Optional[bool] = False,
176
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
177
+ """Input shape: Batch x Time x Channel"""
178
+
179
+ batch_size, q_len, _ = hidden_states.size()
180
+
181
+ query_states = self.q_proj(hidden_states)
182
+ key_states = self.k_proj(hidden_states)
183
+ value_states = self.v_proj(hidden_states)
184
+
185
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
186
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
187
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
188
+
189
+ k_v_seq_len = key_states.shape[-2]
190
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
191
+
192
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
193
+ raise ValueError(
194
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
195
+ f" {attn_weights.size()}"
196
+ )
197
+
198
+ if attention_mask is not None:
199
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
200
+ raise ValueError(
201
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
202
+ )
203
+ attn_weights = attn_weights + attention_mask
204
+
205
+ # upcast attention to fp32
206
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
207
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
208
+ attn_output = torch.matmul(attn_weights, value_states)
209
+
210
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
211
+ raise ValueError(
212
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
213
+ f" {attn_output.size()}"
214
+ )
215
+
216
+ attn_output = attn_output.transpose(1, 2).contiguous()
217
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
218
+
219
+ attn_output = self.out_proj(attn_output)
220
+
221
+ return attn_output, attn_weights
222
+
223
+
224
+ class ConnectorFlashAttention2(ConnectorAttention):
225
+ """
226
+ ConnectorAttention flash attention module. This module inherits from `ConnectorAttention` as the weights of the module stays
227
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
228
+ flash attention and deal with padding tokens in case the input contains any of them.
229
+ """
230
+
231
+ is_causal = False
232
+
233
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
234
+ def __init__(self, *args, **kwargs):
235
+ super().__init__(*args, **kwargs)
236
+
237
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
238
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
239
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
240
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
241
+
242
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
243
+ def forward(
244
+ self,
245
+ hidden_states: torch.Tensor,
246
+ attention_mask: Optional[torch.LongTensor] = None,
247
+ output_attentions: bool = False,
248
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
249
+ output_attentions = False
250
+
251
+ batch_size, q_len, _ = hidden_states.size()
252
+
253
+ query_states = self.q_proj(hidden_states)
254
+ key_states = self.k_proj(hidden_states)
255
+ value_states = self.v_proj(hidden_states)
256
+
257
+ # Flash attention requires the input to have the shape
258
+ # batch_size x seq_length x head_dim x hidden_dim
259
+ # therefore we just need to keep the original shape
260
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
261
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
262
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
263
+
264
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
265
+ # to be able to avoid many of these transpose/reshape/view.
266
+ query_states = query_states.transpose(1, 2)
267
+ key_states = key_states.transpose(1, 2)
268
+ value_states = value_states.transpose(1, 2)
269
+
270
+ dropout_rate = self.dropout if self.training else 0.0
271
+
272
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
273
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
274
+ # cast them back in the correct dtype just to be sure everything works as expected.
275
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
276
+ # in fp32.
277
+
278
+ input_dtype = query_states.dtype
279
+ if input_dtype == torch.float32:
280
+ if torch.is_autocast_enabled():
281
+ target_dtype = torch.get_autocast_gpu_dtype()
282
+ # Handle the case where the model is quantized
283
+ elif hasattr(self.config, "_pre_quantization_dtype"):
284
+ target_dtype = self.config._pre_quantization_dtype
285
+ else:
286
+ target_dtype = self.q_proj.weight.dtype
287
+
288
+ logger.warning_once(
289
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
290
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
291
+ f" {target_dtype}."
292
+ )
293
+
294
+ query_states = query_states.to(target_dtype)
295
+ key_states = key_states.to(target_dtype)
296
+ value_states = value_states.to(target_dtype)
297
+
298
+ attn_output = _flash_attention_forward(
299
+ query_states,
300
+ key_states,
301
+ value_states,
302
+ attention_mask,
303
+ q_len,
304
+ dropout=dropout_rate,
305
+ is_causal=self.is_causal,
306
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
307
+ )
308
+
309
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
310
+ attn_output = self.out_proj(attn_output)
311
+
312
+ if not output_attentions:
313
+ attn_weights = None
314
+
315
+ return attn_output, attn_weights
316
+
317
+
318
+ class ConnectorSdpaAttention(ConnectorAttention):
319
+ """
320
+ Connector attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
321
+ `ConnectorAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
322
+ SDPA API.
323
+ """
324
+
325
+ is_causal = False
326
+
327
+ # Adapted from ConnectorAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
328
+ def forward(
329
+ self,
330
+ hidden_states: torch.Tensor,
331
+ attention_mask: Optional[torch.Tensor] = None,
332
+ output_attentions: Optional[bool] = False,
333
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
334
+ if output_attentions:
335
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
336
+ logger.warning_once(
337
+ "ConnectorModel is using ConnectorSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
338
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
339
+ )
340
+ return super().forward(
341
+ hidden_states=hidden_states,
342
+ attention_mask=attention_mask,
343
+ output_attentions=output_attentions,
344
+ )
345
+
346
+ batch_size, q_len, _ = hidden_states.size()
347
+
348
+ query_states = self.q_proj(hidden_states)
349
+ key_states = self.k_proj(hidden_states)
350
+ value_states = self.v_proj(hidden_states)
351
+
352
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
353
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
354
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
355
+
356
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
357
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
358
+ if query_states.device.type == "cuda" and attention_mask is not None:
359
+ query_states = query_states.contiguous()
360
+ key_states = key_states.contiguous()
361
+ value_states = value_states.contiguous()
362
+
363
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
364
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
365
+ is_causal = True if self.is_causal and q_len > 1 else False
366
+
367
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
368
+ query_states,
369
+ key_states,
370
+ value_states,
371
+ attn_mask=attention_mask,
372
+ dropout_p=self.dropout if self.training else 0.0,
373
+ is_causal=is_causal,
374
+ )
375
+
376
+ attn_output = attn_output.transpose(1, 2).contiguous()
377
+ attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
378
+
379
+ attn_output = self.out_proj(attn_output)
380
+
381
+ return attn_output, None
382
+
383
+
384
+ CONNECTOR_ATTENTION_CLASSES = {
385
+ "eager": ConnectorAttention,
386
+ "flash_attention_2": ConnectorFlashAttention2,
387
+ "sdpa": ConnectorSdpaAttention,
388
+ }
389
+
390
+
391
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Connector
392
+ class ConnectorMLP(nn.Module):
393
+ def __init__(self, config):
394
+ super().__init__()
395
+ self.config = config
396
+ self.activation_fn = ACT2FN[config.hidden_act]
397
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
398
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
399
+
400
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
401
+ hidden_states = self.fc1(hidden_states)
402
+ hidden_states = self.activation_fn(hidden_states)
403
+ hidden_states = self.fc2(hidden_states)
404
+ return hidden_states
405
+
406
+
407
+ class ConnectorEncoderLayer(nn.Module):
408
+ def __init__(self, config: ConnectorConfig):
409
+ super().__init__()
410
+ self.embed_dim = config.hidden_size
411
+ self.self_attn = CONNECTOR_ATTENTION_CLASSES[config._attn_implementation](config=config)
412
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
413
+ self.mlp = ConnectorMLP(config)
414
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
415
+
416
+ # Ignore copy
417
+ def forward(
418
+ self,
419
+ hidden_states: torch.Tensor,
420
+ attention_mask: torch.Tensor,
421
+ output_attentions: Optional[bool] = False,
422
+ ) -> Tuple[torch.FloatTensor]:
423
+ """
424
+ Args:
425
+ hidden_states (`torch.FloatTensor`):
426
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
427
+ attention_mask (`torch.FloatTensor`):
428
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
429
+ output_attentions (`bool`, *optional*, defaults to `False`):
430
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
431
+ returned tensors for more detail.
432
+ """
433
+ residual = hidden_states
434
+
435
+ hidden_states = self.layer_norm1(hidden_states)
436
+ hidden_states, attn_weights = self.self_attn(
437
+ hidden_states=hidden_states,
438
+ attention_mask=attention_mask,
439
+ output_attentions=output_attentions,
440
+ )
441
+ hidden_states = residual + hidden_states
442
+
443
+ residual = hidden_states
444
+ hidden_states = self.layer_norm2(hidden_states)
445
+ hidden_states = self.mlp(hidden_states)
446
+ hidden_states = residual + hidden_states
447
+
448
+ outputs = (hidden_states,)
449
+
450
+ if output_attentions:
451
+ outputs += (attn_weights,)
452
+
453
+ return outputs
454
+
455
+
456
+ # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Connector
457
+ class ConnectorEncoder(nn.Module):
458
+ def __init__(self, config: ConnectorConfig):
459
+ super().__init__()
460
+ self.config = config
461
+ self.layers = nn.ModuleList([ConnectorEncoderLayer(config) for _ in range(config.num_hidden_layers)])
462
+ self.gradient_checkpointing = False
463
+ self.apply(init_weights)
464
+
465
+ def forward(self, inputs_embeds):
466
+ hidden_states = inputs_embeds
467
+ for encoder_layer in self.layers:
468
+ if self.gradient_checkpointing and self.training:
469
+ layer_outputs = torch.utils.checkpoint.checkpoint(
470
+ encoder_layer.__call__,
471
+ hidden_states,
472
+ None,
473
+ False,
474
+ use_reentrant=False
475
+ )
476
+ else:
477
+ layer_outputs = encoder_layer(
478
+ hidden_states,
479
+ None,
480
+ output_attentions=False,
481
+ )
482
+
483
+ hidden_states = layer_outputs[0]
484
+
485
+ return hidden_states
unipicv2/pipeline_stable_diffusion_3_kontext.py ADDED
@@ -0,0 +1,1142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import torch
5
+ from torchvision import transforms
6
+ from transformers import (
7
+ CLIPTextModelWithProjection,
8
+ CLIPTokenizer,
9
+ SiglipImageProcessor,
10
+ SiglipVisionModel,
11
+ T5EncoderModel,
12
+ T5TokenizerFast,
13
+ )
14
+
15
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
16
+ from diffusers.loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
17
+ from diffusers.models.autoencoders import AutoencoderKL
18
+ from .transformer_sd3_kontext import SD3Transformer2DKontextModel
19
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
20
+ from diffusers.utils import (
21
+ USE_PEFT_BACKEND,
22
+ is_torch_xla_available,
23
+ logging,
24
+ replace_example_docstring,
25
+ scale_lora_layers,
26
+ unscale_lora_layers,
27
+ )
28
+ from diffusers.utils.torch_utils import randn_tensor
29
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
30
+ from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
31
+
32
+
33
+ if is_torch_xla_available():
34
+ import torch_xla.core.xla_model as xm
35
+
36
+ XLA_AVAILABLE = True
37
+ else:
38
+ XLA_AVAILABLE = False
39
+
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+ EXAMPLE_DOC_STRING = """
44
+ Examples:
45
+ ```py
46
+ >>> import torch
47
+ >>> from diffusers import StableDiffusion3Pipeline
48
+
49
+ >>> pipe = StableDiffusion3Pipeline.from_pretrained(
50
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
51
+ ... )
52
+ >>> pipe.to("cuda")
53
+ >>> prompt = "A cat holding a sign that says hello world"
54
+ >>> image = pipe(prompt).images[0]
55
+ >>> image.save("sd3.png")
56
+ ```
57
+ """
58
+
59
+
60
+ def pil_list_to_tensor(images):
61
+ """
62
+ Args:
63
+ images: list/tuple of PIL.Image with same H, W
64
+ Returns:
65
+ torch.Tensor: (B, C, H, W) in [-1, 1]
66
+ """
67
+ # Step 1: Convert each PIL to tensor in [0, 1]
68
+ to_tensor = transforms.ToTensor() # PIL -> float tensor in [0, 1]
69
+ tensors = [to_tensor(img) for img in images] # list of (C, H, W)
70
+
71
+ # Step 2: Stack into (B, C, H, W)
72
+ batch = torch.stack(tensors, dim=0) # (B, C, H, W)
73
+
74
+ # Step 3: Scale [0, 1] -> [-1, 1]
75
+ batch = batch * 2.0 - 1.0
76
+ return batch
77
+
78
+
79
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
80
+ def calculate_shift(
81
+ image_seq_len,
82
+ base_seq_len: int = 256,
83
+ max_seq_len: int = 4096,
84
+ base_shift: float = 0.5,
85
+ max_shift: float = 1.15,
86
+ ):
87
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
88
+ b = base_shift - m * base_seq_len
89
+ mu = image_seq_len * m + b
90
+ return mu
91
+
92
+
93
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
94
+ def retrieve_timesteps(
95
+ scheduler,
96
+ num_inference_steps: Optional[int] = None,
97
+ device: Optional[Union[str, torch.device]] = None,
98
+ timesteps: Optional[List[int]] = None,
99
+ sigmas: Optional[List[float]] = None,
100
+ **kwargs,
101
+ ):
102
+ r"""
103
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
104
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
105
+
106
+ Args:
107
+ scheduler (`SchedulerMixin`):
108
+ The scheduler to get timesteps from.
109
+ num_inference_steps (`int`):
110
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
111
+ must be `None`.
112
+ device (`str` or `torch.device`, *optional*):
113
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
114
+ timesteps (`List[int]`, *optional*):
115
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
116
+ `num_inference_steps` and `sigmas` must be `None`.
117
+ sigmas (`List[float]`, *optional*):
118
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
119
+ `num_inference_steps` and `timesteps` must be `None`.
120
+
121
+ Returns:
122
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
123
+ second element is the number of inference steps.
124
+ """
125
+ if timesteps is not None and sigmas is not None:
126
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
127
+ if timesteps is not None:
128
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
129
+ if not accepts_timesteps:
130
+ raise ValueError(
131
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
132
+ f" timestep schedules. Please check whether you are using the correct scheduler."
133
+ )
134
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ num_inference_steps = len(timesteps)
137
+ elif sigmas is not None:
138
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
139
+ if not accept_sigmas:
140
+ raise ValueError(
141
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
142
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
143
+ )
144
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
145
+ timesteps = scheduler.timesteps
146
+ num_inference_steps = len(timesteps)
147
+ else:
148
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
149
+ timesteps = scheduler.timesteps
150
+ return timesteps, num_inference_steps
151
+
152
+
153
+ class StableDiffusion3KontextPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
154
+ r"""
155
+ Args:
156
+ transformer ([`SD3Transformer2DModel`]):
157
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
158
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
159
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
160
+ vae ([`AutoencoderKL`]):
161
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
162
+ text_encoder ([`CLIPTextModelWithProjection`]):
163
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
164
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
165
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
166
+ as its dimension.
167
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
168
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
169
+ specifically the
170
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
171
+ variant.
172
+ text_encoder_3 ([`T5EncoderModel`]):
173
+ Frozen text-encoder. Stable Diffusion 3 uses
174
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
175
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
176
+ tokenizer (`CLIPTokenizer`):
177
+ Tokenizer of class
178
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
179
+ tokenizer_2 (`CLIPTokenizer`):
180
+ Second Tokenizer of class
181
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
182
+ tokenizer_3 (`T5TokenizerFast`):
183
+ Tokenizer of class
184
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
185
+ image_encoder (`SiglipVisionModel`, *optional*):
186
+ Pre-trained Vision Model for IP Adapter.
187
+ feature_extractor (`SiglipImageProcessor`, *optional*):
188
+ Image processor for IP Adapter.
189
+ """
190
+
191
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
192
+ _optional_components = ["image_encoder", "feature_extractor"]
193
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
194
+
195
+ def __init__(
196
+ self,
197
+ transformer: SD3Transformer2DKontextModel,
198
+ scheduler: FlowMatchEulerDiscreteScheduler,
199
+ vae: AutoencoderKL,
200
+ text_encoder: CLIPTextModelWithProjection,
201
+ tokenizer: CLIPTokenizer,
202
+ text_encoder_2: CLIPTextModelWithProjection,
203
+ tokenizer_2: CLIPTokenizer,
204
+ text_encoder_3: T5EncoderModel,
205
+ tokenizer_3: T5TokenizerFast,
206
+ image_encoder: SiglipVisionModel = None,
207
+ feature_extractor: SiglipImageProcessor = None,
208
+ ):
209
+ super().__init__()
210
+
211
+ self.register_modules(
212
+ vae=vae,
213
+ text_encoder=text_encoder,
214
+ text_encoder_2=text_encoder_2,
215
+ text_encoder_3=text_encoder_3,
216
+ tokenizer=tokenizer,
217
+ tokenizer_2=tokenizer_2,
218
+ tokenizer_3=tokenizer_3,
219
+ transformer=transformer,
220
+ scheduler=scheduler,
221
+ image_encoder=image_encoder,
222
+ feature_extractor=feature_extractor,
223
+ )
224
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
225
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
226
+ self.tokenizer_max_length = (
227
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
228
+ )
229
+ self.default_sample_size = (
230
+ self.transformer.config.sample_size
231
+ if hasattr(self, "transformer") and self.transformer is not None
232
+ else 128
233
+ )
234
+ self.patch_size = (
235
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
236
+ )
237
+
238
+ def _get_t5_prompt_embeds(
239
+ self,
240
+ prompt: Union[str, List[str]] = None,
241
+ num_images_per_prompt: int = 1,
242
+ max_sequence_length: int = 256,
243
+ device: Optional[torch.device] = None,
244
+ dtype: Optional[torch.dtype] = None,
245
+ ):
246
+ device = device or self._execution_device
247
+ dtype = dtype or self.text_encoder.dtype
248
+
249
+ prompt = [prompt] if isinstance(prompt, str) else prompt
250
+ batch_size = len(prompt)
251
+
252
+ if self.text_encoder_3 is None:
253
+ return torch.zeros(
254
+ (
255
+ batch_size * num_images_per_prompt,
256
+ self.tokenizer_max_length,
257
+ self.transformer.config.joint_attention_dim,
258
+ ),
259
+ device=device,
260
+ dtype=dtype,
261
+ )
262
+
263
+ text_inputs = self.tokenizer_3(
264
+ prompt,
265
+ padding="max_length",
266
+ max_length=max_sequence_length,
267
+ truncation=True,
268
+ add_special_tokens=True,
269
+ return_tensors="pt",
270
+ )
271
+ text_input_ids = text_inputs.input_ids
272
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
273
+
274
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
275
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
276
+ logger.warning(
277
+ "The following part of your input was truncated because `max_sequence_length` is set to "
278
+ f" {max_sequence_length} tokens: {removed_text}"
279
+ )
280
+
281
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
282
+
283
+ dtype = self.text_encoder_3.dtype
284
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
285
+
286
+ _, seq_len, _ = prompt_embeds.shape
287
+
288
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
289
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
290
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
291
+
292
+ return prompt_embeds
293
+
294
+ def _get_clip_prompt_embeds(
295
+ self,
296
+ prompt: Union[str, List[str]],
297
+ num_images_per_prompt: int = 1,
298
+ device: Optional[torch.device] = None,
299
+ clip_skip: Optional[int] = None,
300
+ clip_model_index: int = 0,
301
+ ):
302
+ device = device or self._execution_device
303
+
304
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
305
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
306
+
307
+ tokenizer = clip_tokenizers[clip_model_index]
308
+ text_encoder = clip_text_encoders[clip_model_index]
309
+
310
+ prompt = [prompt] if isinstance(prompt, str) else prompt
311
+ batch_size = len(prompt)
312
+
313
+ text_inputs = tokenizer(
314
+ prompt,
315
+ padding="max_length",
316
+ max_length=self.tokenizer_max_length,
317
+ truncation=True,
318
+ return_tensors="pt",
319
+ )
320
+
321
+ text_input_ids = text_inputs.input_ids
322
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
323
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
324
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
325
+ logger.warning(
326
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
327
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
328
+ )
329
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
330
+ pooled_prompt_embeds = prompt_embeds[0]
331
+
332
+ if clip_skip is None:
333
+ prompt_embeds = prompt_embeds.hidden_states[-2]
334
+ else:
335
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
336
+
337
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
338
+
339
+ _, seq_len, _ = prompt_embeds.shape
340
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
341
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
342
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
343
+
344
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
345
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
346
+
347
+ return prompt_embeds, pooled_prompt_embeds
348
+
349
+ def encode_prompt(
350
+ self,
351
+ prompt: Union[str, List[str]],
352
+ prompt_2: Union[str, List[str]],
353
+ prompt_3: Union[str, List[str]],
354
+ device: Optional[torch.device] = None,
355
+ num_images_per_prompt: int = 1,
356
+ do_classifier_free_guidance: bool = True,
357
+ negative_prompt: Optional[Union[str, List[str]]] = None,
358
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
359
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
360
+ prompt_embeds: Optional[torch.FloatTensor] = None,
361
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
362
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
363
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
364
+ clip_skip: Optional[int] = None,
365
+ max_sequence_length: int = 256,
366
+ lora_scale: Optional[float] = None,
367
+ ):
368
+ r"""
369
+
370
+ Args:
371
+ prompt (`str` or `List[str]`, *optional*):
372
+ prompt to be encoded
373
+ prompt_2 (`str` or `List[str]`, *optional*):
374
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
375
+ used in all text-encoders
376
+ prompt_3 (`str` or `List[str]`, *optional*):
377
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
378
+ used in all text-encoders
379
+ device: (`torch.device`):
380
+ torch device
381
+ num_images_per_prompt (`int`):
382
+ number of images that should be generated per prompt
383
+ do_classifier_free_guidance (`bool`):
384
+ whether to use classifier free guidance or not
385
+ negative_prompt (`str` or `List[str]`, *optional*):
386
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
387
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
388
+ less than `1`).
389
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
390
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
391
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
392
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
393
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
394
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
395
+ prompt_embeds (`torch.FloatTensor`, *optional*):
396
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
397
+ provided, text embeddings will be generated from `prompt` input argument.
398
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
399
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
400
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
401
+ argument.
402
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
403
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
404
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
405
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
406
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
407
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
408
+ input argument.
409
+ clip_skip (`int`, *optional*):
410
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
411
+ the output of the pre-final layer will be used for computing the prompt embeddings.
412
+ lora_scale (`float`, *optional*):
413
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
414
+ """
415
+ device = device or self._execution_device
416
+
417
+ # set lora scale so that monkey patched LoRA
418
+ # function of text encoder can correctly access it
419
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
420
+ self._lora_scale = lora_scale
421
+
422
+ # dynamically adjust the LoRA scale
423
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
424
+ scale_lora_layers(self.text_encoder, lora_scale)
425
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
426
+ scale_lora_layers(self.text_encoder_2, lora_scale)
427
+
428
+ prompt = [prompt] if isinstance(prompt, str) else prompt
429
+ if prompt is not None:
430
+ batch_size = len(prompt)
431
+ else:
432
+ batch_size = prompt_embeds.shape[0]
433
+
434
+ if prompt_embeds is None:
435
+ prompt_2 = prompt_2 or prompt
436
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
437
+
438
+ prompt_3 = prompt_3 or prompt
439
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
440
+
441
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
442
+ prompt=prompt,
443
+ device=device,
444
+ num_images_per_prompt=num_images_per_prompt,
445
+ clip_skip=clip_skip,
446
+ clip_model_index=0,
447
+ )
448
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
449
+ prompt=prompt_2,
450
+ device=device,
451
+ num_images_per_prompt=num_images_per_prompt,
452
+ clip_skip=clip_skip,
453
+ clip_model_index=1,
454
+ )
455
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
456
+
457
+ t5_prompt_embed = self._get_t5_prompt_embeds(
458
+ prompt=prompt_3,
459
+ num_images_per_prompt=num_images_per_prompt,
460
+ max_sequence_length=max_sequence_length,
461
+ device=device,
462
+ )
463
+
464
+ clip_prompt_embeds = torch.nn.functional.pad(
465
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
466
+ )
467
+
468
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
469
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
470
+
471
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
472
+ negative_prompt = negative_prompt or ""
473
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
474
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
475
+
476
+ # normalize str to list
477
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
478
+ negative_prompt_2 = (
479
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
480
+ )
481
+ negative_prompt_3 = (
482
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
483
+ )
484
+
485
+ if prompt is not None and type(prompt) is not type(negative_prompt):
486
+ raise TypeError(
487
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
488
+ f" {type(prompt)}."
489
+ )
490
+ elif batch_size != len(negative_prompt):
491
+ raise ValueError(
492
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
493
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
494
+ " the batch size of `prompt`."
495
+ )
496
+
497
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
498
+ negative_prompt,
499
+ device=device,
500
+ num_images_per_prompt=num_images_per_prompt,
501
+ clip_skip=None,
502
+ clip_model_index=0,
503
+ )
504
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
505
+ negative_prompt_2,
506
+ device=device,
507
+ num_images_per_prompt=num_images_per_prompt,
508
+ clip_skip=None,
509
+ clip_model_index=1,
510
+ )
511
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
512
+
513
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
514
+ prompt=negative_prompt_3,
515
+ num_images_per_prompt=num_images_per_prompt,
516
+ max_sequence_length=max_sequence_length,
517
+ device=device,
518
+ )
519
+
520
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
521
+ negative_clip_prompt_embeds,
522
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
523
+ )
524
+
525
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
526
+ negative_pooled_prompt_embeds = torch.cat(
527
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
528
+ )
529
+
530
+ if self.text_encoder is not None:
531
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
532
+ # Retrieve the original scale by scaling back the LoRA layers
533
+ unscale_lora_layers(self.text_encoder, lora_scale)
534
+
535
+ if self.text_encoder_2 is not None:
536
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
537
+ # Retrieve the original scale by scaling back the LoRA layers
538
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
539
+
540
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
541
+
542
+ def check_inputs(
543
+ self,
544
+ prompt,
545
+ prompt_2,
546
+ prompt_3,
547
+ height,
548
+ width,
549
+ negative_prompt=None,
550
+ negative_prompt_2=None,
551
+ negative_prompt_3=None,
552
+ prompt_embeds=None,
553
+ negative_prompt_embeds=None,
554
+ pooled_prompt_embeds=None,
555
+ negative_pooled_prompt_embeds=None,
556
+ callback_on_step_end_tensor_inputs=None,
557
+ max_sequence_length=None,
558
+ ):
559
+ if (
560
+ height % (self.vae_scale_factor * self.patch_size) != 0
561
+ or width % (self.vae_scale_factor * self.patch_size) != 0
562
+ ):
563
+ raise ValueError(
564
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
565
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
566
+ )
567
+
568
+ if callback_on_step_end_tensor_inputs is not None and not all(
569
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
570
+ ):
571
+ raise ValueError(
572
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
573
+ )
574
+
575
+ if prompt is not None and prompt_embeds is not None:
576
+ raise ValueError(
577
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
578
+ " only forward one of the two."
579
+ )
580
+ elif prompt_2 is not None and prompt_embeds is not None:
581
+ raise ValueError(
582
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
583
+ " only forward one of the two."
584
+ )
585
+ elif prompt_3 is not None and prompt_embeds is not None:
586
+ raise ValueError(
587
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
588
+ " only forward one of the two."
589
+ )
590
+ elif prompt is None and prompt_embeds is None:
591
+ raise ValueError(
592
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
593
+ )
594
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
595
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
596
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
597
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
598
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
599
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
600
+
601
+ if negative_prompt is not None and negative_prompt_embeds is not None:
602
+ raise ValueError(
603
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
604
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
605
+ )
606
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
607
+ raise ValueError(
608
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
609
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
610
+ )
611
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
612
+ raise ValueError(
613
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
614
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
615
+ )
616
+
617
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
618
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
619
+ raise ValueError(
620
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
621
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
622
+ f" {negative_prompt_embeds.shape}."
623
+ )
624
+
625
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
626
+ raise ValueError(
627
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
628
+ )
629
+
630
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
631
+ raise ValueError(
632
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
633
+ )
634
+
635
+ if max_sequence_length is not None and max_sequence_length > 512:
636
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
637
+
638
+ def prepare_latents(
639
+ self,
640
+ batch_size,
641
+ num_channels_latents,
642
+ height,
643
+ width,
644
+ dtype,
645
+ device,
646
+ generator,
647
+ latents=None,
648
+ ):
649
+ if latents is not None:
650
+ return latents.to(device=device, dtype=dtype)
651
+
652
+ shape = (
653
+ batch_size,
654
+ num_channels_latents,
655
+ int(height) // self.vae_scale_factor,
656
+ int(width) // self.vae_scale_factor,
657
+ )
658
+
659
+ if isinstance(generator, list) and len(generator) != batch_size:
660
+ raise ValueError(
661
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
662
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
663
+ )
664
+
665
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
666
+
667
+ return latents
668
+
669
+ @property
670
+ def guidance_scale(self):
671
+ return self._guidance_scale
672
+
673
+ @property
674
+ def skip_guidance_layers(self):
675
+ return self._skip_guidance_layers
676
+
677
+ @property
678
+ def clip_skip(self):
679
+ return self._clip_skip
680
+
681
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
682
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
683
+ # corresponds to doing no classifier free guidance.
684
+ @property
685
+ def do_classifier_free_guidance(self):
686
+ return self._guidance_scale > 1
687
+
688
+ @property
689
+ def joint_attention_kwargs(self):
690
+ return self._joint_attention_kwargs
691
+
692
+ @property
693
+ def num_timesteps(self):
694
+ return self._num_timesteps
695
+
696
+ @property
697
+ def interrupt(self):
698
+ return self._interrupt
699
+
700
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
701
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
702
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
703
+
704
+ Args:
705
+ image (`PipelineImageInput`):
706
+ Input image to be encoded.
707
+ device: (`torch.device`):
708
+ Torch device.
709
+
710
+ Returns:
711
+ `torch.Tensor`: The encoded image feature representation.
712
+ """
713
+ if not isinstance(image, torch.Tensor):
714
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
715
+
716
+ image = image.to(device=device, dtype=self.dtype)
717
+
718
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
719
+
720
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
721
+ def prepare_ip_adapter_image_embeds(
722
+ self,
723
+ ip_adapter_image: Optional[PipelineImageInput] = None,
724
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
725
+ device: Optional[torch.device] = None,
726
+ num_images_per_prompt: int = 1,
727
+ do_classifier_free_guidance: bool = True,
728
+ ) -> torch.Tensor:
729
+ """Prepares image embeddings for use in the IP-Adapter.
730
+
731
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
732
+
733
+ Args:
734
+ ip_adapter_image (`PipelineImageInput`, *optional*):
735
+ The input image to extract features from for IP-Adapter.
736
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
737
+ Precomputed image embeddings.
738
+ device: (`torch.device`, *optional*):
739
+ Torch device.
740
+ num_images_per_prompt (`int`, defaults to 1):
741
+ Number of images that should be generated per prompt.
742
+ do_classifier_free_guidance (`bool`, defaults to True):
743
+ Whether to use classifier free guidance or not.
744
+ """
745
+ device = device or self._execution_device
746
+
747
+ if ip_adapter_image_embeds is not None:
748
+ if do_classifier_free_guidance:
749
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
750
+ else:
751
+ single_image_embeds = ip_adapter_image_embeds
752
+ elif ip_adapter_image is not None:
753
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
754
+ if do_classifier_free_guidance:
755
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
756
+ else:
757
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
758
+
759
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
760
+
761
+ if do_classifier_free_guidance:
762
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
763
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
764
+
765
+ return image_embeds.to(device=device)
766
+
767
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
768
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
769
+ logger.warning(
770
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
771
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
772
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
773
+ )
774
+
775
+ super().enable_sequential_cpu_offload(*args, **kwargs)
776
+
777
+ @torch.no_grad()
778
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
779
+ def __call__(
780
+ self,
781
+ prompt: Union[str, List[str]] = None,
782
+ prompt_2: Optional[Union[str, List[str]]] = None,
783
+ prompt_3: Optional[Union[str, List[str]]] = None,
784
+ height: Optional[int] = 512,
785
+ width: Optional[int] = 512,
786
+ num_inference_steps: int = 50,
787
+ sigmas: Optional[List[float]] = None,
788
+ guidance_scale: float = 3.5,
789
+ negative_prompt: Optional[Union[str, List[str]]] = None,
790
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
791
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
792
+ num_images_per_prompt: Optional[int] = 1,
793
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
794
+ latents: Optional[torch.FloatTensor] = None,
795
+ image: Optional[PipelineImageInput] = None,
796
+ prompt_embeds: Optional[torch.FloatTensor] = None,
797
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
798
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
799
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
800
+ output_type: Optional[str] = "pil",
801
+ return_dict: bool = True,
802
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
803
+ clip_skip: Optional[int] = None,
804
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
805
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
806
+ max_sequence_length: int = 256,
807
+ skip_guidance_layers: List[int] = None,
808
+ skip_layer_guidance_scale: float = 2.8,
809
+ skip_layer_guidance_stop: float = 0.2,
810
+ skip_layer_guidance_start: float = 0.01,
811
+ mu: Optional[float] = None,
812
+ ):
813
+ r"""
814
+ Function invoked when calling the pipeline for generation.
815
+
816
+ Args:
817
+ prompt (`str` or `List[str]`, *optional*):
818
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
819
+ instead.
820
+ prompt_2 (`str` or `List[str]`, *optional*):
821
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
822
+ will be used instead
823
+ prompt_3 (`str` or `List[str]`, *optional*):
824
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
825
+ will be used instead
826
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
827
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
828
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
829
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
830
+ num_inference_steps (`int`, *optional*, defaults to 50):
831
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
832
+ expense of slower inference.
833
+ sigmas (`List[float]`, *optional*):
834
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
835
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
836
+ will be used.
837
+ guidance_scale (`float`, *optional*, defaults to 7.0):
838
+ Guidance scale as defined in [Classifier-Free Diffusion
839
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
840
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
841
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
842
+ the text `prompt`, usually at the expense of lower image quality.
843
+ negative_prompt (`str` or `List[str]`, *optional*):
844
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
845
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
846
+ less than `1`).
847
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
848
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
849
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
850
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
851
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
852
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
853
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
854
+ The number of images to generate per prompt.
855
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
856
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
857
+ to make generation deterministic.
858
+ latents (`torch.FloatTensor`, *optional*):
859
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
860
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
861
+ tensor will ge generated by sampling using the supplied random `generator`.
862
+ prompt_embeds (`torch.FloatTensor`, *optional*):
863
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
864
+ provided, text embeddings will be generated from `prompt` input argument.
865
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
866
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
867
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
868
+ argument.
869
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
870
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
871
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
872
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
873
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
874
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
875
+ input argument.
876
+ output_type (`str`, *optional*, defaults to `"pil"`):
877
+ The output format of the generate image. Choose between
878
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
879
+ return_dict (`bool`, *optional*, defaults to `True`):
880
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
881
+ a plain tuple.
882
+ joint_attention_kwargs (`dict`, *optional*):
883
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
884
+ `self.processor` in
885
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
886
+ callback_on_step_end (`Callable`, *optional*):
887
+ A function that calls at the end of each denoising steps during the inference. The function is called
888
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
889
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
890
+ `callback_on_step_end_tensor_inputs`.
891
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
892
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
893
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
894
+ `._callback_tensor_inputs` attribute of your pipeline class.
895
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
896
+ skip_guidance_layers (`List[int]`, *optional*):
897
+ A list of integers that specify layers to skip during guidance. If not provided, all layers will be
898
+ used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
899
+ Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
900
+ skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
901
+ `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
902
+ with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
903
+ with a scale of `1`.
904
+ skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
905
+ `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
906
+ `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
907
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
908
+ skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
909
+ `skip_guidance_layers` will start. The guidance will be applied to the layers specified in
910
+ `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
911
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
912
+ mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
913
+
914
+ Examples:
915
+
916
+ Returns:
917
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
918
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
919
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
920
+ """
921
+
922
+ height = height or self.default_sample_size * self.vae_scale_factor
923
+ width = width or self.default_sample_size * self.vae_scale_factor
924
+
925
+ # 1. Check inputs. Raise error if not correct
926
+ self.check_inputs(
927
+ prompt,
928
+ prompt_2,
929
+ prompt_3,
930
+ height,
931
+ width,
932
+ negative_prompt=negative_prompt,
933
+ negative_prompt_2=negative_prompt_2,
934
+ negative_prompt_3=negative_prompt_3,
935
+ prompt_embeds=prompt_embeds,
936
+ negative_prompt_embeds=negative_prompt_embeds,
937
+ pooled_prompt_embeds=pooled_prompt_embeds,
938
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
939
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
940
+ max_sequence_length=max_sequence_length,
941
+ )
942
+
943
+ self._guidance_scale = guidance_scale
944
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
945
+ self._clip_skip = clip_skip
946
+ self._joint_attention_kwargs = joint_attention_kwargs
947
+ self._interrupt = False
948
+
949
+ # 2. Define call parameters
950
+ if prompt is not None and isinstance(prompt, str):
951
+ batch_size = 1
952
+ elif prompt is not None and isinstance(prompt, list):
953
+ batch_size = len(prompt)
954
+ else:
955
+ batch_size = prompt_embeds.shape[0]
956
+
957
+ device = self._execution_device
958
+
959
+ lora_scale = (
960
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
961
+ )
962
+ (
963
+ prompt_embeds,
964
+ negative_prompt_embeds,
965
+ pooled_prompt_embeds,
966
+ negative_pooled_prompt_embeds,
967
+ ) = self.encode_prompt(
968
+ prompt=prompt,
969
+ prompt_2=prompt_2,
970
+ prompt_3=prompt_3,
971
+ negative_prompt=negative_prompt,
972
+ negative_prompt_2=negative_prompt_2,
973
+ negative_prompt_3=negative_prompt_3,
974
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
975
+ prompt_embeds=prompt_embeds,
976
+ negative_prompt_embeds=negative_prompt_embeds,
977
+ pooled_prompt_embeds=pooled_prompt_embeds,
978
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
979
+ device=device,
980
+ clip_skip=self.clip_skip,
981
+ num_images_per_prompt=num_images_per_prompt,
982
+ max_sequence_length=max_sequence_length,
983
+ lora_scale=lora_scale,
984
+ )
985
+
986
+ if self.do_classifier_free_guidance:
987
+ if skip_guidance_layers is not None:
988
+ original_prompt_embeds = prompt_embeds
989
+ original_pooled_prompt_embeds = pooled_prompt_embeds
990
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
991
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
992
+
993
+ # 4. Prepare latent variables
994
+ num_channels_latents = self.transformer.config.in_channels
995
+ latents = self.prepare_latents(
996
+ batch_size * num_images_per_prompt,
997
+ num_channels_latents,
998
+ height,
999
+ width,
1000
+ prompt_embeds.dtype,
1001
+ device,
1002
+ generator,
1003
+ latents,
1004
+ )
1005
+
1006
+ # 5. Prepare timesteps
1007
+ scheduler_kwargs = {}
1008
+ if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
1009
+ _, _, height, width = latents.shape
1010
+ image_seq_len = (height // self.transformer.config.patch_size) * (
1011
+ width // self.transformer.config.patch_size
1012
+ )
1013
+ mu = calculate_shift(
1014
+ image_seq_len,
1015
+ self.scheduler.config.get("base_image_seq_len", 256),
1016
+ self.scheduler.config.get("max_image_seq_len", 4096),
1017
+ self.scheduler.config.get("base_shift", 0.5),
1018
+ self.scheduler.config.get("max_shift", 1.16),
1019
+ )
1020
+ scheduler_kwargs["mu"] = mu
1021
+ elif mu is not None:
1022
+ scheduler_kwargs["mu"] = mu
1023
+ timesteps, num_inference_steps = retrieve_timesteps(
1024
+ self.scheduler,
1025
+ num_inference_steps,
1026
+ device,
1027
+ sigmas=sigmas,
1028
+ **scheduler_kwargs,
1029
+ )
1030
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1031
+ self._num_timesteps = len(timesteps)
1032
+
1033
+ # 6. Prepare image embeddings
1034
+ if image is not None:
1035
+ if not isinstance(image, (list, tuple)):
1036
+ image = (image,)
1037
+ assert image[0].height == height and image[0].width == width
1038
+ image = pil_list_to_tensor(image).to(device=self.transformer.device,
1039
+ dtype=self.transformer.dtype)
1040
+
1041
+ image_latents = self.vae.encode(image).latent_dist.sample()
1042
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
1043
+
1044
+ image_latents = image_latents[:, None].expand(-1, num_images_per_prompt, -1, -1, -1)
1045
+ image_latents = image_latents.flatten(0, 1)
1046
+ else:
1047
+ image_latents = None
1048
+
1049
+ # 7. Denoising loop
1050
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1051
+ for i, t in enumerate(timesteps):
1052
+ if self.interrupt:
1053
+ continue
1054
+
1055
+ # expand the latents if we are doing classifier free guidance
1056
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1057
+ if image_latents is not None:
1058
+ ref_latent_model_input = torch.cat([image_latents] * 2) if self.do_classifier_free_guidance else image_latents
1059
+ else:
1060
+ ref_latent_model_input = None
1061
+
1062
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1063
+ timestep = t.expand(latent_model_input.shape[0])
1064
+
1065
+ noise_pred = self.transformer(
1066
+ hidden_states=latent_model_input,
1067
+ ref_hidden_states=ref_latent_model_input,
1068
+ timestep=timestep,
1069
+ encoder_hidden_states=prompt_embeds,
1070
+ pooled_projections=pooled_prompt_embeds,
1071
+ joint_attention_kwargs=self.joint_attention_kwargs,
1072
+ return_dict=False,
1073
+ )[0]
1074
+
1075
+ # perform guidance
1076
+ if self.do_classifier_free_guidance:
1077
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1078
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1079
+ should_skip_layers = (
1080
+ True
1081
+ if i > num_inference_steps * skip_layer_guidance_start
1082
+ and i < num_inference_steps * skip_layer_guidance_stop
1083
+ else False
1084
+ )
1085
+ if skip_guidance_layers is not None and should_skip_layers:
1086
+ timestep = t.expand(latents.shape[0])
1087
+ latent_model_input = latents
1088
+ noise_pred_skip_layers = self.transformer(
1089
+ hidden_states=latent_model_input,
1090
+ timestep=timestep,
1091
+ encoder_hidden_states=original_prompt_embeds,
1092
+ pooled_projections=original_pooled_prompt_embeds,
1093
+ joint_attention_kwargs=self.joint_attention_kwargs,
1094
+ return_dict=False,
1095
+ skip_layers=skip_guidance_layers,
1096
+ )[0]
1097
+ noise_pred = (
1098
+ noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
1099
+ )
1100
+
1101
+ # compute the previous noisy sample x_t -> x_t-1
1102
+ latents_dtype = latents.dtype
1103
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1104
+
1105
+ if latents.dtype != latents_dtype:
1106
+ if torch.backends.mps.is_available():
1107
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1108
+ latents = latents.to(latents_dtype)
1109
+
1110
+ if callback_on_step_end is not None:
1111
+ callback_kwargs = {}
1112
+ for k in callback_on_step_end_tensor_inputs:
1113
+ callback_kwargs[k] = locals()[k]
1114
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1115
+
1116
+ latents = callback_outputs.pop("latents", latents)
1117
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1118
+ pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds)
1119
+
1120
+ # call the callback, if provided
1121
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1122
+ progress_bar.update()
1123
+
1124
+ if XLA_AVAILABLE:
1125
+ xm.mark_step()
1126
+
1127
+ if output_type == "latent":
1128
+ image = latents
1129
+
1130
+ else:
1131
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1132
+
1133
+ image = self.vae.decode(latents, return_dict=False)[0]
1134
+ image = self.image_processor.postprocess(image, output_type=output_type)
1135
+
1136
+ # Offload all models
1137
+ self.maybe_free_model_hooks()
1138
+
1139
+ if not return_dict:
1140
+ return (image,)
1141
+
1142
+ return StableDiffusion3PipelineOutput(images=image)
unipicv2/stable_diffusion_3_conditioner.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ # from transformers.modeling_utils import PreTrainedModel
4
+ from diffusers.configuration_utils import register_to_config, ConfigMixin
5
+ from unipicv2.modeling_connector import ConnectorEncoder
6
+ from unipicv2.configuration_connector import ConnectorConfig
7
+ from diffusers.models.modeling_utils import ModelMixin
8
+
9
+
10
+ class StableDiffusion3Conditioner(ModelMixin, ConfigMixin):
11
+ model_type: str = "sd3_conditioner" # stored into config for hub niceties
12
+
13
+ @register_to_config
14
+ def __init__(
15
+ self,
16
+ connector_config: dict, # dict passed to ConnectorConfig(**connector)
17
+ num_queries: int = 256,
18
+ llm_hidden_size: int = 3584,
19
+ pooled_projection_dim: int = 2048,
20
+ joint_attention_dim: int = 4096,
21
+ ):
22
+ super().__init__()
23
+
24
+ self.connector = ConnectorEncoder(ConnectorConfig(**connector_config))
25
+ self.projector_1 = nn.Linear(llm_hidden_size, self.connector.config.hidden_size)
26
+ self.projector_2 = nn.Linear(self.connector.config.hidden_size, pooled_projection_dim)
27
+ self.projector_3 = nn.Linear(self.connector.config.hidden_size, joint_attention_dim)
28
+ self.meta_queries = nn.Parameter(torch.zeros(num_queries, llm_hidden_size))
29
+
30
+ def _init_weights(self, module):
31
+ pass
32
+
33
+ def forward(self, x: torch.Tensor):
34
+ """
35
+ x: (batch, seq_len, llm_hidden_size)
36
+ Returns:
37
+ prompt_embeds: (batch, seq_len, joint_attention_dim)
38
+ pooled_prompt_embeds: (batch, pooled_projection_dim)
39
+ """
40
+ x = self.projector_1(x)
41
+ x = self.connector(x) # expects (B, L, hidden)
42
+ pooled_prompt_embeds = self.projector_2(x.mean(1))
43
+ prompt_embeds = self.projector_3(x)
44
+
45
+ return prompt_embeds, pooled_prompt_embeds
46
+
47
+
48
+
49
+ if __name__ == "__main__":
50
+ import torch
51
+ import argparse
52
+ import os
53
+
54
+ parser = argparse.ArgumentParser()
55
+ parser.add_argument("--checkpoint", type=str, default=None)
56
+ parser.add_argument("--output", type=str, default=None)
57
+
58
+ args = parser.parse_args()
59
+
60
+ pretrained_model_name_or_path = "stabilityai/stable-diffusion-3.5-medium"
61
+
62
+ conditioner = StableDiffusion3Conditioner(
63
+ num_queries=256,
64
+ connector_config=dict(
65
+ hidden_size=1536,
66
+ intermediate_size=8960,
67
+ num_hidden_layers=24,
68
+ _attn_implementation='flash_attention_2',
69
+ num_attention_heads=24, ),
70
+ llm_hidden_size=3584,
71
+ pooled_projection_dim=2048,
72
+ joint_attention_dim=4096,
73
+ ).bfloat16()
74
+
75
+ checkpoint = torch.load(args.checkpoint)
76
+
77
+ info = conditioner.load_state_dict(checkpoint, strict=False)
78
+ import pdb; pdb.set_trace()
79
+
80
+ os.makedirs(args.output, exist_ok=True)
81
+
82
+ conditioner.save_pretrained(args.output)
unipicv2/transformer_sd3_kontext.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
8
+ from diffusers.models.attention import FeedForward, JointTransformerBlock
9
+ from diffusers.models.attention_processor import (
10
+ Attention,
11
+ AttentionProcessor,
12
+ FusedJointAttnProcessor2_0,
13
+ JointAttnProcessor2_0,
14
+ )
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
17
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
18
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
19
+ from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
20
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
21
+
22
+
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+
26
+ @maybe_allow_in_graph
27
+ class SD3SingleTransformerBlock(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ num_attention_heads: int,
32
+ attention_head_dim: int,
33
+ ):
34
+ super().__init__()
35
+
36
+ self.norm1 = AdaLayerNormZero(dim)
37
+ self.attn = Attention(
38
+ query_dim=dim,
39
+ dim_head=attention_head_dim,
40
+ heads=num_attention_heads,
41
+ out_dim=dim,
42
+ bias=True,
43
+ processor=JointAttnProcessor2_0(),
44
+ eps=1e-6,
45
+ )
46
+
47
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
48
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
49
+
50
+ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
51
+ # 1. Attention
52
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
53
+ attn_output = self.attn(hidden_states=norm_hidden_states, encoder_hidden_states=None)
54
+ attn_output = gate_msa.unsqueeze(1) * attn_output
55
+ hidden_states = hidden_states + attn_output
56
+
57
+ # 2. Feed Forward
58
+ norm_hidden_states = self.norm2(hidden_states)
59
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
60
+ ff_output = self.ff(norm_hidden_states)
61
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
62
+ hidden_states = hidden_states + ff_output
63
+
64
+ return hidden_states
65
+
66
+
67
+ class SD3Transformer2DKontextModel(
68
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
69
+ ):
70
+ """
71
+ The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
72
+
73
+ Parameters:
74
+ sample_size (`int`, defaults to `128`):
75
+ The width/height of the latents. This is fixed during training since it is used to learn a number of
76
+ position embeddings.
77
+ patch_size (`int`, defaults to `2`):
78
+ Patch size to turn the input data into small patches.
79
+ in_channels (`int`, defaults to `16`):
80
+ The number of latent channels in the input.
81
+ num_layers (`int`, defaults to `18`):
82
+ The number of layers of transformer blocks to use.
83
+ attention_head_dim (`int`, defaults to `64`):
84
+ The number of channels in each head.
85
+ num_attention_heads (`int`, defaults to `18`):
86
+ The number of heads to use for multi-head attention.
87
+ joint_attention_dim (`int`, defaults to `4096`):
88
+ The embedding dimension to use for joint text-image attention.
89
+ caption_projection_dim (`int`, defaults to `1152`):
90
+ The embedding dimension of caption embeddings.
91
+ pooled_projection_dim (`int`, defaults to `2048`):
92
+ The embedding dimension of pooled text projections.
93
+ out_channels (`int`, defaults to `16`):
94
+ The number of latent channels in the output.
95
+ pos_embed_max_size (`int`, defaults to `96`):
96
+ The maximum latent height/width of positional embeddings.
97
+ dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
98
+ The number of dual-stream transformer blocks to use.
99
+ qk_norm (`str`, *optional*, defaults to `None`):
100
+ The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
101
+ """
102
+
103
+ _supports_gradient_checkpointing = True
104
+ _no_split_modules = ["JointTransformerBlock"]
105
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
106
+
107
+ @register_to_config
108
+ def __init__(
109
+ self,
110
+ sample_size: int = 128,
111
+ patch_size: int = 2,
112
+ in_channels: int = 16,
113
+ num_layers: int = 18,
114
+ attention_head_dim: int = 64,
115
+ num_attention_heads: int = 18,
116
+ joint_attention_dim: int = 4096,
117
+ caption_projection_dim: int = 1152,
118
+ pooled_projection_dim: int = 2048,
119
+ out_channels: int = 16,
120
+ pos_embed_max_size: int = 96,
121
+ dual_attention_layers: Tuple[
122
+ int, ...
123
+ ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
124
+ qk_norm: Optional[str] = None,
125
+ ):
126
+ super().__init__()
127
+ self.out_channels = out_channels if out_channels is not None else in_channels
128
+ self.inner_dim = num_attention_heads * attention_head_dim
129
+
130
+ self.pos_embed = PatchEmbed(
131
+ height=sample_size,
132
+ width=sample_size,
133
+ patch_size=patch_size,
134
+ in_channels=in_channels,
135
+ embed_dim=self.inner_dim,
136
+ pos_embed_max_size=pos_embed_max_size, # hard-code for now.
137
+ )
138
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
139
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
140
+ )
141
+ self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
142
+
143
+ self.transformer_blocks = nn.ModuleList(
144
+ [
145
+ JointTransformerBlock(
146
+ dim=self.inner_dim,
147
+ num_attention_heads=num_attention_heads,
148
+ attention_head_dim=attention_head_dim,
149
+ context_pre_only=i == num_layers - 1,
150
+ qk_norm=qk_norm,
151
+ use_dual_attention=True if i in dual_attention_layers else False,
152
+ )
153
+ for i in range(num_layers)
154
+ ]
155
+ )
156
+
157
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
158
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
159
+
160
+ self.gradient_checkpointing = False
161
+
162
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
163
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
164
+ """
165
+ Sets the attention processor to use [feed forward
166
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
167
+
168
+ Parameters:
169
+ chunk_size (`int`, *optional*):
170
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
171
+ over each tensor of dim=`dim`.
172
+ dim (`int`, *optional*, defaults to `0`):
173
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
174
+ or dim=1 (sequence length).
175
+ """
176
+ if dim not in [0, 1]:
177
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
178
+
179
+ # By default chunk size is 1
180
+ chunk_size = chunk_size or 1
181
+
182
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
183
+ if hasattr(module, "set_chunk_feed_forward"):
184
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
185
+
186
+ for child in module.children():
187
+ fn_recursive_feed_forward(child, chunk_size, dim)
188
+
189
+ for module in self.children():
190
+ fn_recursive_feed_forward(module, chunk_size, dim)
191
+
192
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
193
+ def disable_forward_chunking(self):
194
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
195
+ if hasattr(module, "set_chunk_feed_forward"):
196
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
197
+
198
+ for child in module.children():
199
+ fn_recursive_feed_forward(child, chunk_size, dim)
200
+
201
+ for module in self.children():
202
+ fn_recursive_feed_forward(module, None, 0)
203
+
204
+ @property
205
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
206
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
207
+ r"""
208
+ Returns:
209
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
210
+ indexed by its weight name.
211
+ """
212
+ # set recursively
213
+ processors = {}
214
+
215
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
216
+ if hasattr(module, "get_processor"):
217
+ processors[f"{name}.processor"] = module.get_processor()
218
+
219
+ for sub_name, child in module.named_children():
220
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
221
+
222
+ return processors
223
+
224
+ for name, module in self.named_children():
225
+ fn_recursive_add_processors(name, module, processors)
226
+
227
+ return processors
228
+
229
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
230
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
231
+ r"""
232
+ Sets the attention processor to use to compute attention.
233
+
234
+ Parameters:
235
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
236
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
237
+ for **all** `Attention` layers.
238
+
239
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
240
+ processor. This is strongly recommended when setting trainable attention processors.
241
+
242
+ """
243
+ count = len(self.attn_processors.keys())
244
+
245
+ if isinstance(processor, dict) and len(processor) != count:
246
+ raise ValueError(
247
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
248
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
249
+ )
250
+
251
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
252
+ if hasattr(module, "set_processor"):
253
+ if not isinstance(processor, dict):
254
+ module.set_processor(processor)
255
+ else:
256
+ module.set_processor(processor.pop(f"{name}.processor"))
257
+
258
+ for sub_name, child in module.named_children():
259
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
260
+
261
+ for name, module in self.named_children():
262
+ fn_recursive_attn_processor(name, module, processor)
263
+
264
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
265
+ def fuse_qkv_projections(self):
266
+ """
267
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
268
+ are fused. For cross-attention modules, key and value projection matrices are fused.
269
+
270
+ <Tip warning={true}>
271
+
272
+ This API is 🧪 experimental.
273
+
274
+ </Tip>
275
+ """
276
+ self.original_attn_processors = None
277
+
278
+ for _, attn_processor in self.attn_processors.items():
279
+ if "Added" in str(attn_processor.__class__.__name__):
280
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
281
+
282
+ self.original_attn_processors = self.attn_processors
283
+
284
+ for module in self.modules():
285
+ if isinstance(module, Attention):
286
+ module.fuse_projections(fuse=True)
287
+
288
+ self.set_attn_processor(FusedJointAttnProcessor2_0())
289
+
290
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
291
+ def unfuse_qkv_projections(self):
292
+ """Disables the fused QKV projection if enabled.
293
+
294
+ <Tip warning={true}>
295
+
296
+ This API is 🧪 experimental.
297
+
298
+ </Tip>
299
+
300
+ """
301
+ if self.original_attn_processors is not None:
302
+ self.set_attn_processor(self.original_attn_processors)
303
+
304
+ def forward(
305
+ self,
306
+ hidden_states: torch.Tensor,
307
+ encoder_hidden_states: torch.Tensor = None,
308
+ ref_hidden_states: torch.Tensor = None,
309
+ pooled_projections: torch.Tensor = None,
310
+ timestep: torch.LongTensor = None,
311
+ block_controlnet_hidden_states: List = None,
312
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
313
+ return_dict: bool = True,
314
+ skip_layers: Optional[List[int]] = None,
315
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
316
+ """
317
+ The [`SD3Transformer2DModel`] forward method.
318
+
319
+ Args:
320
+ hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
321
+ Input `hidden_states`.
322
+ ref_hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
323
+ Input `ref_hidden_states`.
324
+ encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
325
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
326
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`):
327
+ Embeddings projected from the embeddings of input conditions.
328
+ timestep (`torch.LongTensor`):
329
+ Used to indicate denoising step.
330
+ block_controlnet_hidden_states (`list` of `torch.Tensor`):
331
+ A list of tensors that if specified are added to the residuals of transformer blocks.
332
+ joint_attention_kwargs (`dict`, *optional*):
333
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
334
+ `self.processor` in
335
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
336
+ return_dict (`bool`, *optional*, defaults to `True`):
337
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
338
+ tuple.
339
+ skip_layers (`list` of `int`, *optional*):
340
+ A list of layer indices to skip during the forward pass.
341
+
342
+ Returns:
343
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
344
+ `tuple` where the first element is the sample tensor.
345
+ """
346
+ if joint_attention_kwargs is not None:
347
+ joint_attention_kwargs = joint_attention_kwargs.copy()
348
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
349
+ else:
350
+ lora_scale = 1.0
351
+
352
+ if USE_PEFT_BACKEND:
353
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
354
+ scale_lora_layers(self, lora_scale)
355
+ else:
356
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
357
+ logger.warning(
358
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
359
+ )
360
+
361
+ height, width = hidden_states.shape[-2:]
362
+
363
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
364
+ if ref_hidden_states is not None:
365
+ ref_hidden_states = self.pos_embed(ref_hidden_states)
366
+ assert ref_hidden_states.shape == hidden_states.shape
367
+ hidden_states = torch.cat([ref_hidden_states, hidden_states], dim=1)
368
+
369
+ temb = self.time_text_embed(timestep, pooled_projections)
370
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
371
+
372
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
373
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
374
+ ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)
375
+
376
+ joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)
377
+
378
+ for index_block, block in enumerate(self.transformer_blocks):
379
+ # Skip specified layers
380
+ is_skip = True if skip_layers is not None and index_block in skip_layers else False
381
+
382
+ if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
383
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
384
+ block,
385
+ hidden_states,
386
+ encoder_hidden_states,
387
+ temb,
388
+ joint_attention_kwargs,
389
+ )
390
+ elif not is_skip:
391
+ encoder_hidden_states, hidden_states = block(
392
+ hidden_states=hidden_states,
393
+ encoder_hidden_states=encoder_hidden_states,
394
+ temb=temb,
395
+ joint_attention_kwargs=joint_attention_kwargs,
396
+ )
397
+
398
+ # controlnet residual
399
+ if block_controlnet_hidden_states is not None and block.context_pre_only is False:
400
+ interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
401
+ hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
402
+
403
+ patch_size = self.config.patch_size
404
+ height = height // patch_size
405
+ width = width // patch_size
406
+ hidden_states = hidden_states[:, -height*width:, :]
407
+
408
+ hidden_states = self.norm_out(hidden_states, temb)
409
+ hidden_states = self.proj_out(hidden_states)
410
+
411
+ # unpatchify
412
+ hidden_states = hidden_states.reshape(
413
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
414
+ )
415
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
416
+ output = hidden_states.reshape(
417
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
418
+ )
419
+
420
+ if USE_PEFT_BACKEND:
421
+ # remove `lora_scale` from each PEFT layer
422
+ unscale_lora_layers(self, lora_scale)
423
+
424
+ if not return_dict:
425
+ return (output,)
426
+
427
+ return Transformer2DModelOutput(sample=output)
428
+
429
+
430
+ if __name__ == "__main__":
431
+ import torch
432
+ import argparse
433
+ import os
434
+
435
+ parser = argparse.ArgumentParser()
436
+ parser.add_argument("--checkpoint", type=str, default=None)
437
+ parser.add_argument("--output", type=str, default=None)
438
+
439
+ args = parser.parse_args()
440
+
441
+ pretrained_model_name_or_path = "stabilityai/stable-diffusion-3.5-medium"
442
+
443
+ transformer = SD3Transformer2DKontextModel.from_pretrained(
444
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
445
+ subfolder="transformer",
446
+ torch_dtype=torch.bfloat16)
447
+
448
+ checkpoint = torch.load(args.checkpoint)
449
+ checkpoint = {k[len('transformer.'):]: v for k, v in checkpoint.items() if 'transformer.' in k}
450
+
451
+ transformer.load_state_dict(checkpoint)
452
+
453
+ os.makedirs(args.output, exist_ok=True)
454
+
455
+ transformer.save_pretrained(args.output)
user.png ADDED