AlexGraikos commited on
Commit
86d49c5
·
verified ·
1 Parent(s): 1d4f751

Create pixcell_controlnet_transformer_2d.py

Browse files
Files changed (1) hide show
  1. pixcell_controlnet_transformer_2d.py +684 -0
pixcell_controlnet_transformer_2d.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Union
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.utils import is_torch_version, logging
21
+ from diffusers.models.attention import BasicTransformerBlock
22
+ from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
23
+ from diffusers.models.embeddings import PatchEmbed
24
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.normalization import AdaLayerNormSingle
27
+ from diffusers.models.activations import deprecate, FP32SiLU
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ # PixCell UNI conditioning
34
+ def pixcell_get_2d_sincos_pos_embed(
35
+ embed_dim,
36
+ grid_size,
37
+ cls_token=False,
38
+ extra_tokens=0,
39
+ interpolation_scale=1.0,
40
+ base_size=16,
41
+ device: Optional[torch.device] = None,
42
+ phase=0,
43
+ output_type: str = "np",
44
+ ):
45
+ """
46
+ Creates 2D sinusoidal positional embeddings.
47
+
48
+ Args:
49
+ embed_dim (`int`):
50
+ The embedding dimension.
51
+ grid_size (`int`):
52
+ The size of the grid height and width.
53
+ cls_token (`bool`, defaults to `False`):
54
+ Whether or not to add a classification token.
55
+ extra_tokens (`int`, defaults to `0`):
56
+ The number of extra tokens to add.
57
+ interpolation_scale (`float`, defaults to `1.0`):
58
+ The scale of the interpolation.
59
+
60
+ Returns:
61
+ pos_embed (`torch.Tensor`):
62
+ Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
63
+ embed_dim]` if using cls_token
64
+ """
65
+ if output_type == "np":
66
+ deprecation_message = (
67
+ "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
68
+ " `from_numpy` is no longer required."
69
+ " Pass `output_type='pt' to use the new version now."
70
+ )
71
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
72
+ raise ValueError("Not supported")
73
+ if isinstance(grid_size, int):
74
+ grid_size = (grid_size, grid_size)
75
+
76
+ grid_h = (
77
+ torch.arange(grid_size[0], device=device, dtype=torch.float32)
78
+ / (grid_size[0] / base_size)
79
+ / interpolation_scale
80
+ )
81
+ grid_w = (
82
+ torch.arange(grid_size[1], device=device, dtype=torch.float32)
83
+ / (grid_size[1] / base_size)
84
+ / interpolation_scale
85
+ )
86
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
87
+ grid = torch.stack(grid, dim=0)
88
+
89
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
90
+ pos_embed = pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=phase, output_type=output_type)
91
+ if cls_token and extra_tokens > 0:
92
+ pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
93
+ return pos_embed
94
+
95
+
96
+ def pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=0, output_type="np"):
97
+ r"""
98
+ This function generates 2D sinusoidal positional embeddings from a grid.
99
+
100
+ Args:
101
+ embed_dim (`int`): The embedding dimension.
102
+ grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
103
+
104
+ Returns:
105
+ `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
106
+ """
107
+ if output_type == "np":
108
+ deprecation_message = (
109
+ "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
110
+ " `from_numpy` is no longer required."
111
+ " Pass `output_type='pt' to use the new version now."
112
+ )
113
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
114
+ raise ValueError("Not supported")
115
+ if embed_dim % 2 != 0:
116
+ raise ValueError("embed_dim must be divisible by 2")
117
+
118
+ # use half of dimensions to encode grid_h
119
+ emb_h = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], phase=phase, output_type=output_type) # (H*W, D/2)
120
+ emb_w = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], phase=phase, output_type=output_type) # (H*W, D/2)
121
+
122
+ emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
123
+ return emb
124
+
125
+
126
+ def pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim, pos, phase=0, output_type="np"):
127
+ """
128
+ This function generates 1D positional embeddings from a grid.
129
+
130
+ Args:
131
+ embed_dim (`int`): The embedding dimension `D`
132
+ pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
133
+
134
+ Returns:
135
+ `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
136
+ """
137
+ if output_type == "np":
138
+ deprecation_message = (
139
+ "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
140
+ " `from_numpy` is no longer required."
141
+ " Pass `output_type='pt' to use the new version now."
142
+ )
143
+ deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
144
+ raise ValueError("Not supported")
145
+ if embed_dim % 2 != 0:
146
+ raise ValueError("embed_dim must be divisible by 2")
147
+
148
+ omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
149
+ omega /= embed_dim / 2.0
150
+ omega = 1.0 / 10000**omega # (D/2,)
151
+
152
+ pos = pos.reshape(-1) + phase # (M,)
153
+ out = torch.outer(pos, omega) # (M, D/2), outer product
154
+
155
+ emb_sin = torch.sin(out) # (M, D/2)
156
+ emb_cos = torch.cos(out) # (M, D/2)
157
+
158
+ emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
159
+ return emb
160
+
161
+
162
+ class PixcellUNIProjection(nn.Module):
163
+ """
164
+ Projects UNI embeddings. Also handles dropout for classifier-free guidance.
165
+
166
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
167
+ """
168
+
169
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", num_tokens=1):
170
+ super().__init__()
171
+ if out_features is None:
172
+ out_features = hidden_size
173
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
174
+ if act_fn == "gelu_tanh":
175
+ self.act_1 = nn.GELU(approximate="tanh")
176
+ elif act_fn == "silu":
177
+ self.act_1 = nn.SiLU()
178
+ elif act_fn == "silu_fp32":
179
+ self.act_1 = FP32SiLU()
180
+ else:
181
+ raise ValueError(f"Unknown activation function: {act_fn}")
182
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
183
+
184
+ self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features ** 0.5))
185
+
186
+ def forward(self, caption):
187
+ hidden_states = self.linear_1(caption)
188
+ hidden_states = self.act_1(hidden_states)
189
+ hidden_states = self.linear_2(hidden_states)
190
+ return hidden_states
191
+
192
+ class UNIPosEmbed(nn.Module):
193
+ """
194
+ Adds positional embeddings to the UNI conditions.
195
+
196
+ Args:
197
+ height (`int`, defaults to `224`): The height of the image.
198
+ width (`int`, defaults to `224`): The width of the image.
199
+ patch_size (`int`, defaults to `16`): The size of the patches.
200
+ in_channels (`int`, defaults to `3`): The number of input channels.
201
+ embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
202
+ layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
203
+ flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
204
+ bias (`bool`, defaults to `True`): Whether or not to use bias.
205
+ interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
206
+ pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
207
+ pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
208
+ """
209
+
210
+ def __init__(
211
+ self,
212
+ height=1,
213
+ width=1,
214
+ base_size=16,
215
+ embed_dim=768,
216
+ interpolation_scale=1,
217
+ pos_embed_type="sincos",
218
+ ):
219
+ super().__init__()
220
+
221
+ num_embeds = height*width
222
+ grid_size = int(num_embeds ** 0.5)
223
+
224
+ if pos_embed_type == "sincos":
225
+ y_pos_embed = pixcell_get_2d_sincos_pos_embed(
226
+ embed_dim,
227
+ grid_size,
228
+ base_size=base_size,
229
+ interpolation_scale=interpolation_scale,
230
+ output_type="pt",
231
+ phase = base_size // num_embeds
232
+ )
233
+ self.register_buffer("y_pos_embed", y_pos_embed.float().unsqueeze(0))
234
+ else:
235
+ raise ValueError("`pos_embed_type` not supported")
236
+
237
+ def forward(self, uni_embeds):
238
+ return (uni_embeds + self.y_pos_embed).to(uni_embeds.dtype)
239
+
240
+
241
+
242
+ class PixCellTransformer2DModelControlNet(ModelMixin, ConfigMixin):
243
+ r"""
244
+ A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
245
+ https://arxiv.org/abs/2403.04692). Modified for the pathology domain.
246
+
247
+ Parameters:
248
+ num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
249
+ attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
250
+ in_channels (int, defaults to 4): The number of channels in the input.
251
+ out_channels (int, optional):
252
+ The number of channels in the output. Specify this parameter if the output channel number differs from the
253
+ input.
254
+ num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
255
+ dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
256
+ norm_num_groups (int, optional, defaults to 32):
257
+ Number of groups for group normalization within Transformer blocks.
258
+ cross_attention_dim (int, optional):
259
+ The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension.
260
+ attention_bias (bool, optional, defaults to True):
261
+ Configure if the Transformer blocks' attention should contain a bias parameter.
262
+ sample_size (int, defaults to 128):
263
+ The width of the latent images. This parameter is fixed during training.
264
+ patch_size (int, defaults to 2):
265
+ Size of the patches the model processes, relevant for architectures working on non-sequential data.
266
+ activation_fn (str, optional, defaults to "gelu-approximate"):
267
+ Activation function to use in feed-forward networks within Transformer blocks.
268
+ num_embeds_ada_norm (int, optional, defaults to 1000):
269
+ Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
270
+ inference.
271
+ upcast_attention (bool, optional, defaults to False):
272
+ If true, upcasts the attention mechanism dimensions for potentially improved performance.
273
+ norm_type (str, optional, defaults to "ada_norm_zero"):
274
+ Specifies the type of normalization used, can be 'ada_norm_zero'.
275
+ norm_elementwise_affine (bool, optional, defaults to False):
276
+ If true, enables element-wise affine parameters in the normalization layers.
277
+ norm_eps (float, optional, defaults to 1e-6):
278
+ A small constant added to the denominator in normalization layers to prevent division by zero.
279
+ interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings.
280
+ use_additional_conditions (bool, optional): If we're using additional conditions as inputs.
281
+ attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used.
282
+ caption_channels (int, optional, defaults to None):
283
+ Number of channels to use for projecting the caption embeddings.
284
+ use_linear_projection (bool, optional, defaults to False):
285
+ Deprecated argument. Will be removed in a future version.
286
+ num_vector_embeds (bool, optional, defaults to False):
287
+ Deprecated argument. Will be removed in a future version.
288
+ """
289
+
290
+ _supports_gradient_checkpointing = True
291
+ _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
292
+
293
+ @register_to_config
294
+ def __init__(
295
+ self,
296
+ num_attention_heads: int = 16,
297
+ attention_head_dim: int = 72,
298
+ in_channels: int = 4,
299
+ out_channels: Optional[int] = 8,
300
+ num_layers: int = 28,
301
+ dropout: float = 0.0,
302
+ norm_num_groups: int = 32,
303
+ cross_attention_dim: Optional[int] = 1152,
304
+ attention_bias: bool = True,
305
+ sample_size: int = 128,
306
+ patch_size: int = 2,
307
+ activation_fn: str = "gelu-approximate",
308
+ num_embeds_ada_norm: Optional[int] = 1000,
309
+ upcast_attention: bool = False,
310
+ norm_type: str = "ada_norm_single",
311
+ norm_elementwise_affine: bool = False,
312
+ norm_eps: float = 1e-6,
313
+ interpolation_scale: Optional[int] = None,
314
+ use_additional_conditions: Optional[bool] = None,
315
+ caption_channels: Optional[int] = None,
316
+ caption_num_tokens: int = 1,
317
+ attention_type: Optional[str] = "default",
318
+ ):
319
+ super().__init__()
320
+
321
+ # Validate inputs.
322
+ if norm_type != "ada_norm_single":
323
+ raise NotImplementedError(
324
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
325
+ )
326
+ elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None:
327
+ raise ValueError(
328
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
329
+ )
330
+
331
+ # Set some common variables used across the board.
332
+ self.attention_head_dim = attention_head_dim
333
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
334
+ self.out_channels = in_channels if out_channels is None else out_channels
335
+ if use_additional_conditions is None:
336
+ if sample_size == 128:
337
+ use_additional_conditions = True
338
+ else:
339
+ use_additional_conditions = False
340
+ self.use_additional_conditions = use_additional_conditions
341
+
342
+ self.gradient_checkpointing = False
343
+
344
+ # 2. Initialize the position embedding and transformer blocks.
345
+ self.height = self.config.sample_size
346
+ self.width = self.config.sample_size
347
+
348
+ interpolation_scale = (
349
+ self.config.interpolation_scale
350
+ if self.config.interpolation_scale is not None
351
+ else max(self.config.sample_size // 64, 1)
352
+ )
353
+ self.pos_embed = PatchEmbed(
354
+ height=self.config.sample_size,
355
+ width=self.config.sample_size,
356
+ patch_size=self.config.patch_size,
357
+ in_channels=self.config.in_channels,
358
+ embed_dim=self.inner_dim,
359
+ interpolation_scale=interpolation_scale,
360
+ )
361
+
362
+ self.transformer_blocks = nn.ModuleList(
363
+ [
364
+ BasicTransformerBlock(
365
+ self.inner_dim,
366
+ self.config.num_attention_heads,
367
+ self.config.attention_head_dim,
368
+ dropout=self.config.dropout,
369
+ cross_attention_dim=self.config.cross_attention_dim,
370
+ activation_fn=self.config.activation_fn,
371
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
372
+ attention_bias=self.config.attention_bias,
373
+ upcast_attention=self.config.upcast_attention,
374
+ norm_type=norm_type,
375
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
376
+ norm_eps=self.config.norm_eps,
377
+ attention_type=self.config.attention_type,
378
+ )
379
+ for _ in range(self.config.num_layers)
380
+ ]
381
+ )
382
+
383
+ # Initialize the positional embedding for the conditions for >1 UNI embeddings
384
+ if self.config.caption_num_tokens == 1:
385
+ self.y_pos_embed = None
386
+ else:
387
+ # 1:1 aspect ratio
388
+ self.uni_height = int(self.config.caption_num_tokens ** 0.5)
389
+ self.uni_width = int(self.config.caption_num_tokens ** 0.5)
390
+
391
+ self.y_pos_embed = UNIPosEmbed(
392
+ height=self.uni_height,
393
+ width=self.uni_width,
394
+ base_size=self.config.sample_size // self.config.patch_size,
395
+ embed_dim=self.config.caption_channels,
396
+ interpolation_scale=2, # Should this be fixed?
397
+ pos_embed_type="sincos", # This is fixed
398
+ )
399
+
400
+ # 3. Output blocks.
401
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
402
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
403
+ self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
404
+
405
+ self.adaln_single = AdaLayerNormSingle(
406
+ self.inner_dim, use_additional_conditions=self.use_additional_conditions
407
+ )
408
+ self.caption_projection = None
409
+ if self.config.caption_channels is not None:
410
+ self.caption_projection = PixcellUNIProjection(
411
+ in_features=self.config.caption_channels, hidden_size=self.inner_dim, num_tokens=self.config.caption_num_tokens,
412
+ )
413
+
414
+ def _set_gradient_checkpointing(self, module, value=False):
415
+ if hasattr(module, "gradient_checkpointing"):
416
+ module.gradient_checkpointing = value
417
+
418
+ @property
419
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
420
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
421
+ r"""
422
+ Returns:
423
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
424
+ indexed by its weight name.
425
+ """
426
+ # set recursively
427
+ processors = {}
428
+
429
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
430
+ if hasattr(module, "get_processor"):
431
+ processors[f"{name}.processor"] = module.get_processor()
432
+
433
+ for sub_name, child in module.named_children():
434
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
435
+
436
+ return processors
437
+
438
+ for name, module in self.named_children():
439
+ fn_recursive_add_processors(name, module, processors)
440
+
441
+ return processors
442
+
443
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
444
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
445
+ r"""
446
+ Sets the attention processor to use to compute attention.
447
+
448
+ Parameters:
449
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
450
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
451
+ for **all** `Attention` layers.
452
+
453
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
454
+ processor. This is strongly recommended when setting trainable attention processors.
455
+
456
+ """
457
+ count = len(self.attn_processors.keys())
458
+
459
+ if isinstance(processor, dict) and len(processor) != count:
460
+ raise ValueError(
461
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
462
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
463
+ )
464
+
465
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
466
+ if hasattr(module, "set_processor"):
467
+ if not isinstance(processor, dict):
468
+ module.set_processor(processor)
469
+ else:
470
+ module.set_processor(processor.pop(f"{name}.processor"))
471
+
472
+ for sub_name, child in module.named_children():
473
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
474
+
475
+ for name, module in self.named_children():
476
+ fn_recursive_attn_processor(name, module, processor)
477
+
478
+ def set_default_attn_processor(self):
479
+ """
480
+ Disables custom attention processors and sets the default attention implementation.
481
+
482
+ Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
483
+ """
484
+ self.set_attn_processor(AttnProcessor())
485
+
486
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
487
+ def fuse_qkv_projections(self):
488
+ """
489
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
490
+ are fused. For cross-attention modules, key and value projection matrices are fused.
491
+
492
+ <Tip warning={true}>
493
+
494
+ This API is 🧪 experimental.
495
+
496
+ </Tip>
497
+ """
498
+ self.original_attn_processors = None
499
+
500
+ for _, attn_processor in self.attn_processors.items():
501
+ if "Added" in str(attn_processor.__class__.__name__):
502
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
503
+
504
+ self.original_attn_processors = self.attn_processors
505
+
506
+ for module in self.modules():
507
+ if isinstance(module, Attention):
508
+ module.fuse_projections(fuse=True)
509
+
510
+ self.set_attn_processor(FusedAttnProcessor2_0())
511
+
512
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
513
+ def unfuse_qkv_projections(self):
514
+ """Disables the fused QKV projection if enabled.
515
+
516
+ <Tip warning={true}>
517
+
518
+ This API is 🧪 experimental.
519
+
520
+ </Tip>
521
+
522
+ """
523
+ if self.original_attn_processors is not None:
524
+ self.set_attn_processor(self.original_attn_processors)
525
+
526
+ def forward(
527
+ self,
528
+ hidden_states: torch.Tensor,
529
+ encoder_hidden_states: Optional[torch.Tensor] = None,
530
+ timestep: Optional[torch.LongTensor] = None,
531
+ controlnet_outputs: Optional[torch.Tensor] = None,
532
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
533
+ cross_attention_kwargs: Dict[str, Any] = None,
534
+ attention_mask: Optional[torch.Tensor] = None,
535
+ encoder_attention_mask: Optional[torch.Tensor] = None,
536
+ return_dict: bool = True,
537
+ ):
538
+ """
539
+ The [`PixCellTransformer2DModel`] forward method.
540
+
541
+ Args:
542
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
543
+ Input `hidden_states`.
544
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
545
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
546
+ self-attention.
547
+ timestep (`torch.LongTensor`, *optional*):
548
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
549
+ added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs.
550
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
551
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
552
+ `self.processor` in
553
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
554
+ attention_mask ( `torch.Tensor`, *optional*):
555
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
556
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
557
+ negative values to the attention scores corresponding to "discard" tokens.
558
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
559
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
560
+
561
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
562
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
563
+
564
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
565
+ above. This bias will be added to the cross-attention scores.
566
+ return_dict (`bool`, *optional*, defaults to `True`):
567
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
568
+ tuple.
569
+
570
+ Returns:
571
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
572
+ `tuple` where the first element is the sample tensor.
573
+ """
574
+ if self.use_additional_conditions and added_cond_kwargs is None:
575
+ raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
576
+
577
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
578
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
579
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
580
+ # expects mask of shape:
581
+ # [batch, key_tokens]
582
+ # adds singleton query_tokens dimension:
583
+ # [batch, 1, key_tokens]
584
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
585
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
586
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
587
+ if attention_mask is not None and attention_mask.ndim == 2:
588
+ # assume that mask is expressed as:
589
+ # (1 = keep, 0 = discard)
590
+ # convert mask into a bias that can be added to attention scores:
591
+ # (keep = +0, discard = -10000.0)
592
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
593
+ attention_mask = attention_mask.unsqueeze(1)
594
+
595
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
596
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
597
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
598
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
599
+
600
+ # 1. Input
601
+ batch_size = hidden_states.shape[0]
602
+ height, width = (
603
+ hidden_states.shape[-2] // self.config.patch_size,
604
+ hidden_states.shape[-1] // self.config.patch_size,
605
+ )
606
+ hidden_states = self.pos_embed(hidden_states)
607
+
608
+ timestep, embedded_timestep = self.adaln_single(
609
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
610
+ )
611
+
612
+ if self.caption_projection is not None:
613
+ # Add positional embeddings to conditions if >1 UNI are given
614
+ if self.y_pos_embed is not None:
615
+ encoder_hidden_states = self.y_pos_embed(encoder_hidden_states)
616
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
617
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
618
+
619
+ # 2. Blocks
620
+ for block_idx, block in enumerate(self.transformer_blocks):
621
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
622
+
623
+ def create_custom_forward(module, return_dict=None):
624
+ def custom_forward(*inputs):
625
+ if return_dict is not None:
626
+ return module(*inputs, return_dict=return_dict)
627
+ else:
628
+ return module(*inputs)
629
+
630
+ return custom_forward
631
+
632
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
633
+ hidden_states = torch.utils.checkpoint.checkpoint(
634
+ create_custom_forward(block),
635
+ hidden_states,
636
+ attention_mask,
637
+ encoder_hidden_states,
638
+ encoder_attention_mask,
639
+ timestep,
640
+ cross_attention_kwargs,
641
+ None,
642
+ **ckpt_kwargs,
643
+ )
644
+ else:
645
+ hidden_states = block(
646
+ hidden_states,
647
+ attention_mask=attention_mask,
648
+ encoder_hidden_states=encoder_hidden_states,
649
+ encoder_attention_mask=encoder_attention_mask,
650
+ timestep=timestep,
651
+ cross_attention_kwargs=cross_attention_kwargs,
652
+ class_labels=None,
653
+ )
654
+
655
+
656
+ # Add ControlNet outputs
657
+ if controlnet_outputs is not None:
658
+ if block_idx < len(controlnet_outputs):
659
+ hidden_states = hidden_states + controlnet_outputs[block_idx]
660
+
661
+
662
+ # 3. Output
663
+ shift, scale = (
664
+ self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
665
+ ).chunk(2, dim=1)
666
+ hidden_states = self.norm_out(hidden_states)
667
+ # Modulation
668
+ hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
669
+ hidden_states = self.proj_out(hidden_states)
670
+ hidden_states = hidden_states.squeeze(1)
671
+
672
+ # unpatchify
673
+ hidden_states = hidden_states.reshape(
674
+ shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
675
+ )
676
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
677
+ output = hidden_states.reshape(
678
+ shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
679
+ )
680
+
681
+ if not return_dict:
682
+ return (output,)
683
+
684
+ return Transformer2DModelOutput(sample=output)