AlexGraikos commited on
Commit
34b2ac3
·
verified ·
1 Parent(s): 8ed3158

Upload pixcell_controlnet.py

Browse files
Files changed (1) hide show
  1. pixcell_controlnet.py +675 -0
pixcell_controlnet.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import Any, Dict, Optional, Union, Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.utils import is_torch_version, logging
22
+ from diffusers.models.attention import BasicTransformerBlock
23
+ from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
24
+ from diffusers.models.embeddings import PatchEmbed
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.normalization import AdaLayerNormSingle
27
+ from diffusers.models.activations import deprecate, FP32SiLU
28
+
29
+ from diffusers.models.controlnet import zero_module
30
+ from diffusers.models.embeddings import PatchEmbed
31
+ from dataclasses import dataclass
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ # PixCell UNI conditioning
37
+ def pixcell_get_2d_sincos_pos_embed(
38
+ embed_dim,
39
+ grid_size,
40
+ cls_token=False,
41
+ extra_tokens=0,
42
+ interpolation_scale=1.0,
43
+ base_size=16,
44
+ device: Optional[torch.device] = None,
45
+ phase=0,
46
+ output_type: str = "np",
47
+ ):
48
+ """
49
+ Creates 2D sinusoidal positional embeddings.
50
+
51
+ Args:
52
+ embed_dim (`int`):
53
+ The embedding dimension.
54
+ grid_size (`int`):
55
+ The size of the grid height and width.
56
+ cls_token (`bool`, defaults to `False`):
57
+ Whether or not to add a classification token.
58
+ extra_tokens (`int`, defaults to `0`):
59
+ The number of extra tokens to add.
60
+ interpolation_scale (`float`, defaults to `1.0`):
61
+ The scale of the interpolation.
62
+
63
+ Returns:
64
+ pos_embed (`torch.Tensor`):
65
+ Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
66
+ embed_dim]` if using cls_token
67
+ """
68
+ if output_type == "np":
69
+ deprecation_message = (
70
+ "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
71
+ " `from_numpy` is no longer required."
72
+ " Pass `output_type='pt' to use the new version now."
73
+ )
74
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
75
+ raise ValueError("Not supported")
76
+ if isinstance(grid_size, int):
77
+ grid_size = (grid_size, grid_size)
78
+
79
+ grid_h = (
80
+ torch.arange(grid_size[0], device=device, dtype=torch.float32)
81
+ / (grid_size[0] / base_size)
82
+ / interpolation_scale
83
+ )
84
+ grid_w = (
85
+ torch.arange(grid_size[1], device=device, dtype=torch.float32)
86
+ / (grid_size[1] / base_size)
87
+ / interpolation_scale
88
+ )
89
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
90
+ grid = torch.stack(grid, dim=0)
91
+
92
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
93
+ pos_embed = pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=phase, output_type=output_type)
94
+ if cls_token and extra_tokens > 0:
95
+ pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
96
+ return pos_embed
97
+
98
+
99
+ def pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=0, output_type="np"):
100
+ r"""
101
+ This function generates 2D sinusoidal positional embeddings from a grid.
102
+
103
+ Args:
104
+ embed_dim (`int`): The embedding dimension.
105
+ grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
106
+
107
+ Returns:
108
+ `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
109
+ """
110
+ if output_type == "np":
111
+ deprecation_message = (
112
+ "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
113
+ " `from_numpy` is no longer required."
114
+ " Pass `output_type='pt' to use the new version now."
115
+ )
116
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
117
+ raise ValueError("Not supported")
118
+ if embed_dim % 2 != 0:
119
+ raise ValueError("embed_dim must be divisible by 2")
120
+
121
+ # use half of dimensions to encode grid_h
122
+ emb_h = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], phase=phase, output_type=output_type) # (H*W, D/2)
123
+ emb_w = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], phase=phase, output_type=output_type) # (H*W, D/2)
124
+
125
+ emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
126
+ return emb
127
+
128
+
129
+ def pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim, pos, phase=0, output_type="np"):
130
+ """
131
+ This function generates 1D positional embeddings from a grid.
132
+
133
+ Args:
134
+ embed_dim (`int`): The embedding dimension `D`
135
+ pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
136
+
137
+ Returns:
138
+ `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
139
+ """
140
+ if output_type == "np":
141
+ deprecation_message = (
142
+ "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
143
+ " `from_numpy` is no longer required."
144
+ " Pass `output_type='pt' to use the new version now."
145
+ )
146
+ deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
147
+ raise ValueError("Not supported")
148
+ if embed_dim % 2 != 0:
149
+ raise ValueError("embed_dim must be divisible by 2")
150
+
151
+ omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
152
+ omega /= embed_dim / 2.0
153
+ omega = 1.0 / 10000**omega # (D/2,)
154
+
155
+ pos = pos.reshape(-1) + phase # (M,)
156
+ out = torch.outer(pos, omega) # (M, D/2), outer product
157
+
158
+ emb_sin = torch.sin(out) # (M, D/2)
159
+ emb_cos = torch.cos(out) # (M, D/2)
160
+
161
+ emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
162
+ return emb
163
+
164
+
165
+ class PixcellUNIProjection(nn.Module):
166
+ """
167
+ Projects UNI embeddings. Also handles dropout for classifier-free guidance.
168
+
169
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
170
+ """
171
+
172
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", num_tokens=1):
173
+ super().__init__()
174
+ if out_features is None:
175
+ out_features = hidden_size
176
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
177
+ if act_fn == "gelu_tanh":
178
+ self.act_1 = nn.GELU(approximate="tanh")
179
+ elif act_fn == "silu":
180
+ self.act_1 = nn.SiLU()
181
+ elif act_fn == "silu_fp32":
182
+ self.act_1 = FP32SiLU()
183
+ else:
184
+ raise ValueError(f"Unknown activation function: {act_fn}")
185
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
186
+
187
+ self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features ** 0.5))
188
+
189
+ def forward(self, caption):
190
+ hidden_states = self.linear_1(caption)
191
+ hidden_states = self.act_1(hidden_states)
192
+ hidden_states = self.linear_2(hidden_states)
193
+ return hidden_states
194
+
195
+ class UNIPosEmbed(nn.Module):
196
+ """
197
+ Adds positional embeddings to the UNI conditions.
198
+
199
+ Args:
200
+ height (`int`, defaults to `224`): The height of the image.
201
+ width (`int`, defaults to `224`): The width of the image.
202
+ patch_size (`int`, defaults to `16`): The size of the patches.
203
+ in_channels (`int`, defaults to `3`): The number of input channels.
204
+ embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
205
+ layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
206
+ flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
207
+ bias (`bool`, defaults to `True`): Whether or not to use bias.
208
+ interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
209
+ pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
210
+ pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ height=1,
216
+ width=1,
217
+ base_size=16,
218
+ embed_dim=768,
219
+ interpolation_scale=1,
220
+ pos_embed_type="sincos",
221
+ ):
222
+ super().__init__()
223
+
224
+ num_embeds = height*width
225
+ grid_size = int(num_embeds ** 0.5)
226
+
227
+ if pos_embed_type == "sincos":
228
+ y_pos_embed = pixcell_get_2d_sincos_pos_embed(
229
+ embed_dim,
230
+ grid_size,
231
+ base_size=base_size,
232
+ interpolation_scale=interpolation_scale,
233
+ output_type="pt",
234
+ phase = base_size // num_embeds
235
+ )
236
+ self.register_buffer("y_pos_embed", y_pos_embed.float().unsqueeze(0))
237
+ else:
238
+ raise ValueError("`pos_embed_type` not supported")
239
+
240
+ def forward(self, uni_embeds):
241
+ return (uni_embeds + self.y_pos_embed).to(uni_embeds.dtype)
242
+
243
+ from diffusers.utils import BaseOutput, is_torch_version
244
+ @dataclass
245
+ class PixCellControlNetOutput(BaseOutput):
246
+ controlnet_block_samples: Tuple[torch.Tensor]
247
+
248
+ class PixCellControlNet(ModelMixin, ConfigMixin):
249
+ r"""
250
+ A 2D Transformer ControlNet model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
251
+ https://arxiv.org/abs/2403.04692). Modified for the pathology domain.
252
+
253
+ Parameters:
254
+ num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
255
+ attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
256
+ in_channels (int, defaults to 4): The number of channels in the input.
257
+ out_channels (int, optional):
258
+ The number of channels in the output. Specify this parameter if the output channel number differs from the
259
+ input.
260
+ num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
261
+ dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
262
+ norm_num_groups (int, optional, defaults to 32):
263
+ Number of groups for group normalization within Transformer blocks.
264
+ cross_attention_dim (int, optional):
265
+ The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension.
266
+ attention_bias (bool, optional, defaults to True):
267
+ Configure if the Transformer blocks' attention should contain a bias parameter.
268
+ sample_size (int, defaults to 128):
269
+ The width of the latent images. This parameter is fixed during training.
270
+ patch_size (int, defaults to 2):
271
+ Size of the patches the model processes, relevant for architectures working on non-sequential data.
272
+ activation_fn (str, optional, defaults to "gelu-approximate"):
273
+ Activation function to use in feed-forward networks within Transformer blocks.
274
+ num_embeds_ada_norm (int, optional, defaults to 1000):
275
+ Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
276
+ inference.
277
+ upcast_attention (bool, optional, defaults to False):
278
+ If true, upcasts the attention mechanism dimensions for potentially improved performance.
279
+ norm_type (str, optional, defaults to "ada_norm_zero"):
280
+ Specifies the type of normalization used, can be 'ada_norm_zero'.
281
+ norm_elementwise_affine (bool, optional, defaults to False):
282
+ If true, enables element-wise affine parameters in the normalization layers.
283
+ norm_eps (float, optional, defaults to 1e-6):
284
+ A small constant added to the denominator in normalization layers to prevent division by zero.
285
+ interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings.
286
+ use_additional_conditions (bool, optional): If we're using additional conditions as inputs.
287
+ attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used.
288
+ caption_channels (int, optional, defaults to None):
289
+ Number of channels to use for projecting the caption embeddings.
290
+ use_linear_projection (bool, optional, defaults to False):
291
+ Deprecated argument. Will be removed in a future version.
292
+ num_vector_embeds (bool, optional, defaults to False):
293
+ Deprecated argument. Will be removed in a future version.
294
+ """
295
+
296
+ _supports_gradient_checkpointing = True
297
+ _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
298
+
299
+ @register_to_config
300
+ def __init__(
301
+ self,
302
+ num_attention_heads: int = 16,
303
+ attention_head_dim: int = 72,
304
+ in_channels: int = 4,
305
+ out_channels: Optional[int] = 8,
306
+ num_layers: int = 28,
307
+ dropout: float = 0.0,
308
+ norm_num_groups: int = 32,
309
+ cross_attention_dim: Optional[int] = 1152,
310
+ attention_bias: bool = True,
311
+ sample_size: int = 128,
312
+ patch_size: int = 2,
313
+ activation_fn: str = "gelu-approximate",
314
+ num_embeds_ada_norm: Optional[int] = 1000,
315
+ upcast_attention: bool = False,
316
+ norm_type: str = "ada_norm_single",
317
+ norm_elementwise_affine: bool = False,
318
+ norm_eps: float = 1e-6,
319
+ interpolation_scale: Optional[int] = None,
320
+ use_additional_conditions: Optional[bool] = None,
321
+ caption_channels: Optional[int] = None,
322
+ caption_num_tokens: int = 1,
323
+ attention_type: Optional[str] = "default",
324
+ n_controlnet_blocks: Optional[int] = 28,
325
+ ):
326
+ super().__init__()
327
+
328
+ # Validate inputs.
329
+ if norm_type != "ada_norm_single":
330
+ raise NotImplementedError(
331
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
332
+ )
333
+ elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None:
334
+ raise ValueError(
335
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
336
+ )
337
+
338
+ # Set some common variables used across the board.
339
+ self.attention_head_dim = attention_head_dim
340
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
341
+ self.out_channels = in_channels if out_channels is None else out_channels
342
+ if use_additional_conditions is None:
343
+ if sample_size == 128:
344
+ use_additional_conditions = True
345
+ else:
346
+ use_additional_conditions = False
347
+ self.use_additional_conditions = use_additional_conditions
348
+
349
+ self.gradient_checkpointing = False
350
+
351
+ # 2. Initialize the position embedding and transformer blocks.
352
+ self.height = self.config.sample_size
353
+ self.width = self.config.sample_size
354
+
355
+ interpolation_scale = (
356
+ self.config.interpolation_scale
357
+ if self.config.interpolation_scale is not None
358
+ else max(self.config.sample_size // 64, 1)
359
+ )
360
+ self.pos_embed = PatchEmbed(
361
+ height=self.config.sample_size,
362
+ width=self.config.sample_size,
363
+ patch_size=self.config.patch_size,
364
+ in_channels=self.config.in_channels,
365
+ embed_dim=self.inner_dim,
366
+ interpolation_scale=interpolation_scale,
367
+ )
368
+
369
+ self.transformer_blocks = nn.ModuleList(
370
+ [
371
+ BasicTransformerBlock(
372
+ self.inner_dim,
373
+ self.config.num_attention_heads,
374
+ self.config.attention_head_dim,
375
+ dropout=self.config.dropout,
376
+ cross_attention_dim=self.config.cross_attention_dim,
377
+ activation_fn=self.config.activation_fn,
378
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
379
+ attention_bias=self.config.attention_bias,
380
+ upcast_attention=self.config.upcast_attention,
381
+ norm_type=norm_type,
382
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
383
+ norm_eps=self.config.norm_eps,
384
+ attention_type=self.config.attention_type,
385
+ )
386
+ for _ in range(self.config.num_layers)
387
+ ]
388
+ )
389
+
390
+ # Initialize the positional embedding for the conditions for >1 UNI embeddings
391
+ if self.config.caption_num_tokens == 1:
392
+ self.y_pos_embed = None
393
+ else:
394
+ # 1:1 aspect ratio
395
+ self.uni_height = int(self.config.caption_num_tokens ** 0.5)
396
+ self.uni_width = int(self.config.caption_num_tokens ** 0.5)
397
+
398
+ self.y_pos_embed = UNIPosEmbed(
399
+ height=self.uni_height,
400
+ width=self.uni_width,
401
+ base_size=self.config.sample_size // self.config.patch_size,
402
+ embed_dim=self.config.caption_channels,
403
+ interpolation_scale=2, # Should this be fixed?
404
+ pos_embed_type="sincos", # This is fixed
405
+ )
406
+
407
+ # 3. Output blocks.
408
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
409
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
410
+ self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
411
+
412
+ self.adaln_single = AdaLayerNormSingle(
413
+ self.inner_dim, use_additional_conditions=self.use_additional_conditions
414
+ )
415
+ self.caption_projection = None
416
+ if self.config.caption_channels is not None:
417
+ self.caption_projection = PixcellUNIProjection(
418
+ in_features=self.config.caption_channels, hidden_size=self.inner_dim, num_tokens=self.config.caption_num_tokens,
419
+ )
420
+
421
+
422
+ # 4. ControlNet blocks
423
+ # Condition patch embedding
424
+ self.cond_pos_embed = zero_module(PatchEmbed(
425
+ height=self.config.sample_size,
426
+ width=self.config.sample_size,
427
+ patch_size=self.config.patch_size,
428
+ in_channels=self.config.in_channels,
429
+ embed_dim=self.inner_dim,
430
+ interpolation_scale=interpolation_scale,
431
+ ))
432
+ # Can use a subset of the transformer blocks for ControLNet
433
+ self.n_controlnet_blocks = n_controlnet_blocks
434
+ if self.n_controlnet_blocks is not None:
435
+ self.transformer_blocks = self.transformer_blocks[:self.n_controlnet_blocks]
436
+
437
+ # ControlNet layers
438
+ self.controlnet_blocks = nn.ModuleList([])
439
+ for i in range(len(self.transformer_blocks)):
440
+ controlnet_block = nn.Linear(self.inner_dim, self.inner_dim)
441
+ controlnet_block = zero_module(controlnet_block)
442
+ self.controlnet_blocks.append(controlnet_block)
443
+
444
+ if self.n_controlnet_blocks is not None:
445
+ if i+1 == self.n_controlnet_blocks:
446
+ break
447
+
448
+
449
+
450
+ def _set_gradient_checkpointing(self, module, value=False):
451
+ if hasattr(module, "gradient_checkpointing"):
452
+ module.gradient_checkpointing = value
453
+
454
+ @property
455
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
456
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
457
+ r"""
458
+ Returns:
459
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
460
+ indexed by its weight name.
461
+ """
462
+ # set recursively
463
+ processors = {}
464
+
465
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
466
+ if hasattr(module, "get_processor"):
467
+ processors[f"{name}.processor"] = module.get_processor()
468
+
469
+ for sub_name, child in module.named_children():
470
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
471
+
472
+ return processors
473
+
474
+ for name, module in self.named_children():
475
+ fn_recursive_add_processors(name, module, processors)
476
+
477
+ return processors
478
+
479
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
480
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
481
+ r"""
482
+ Sets the attention processor to use to compute attention.
483
+
484
+ Parameters:
485
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
486
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
487
+ for **all** `Attention` layers.
488
+
489
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
490
+ processor. This is strongly recommended when setting trainable attention processors.
491
+
492
+ """
493
+ count = len(self.attn_processors.keys())
494
+
495
+ if isinstance(processor, dict) and len(processor) != count:
496
+ raise ValueError(
497
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
498
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
499
+ )
500
+
501
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
502
+ if hasattr(module, "set_processor"):
503
+ if not isinstance(processor, dict):
504
+ module.set_processor(processor)
505
+ else:
506
+ module.set_processor(processor.pop(f"{name}.processor"))
507
+
508
+ for sub_name, child in module.named_children():
509
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
510
+
511
+ for name, module in self.named_children():
512
+ fn_recursive_attn_processor(name, module, processor)
513
+
514
+ def set_default_attn_processor(self):
515
+ """
516
+ Disables custom attention processors and sets the default attention implementation.
517
+
518
+ Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
519
+ """
520
+ self.set_attn_processor(AttnProcessor())
521
+
522
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
523
+ def fuse_qkv_projections(self):
524
+ """
525
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
526
+ are fused. For cross-attention modules, key and value projection matrices are fused.
527
+
528
+ <Tip warning={true}>
529
+
530
+ This API is 🧪 experimental.
531
+
532
+ </Tip>
533
+ """
534
+ self.original_attn_processors = None
535
+
536
+ for _, attn_processor in self.attn_processors.items():
537
+ if "Added" in str(attn_processor.__class__.__name__):
538
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
539
+
540
+ self.original_attn_processors = self.attn_processors
541
+
542
+ for module in self.modules():
543
+ if isinstance(module, Attention):
544
+ module.fuse_projections(fuse=True)
545
+
546
+ self.set_attn_processor(FusedAttnProcessor2_0())
547
+
548
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
549
+ def unfuse_qkv_projections(self):
550
+ """Disables the fused QKV projection if enabled.
551
+
552
+ <Tip warning={true}>
553
+
554
+ This API is 🧪 experimental.
555
+
556
+ </Tip>
557
+
558
+ """
559
+ if self.original_attn_processors is not None:
560
+ self.set_attn_processor(self.original_attn_processors)
561
+
562
+ def forward(
563
+ self,
564
+ hidden_states: torch.Tensor,
565
+ conditioning: torch.Tensor,
566
+ encoder_hidden_states: Optional[torch.Tensor] = None,
567
+ timestep: Optional[torch.LongTensor] = None,
568
+ conditioning_scale: float = 1.0,
569
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
570
+ cross_attention_kwargs: Dict[str, Any] = None,
571
+ attention_mask: Optional[torch.Tensor] = None,
572
+ encoder_attention_mask: Optional[torch.Tensor] = None,
573
+ return_dict: bool = True,
574
+ ):
575
+ if self.use_additional_conditions and added_cond_kwargs is None:
576
+ raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
577
+
578
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
579
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
580
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
581
+ # expects mask of shape:
582
+ # [batch, key_tokens]
583
+ # adds singleton query_tokens dimension:
584
+ # [batch, 1, key_tokens]
585
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
586
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
587
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
588
+ if attention_mask is not None and attention_mask.ndim == 2:
589
+ # assume that mask is expressed as:
590
+ # (1 = keep, 0 = discard)
591
+ # convert mask into a bias that can be added to attention scores:
592
+ # (keep = +0, discard = -10000.0)
593
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
594
+ attention_mask = attention_mask.unsqueeze(1)
595
+
596
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
597
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
598
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
599
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
600
+
601
+ # 1. Input
602
+ batch_size = hidden_states.shape[0]
603
+ height, width = (
604
+ hidden_states.shape[-2] // self.config.patch_size,
605
+ hidden_states.shape[-1] // self.config.patch_size,
606
+ )
607
+ hidden_states = self.pos_embed(hidden_states)
608
+
609
+ # Conditioning
610
+ hidden_states = hidden_states + self.cond_pos_embed(conditioning)
611
+
612
+ timestep, embedded_timestep = self.adaln_single(
613
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
614
+ )
615
+
616
+ if self.caption_projection is not None:
617
+ # Add positional embeddings to conditions if >1 UNI are given
618
+ if self.y_pos_embed is not None:
619
+ encoder_hidden_states = self.y_pos_embed(encoder_hidden_states)
620
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
621
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
622
+
623
+ # 2. Blocks
624
+ block_outputs = ()
625
+
626
+ for block in self.transformer_blocks:
627
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
628
+
629
+ def create_custom_forward(module, return_dict=None):
630
+ def custom_forward(*inputs):
631
+ if return_dict is not None:
632
+ return module(*inputs, return_dict=return_dict)
633
+ else:
634
+ return module(*inputs)
635
+
636
+ return custom_forward
637
+
638
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
639
+ hidden_states = torch.utils.checkpoint.checkpoint(
640
+ create_custom_forward(block),
641
+ hidden_states,
642
+ attention_mask,
643
+ encoder_hidden_states,
644
+ encoder_attention_mask,
645
+ timestep,
646
+ cross_attention_kwargs,
647
+ None,
648
+ **ckpt_kwargs,
649
+ )
650
+ else:
651
+ hidden_states = block(
652
+ hidden_states,
653
+ attention_mask=attention_mask,
654
+ encoder_hidden_states=encoder_hidden_states,
655
+ encoder_attention_mask=encoder_attention_mask,
656
+ timestep=timestep,
657
+ cross_attention_kwargs=cross_attention_kwargs,
658
+ class_labels=None,
659
+ )
660
+
661
+ block_outputs = block_outputs + (hidden_states,)
662
+
663
+ # 3. controlnet blocks
664
+ controlnet_outputs = ()
665
+ for t_output, controlnet_block in zip(block_outputs, self.controlnet_blocks):
666
+ b_output = controlnet_block(t_output)
667
+ controlnet_outputs = controlnet_outputs + (b_output,)
668
+
669
+ controlnet_outputs = [sample * conditioning_scale for sample in controlnet_outputs]
670
+
671
+ if not return_dict:
672
+ return (controlnet_outputs,)
673
+
674
+ return PixCellControlNetOutput(controlnet_block_samples=controlnet_outputs)
675
+