alexnasa commited on
Commit
01e63df
·
verified ·
1 Parent(s): 6348747

Upload 4 files

Browse files
models/controlnet.py ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import FromOriginalControlnetMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
25
+ from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+ from .unet_2d_blocks import (
28
+ CrossAttnDownBlock2D,
29
+ DownBlock2D,
30
+ UNetMidBlock2DCrossAttn,
31
+ get_down_block,
32
+ )
33
+ from .unet_2d_condition import UNet2DConditionModel
34
+
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ @dataclass
40
+ class ControlNetOutput(BaseOutput):
41
+ """
42
+ The output of [`ControlNetModel`].
43
+
44
+ Args:
45
+ down_block_res_samples (`tuple[torch.Tensor]`):
46
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
47
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
48
+ used to condition the original UNet's downsampling activations.
49
+ mid_down_block_re_sample (`torch.Tensor`):
50
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
51
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
52
+ Output can be used to condition the original UNet's middle block activation.
53
+ """
54
+
55
+ down_block_res_samples: Tuple[torch.Tensor]
56
+ mid_block_res_sample: torch.Tensor
57
+
58
+
59
+ class ControlNetConditioningEmbedding(nn.Module):
60
+ """
61
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
62
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
63
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
64
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
65
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
66
+ model) to encode image-space conditions ... into feature maps ..."
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ conditioning_embedding_channels: int,
72
+ conditioning_channels: int = 3,
73
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
74
+ ):
75
+ super().__init__()
76
+
77
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
78
+
79
+ self.blocks = nn.ModuleList([])
80
+
81
+ for i in range(len(block_out_channels) - 1):
82
+ channel_in = block_out_channels[i]
83
+ channel_out = block_out_channels[i + 1]
84
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
85
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
86
+
87
+ self.conv_out = zero_module(
88
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
89
+ )
90
+
91
+ def forward(self, conditioning):
92
+ embedding = self.conv_in(conditioning)
93
+ embedding = F.silu(embedding)
94
+
95
+ for block in self.blocks:
96
+ embedding = block(embedding)
97
+ embedding = F.silu(embedding)
98
+
99
+ embedding = self.conv_out(embedding)
100
+
101
+ return embedding
102
+
103
+
104
+ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
105
+ """
106
+ A ControlNet model.
107
+
108
+ Args:
109
+ in_channels (`int`, defaults to 4):
110
+ The number of channels in the input sample.
111
+ flip_sin_to_cos (`bool`, defaults to `True`):
112
+ Whether to flip the sin to cos in the time embedding.
113
+ freq_shift (`int`, defaults to 0):
114
+ The frequency shift to apply to the time embedding.
115
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
116
+ The tuple of downsample blocks to use.
117
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
118
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
119
+ The tuple of output channels for each block.
120
+ layers_per_block (`int`, defaults to 2):
121
+ The number of layers per block.
122
+ downsample_padding (`int`, defaults to 1):
123
+ The padding to use for the downsampling convolution.
124
+ mid_block_scale_factor (`float`, defaults to 1):
125
+ The scale factor to use for the mid block.
126
+ act_fn (`str`, defaults to "silu"):
127
+ The activation function to use.
128
+ norm_num_groups (`int`, *optional*, defaults to 32):
129
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
130
+ in post-processing.
131
+ norm_eps (`float`, defaults to 1e-5):
132
+ The epsilon to use for the normalization.
133
+ cross_attention_dim (`int`, defaults to 1280):
134
+ The dimension of the cross attention features.
135
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
136
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
137
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
138
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
139
+ encoder_hid_dim (`int`, *optional*, defaults to None):
140
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
141
+ dimension to `cross_attention_dim`.
142
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
143
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
144
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
145
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
146
+ The dimension of the attention heads.
147
+ use_linear_projection (`bool`, defaults to `False`):
148
+ class_embed_type (`str`, *optional*, defaults to `None`):
149
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
150
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
151
+ addition_embed_type (`str`, *optional*, defaults to `None`):
152
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
153
+ "text". "text" will use the `TextTimeEmbedding` layer.
154
+ num_class_embeds (`int`, *optional*, defaults to 0):
155
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
156
+ class conditioning with `class_embed_type` equal to `None`.
157
+ upcast_attention (`bool`, defaults to `False`):
158
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
159
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
160
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
161
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
162
+ `class_embed_type="projection"`.
163
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
164
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
165
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
166
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
167
+ global_pool_conditions (`bool`, defaults to `False`):
168
+ """
169
+
170
+ _supports_gradient_checkpointing = True
171
+
172
+ @register_to_config
173
+ def __init__(
174
+ self,
175
+ in_channels: int = 4,
176
+ conditioning_channels: int = 3,
177
+ flip_sin_to_cos: bool = True,
178
+ freq_shift: int = 0,
179
+ down_block_types: Tuple[str] = (
180
+ "CrossAttnDownBlock2D",
181
+ "CrossAttnDownBlock2D",
182
+ "CrossAttnDownBlock2D",
183
+ "DownBlock2D",
184
+ ),
185
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
186
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
187
+ layers_per_block: int = 2,
188
+ downsample_padding: int = 1,
189
+ mid_block_scale_factor: float = 1,
190
+ act_fn: str = "silu",
191
+ norm_num_groups: Optional[int] = 32,
192
+ norm_eps: float = 1e-5,
193
+ cross_attention_dim: int = 1280,
194
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
195
+ encoder_hid_dim: Optional[int] = None,
196
+ encoder_hid_dim_type: Optional[str] = None,
197
+ attention_head_dim: Union[int, Tuple[int]] = 8,
198
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
199
+ use_linear_projection: bool = False,
200
+ class_embed_type: Optional[str] = None,
201
+ addition_embed_type: Optional[str] = None,
202
+ addition_time_embed_dim: Optional[int] = None,
203
+ num_class_embeds: Optional[int] = None,
204
+ upcast_attention: bool = False,
205
+ resnet_time_scale_shift: str = "default",
206
+ projection_class_embeddings_input_dim: Optional[int] = None,
207
+ controlnet_conditioning_channel_order: str = "rgb",
208
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
209
+ global_pool_conditions: bool = False,
210
+ addition_embed_type_num_heads=64,
211
+ use_image_cross_attention=False,
212
+ ):
213
+ super().__init__()
214
+
215
+ # If `num_attention_heads` is not defined (which is the case for most models)
216
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
217
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
218
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
219
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
220
+ # which is why we correct for the naming here.
221
+ num_attention_heads = num_attention_heads or attention_head_dim
222
+
223
+ # Check inputs
224
+ if len(block_out_channels) != len(down_block_types):
225
+ raise ValueError(
226
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
227
+ )
228
+
229
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
230
+ raise ValueError(
231
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
232
+ )
233
+
234
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
235
+ raise ValueError(
236
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
237
+ )
238
+
239
+ if isinstance(transformer_layers_per_block, int):
240
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
241
+
242
+ # input
243
+ conv_in_kernel = 3
244
+ conv_in_padding = (conv_in_kernel - 1) // 2
245
+ self.conv_in = nn.Conv2d(
246
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
247
+ )
248
+
249
+ # time
250
+ time_embed_dim = block_out_channels[0] * 4
251
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
252
+ timestep_input_dim = block_out_channels[0]
253
+ self.time_embedding = TimestepEmbedding(
254
+ timestep_input_dim,
255
+ time_embed_dim,
256
+ act_fn=act_fn,
257
+ )
258
+
259
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
260
+ encoder_hid_dim_type = "text_proj"
261
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
262
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
263
+
264
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
265
+ raise ValueError(
266
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
267
+ )
268
+
269
+ if encoder_hid_dim_type == "text_proj":
270
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
271
+ elif encoder_hid_dim_type == "text_image_proj":
272
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
273
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
274
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
275
+ self.encoder_hid_proj = TextImageProjection(
276
+ text_embed_dim=encoder_hid_dim,
277
+ image_embed_dim=cross_attention_dim,
278
+ cross_attention_dim=cross_attention_dim,
279
+ )
280
+
281
+ elif encoder_hid_dim_type is not None:
282
+ raise ValueError(
283
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
284
+ )
285
+ else:
286
+ self.encoder_hid_proj = None
287
+
288
+ # class embedding
289
+ if class_embed_type is None and num_class_embeds is not None:
290
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
291
+ elif class_embed_type == "timestep":
292
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
293
+ elif class_embed_type == "identity":
294
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
295
+ elif class_embed_type == "projection":
296
+ if projection_class_embeddings_input_dim is None:
297
+ raise ValueError(
298
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
299
+ )
300
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
301
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
302
+ # 2. it projects from an arbitrary input dimension.
303
+ #
304
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
305
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
306
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
307
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
308
+ else:
309
+ self.class_embedding = None
310
+
311
+ if addition_embed_type == "text":
312
+ if encoder_hid_dim is not None:
313
+ text_time_embedding_from_dim = encoder_hid_dim
314
+ else:
315
+ text_time_embedding_from_dim = cross_attention_dim
316
+
317
+ self.add_embedding = TextTimeEmbedding(
318
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
319
+ )
320
+ elif addition_embed_type == "text_image":
321
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
322
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
323
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
324
+ self.add_embedding = TextImageTimeEmbedding(
325
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
326
+ )
327
+ elif addition_embed_type == "text_time":
328
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
329
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
330
+
331
+ elif addition_embed_type is not None:
332
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
333
+
334
+ # control net conditioning embedding
335
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
336
+ conditioning_embedding_channels=block_out_channels[0],
337
+ block_out_channels=conditioning_embedding_out_channels,
338
+ conditioning_channels=conditioning_channels,
339
+ )
340
+
341
+ self.down_blocks = nn.ModuleList([])
342
+ self.controlnet_down_blocks = nn.ModuleList([])
343
+
344
+ if isinstance(only_cross_attention, bool):
345
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
346
+
347
+ if isinstance(attention_head_dim, int):
348
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
349
+
350
+ if isinstance(num_attention_heads, int):
351
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
352
+
353
+ # down
354
+ output_channel = block_out_channels[0]
355
+
356
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
357
+ controlnet_block = zero_module(controlnet_block)
358
+ self.controlnet_down_blocks.append(controlnet_block)
359
+
360
+ for i, down_block_type in enumerate(down_block_types):
361
+ input_channel = output_channel
362
+ output_channel = block_out_channels[i]
363
+ is_final_block = i == len(block_out_channels) - 1
364
+
365
+ down_block = get_down_block(
366
+ down_block_type,
367
+ num_layers=layers_per_block,
368
+ transformer_layers_per_block=transformer_layers_per_block[i],
369
+ in_channels=input_channel,
370
+ out_channels=output_channel,
371
+ temb_channels=time_embed_dim,
372
+ add_downsample=not is_final_block,
373
+ resnet_eps=norm_eps,
374
+ resnet_act_fn=act_fn,
375
+ resnet_groups=norm_num_groups,
376
+ cross_attention_dim=cross_attention_dim,
377
+ num_attention_heads=num_attention_heads[i],
378
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
379
+ downsample_padding=downsample_padding,
380
+ use_linear_projection=use_linear_projection,
381
+ only_cross_attention=only_cross_attention[i],
382
+ upcast_attention=upcast_attention,
383
+ resnet_time_scale_shift=resnet_time_scale_shift,
384
+ use_image_cross_attention=use_image_cross_attention,
385
+ )
386
+ self.down_blocks.append(down_block)
387
+
388
+ for _ in range(layers_per_block):
389
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
390
+ controlnet_block = zero_module(controlnet_block)
391
+ self.controlnet_down_blocks.append(controlnet_block)
392
+
393
+ if not is_final_block:
394
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
395
+ controlnet_block = zero_module(controlnet_block)
396
+ self.controlnet_down_blocks.append(controlnet_block)
397
+
398
+ # mid
399
+ mid_block_channel = block_out_channels[-1]
400
+
401
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
402
+ controlnet_block = zero_module(controlnet_block)
403
+ self.controlnet_mid_block = controlnet_block
404
+
405
+ self.mid_block = UNetMidBlock2DCrossAttn(
406
+ transformer_layers_per_block=transformer_layers_per_block[-1],
407
+ in_channels=mid_block_channel,
408
+ temb_channels=time_embed_dim,
409
+ resnet_eps=norm_eps,
410
+ resnet_act_fn=act_fn,
411
+ output_scale_factor=mid_block_scale_factor,
412
+ resnet_time_scale_shift=resnet_time_scale_shift,
413
+ cross_attention_dim=cross_attention_dim,
414
+ num_attention_heads=num_attention_heads[-1],
415
+ resnet_groups=norm_num_groups,
416
+ use_linear_projection=use_linear_projection,
417
+ upcast_attention=upcast_attention,
418
+ use_image_cross_attention=use_image_cross_attention,
419
+ )
420
+
421
+ @classmethod
422
+ def from_unet(
423
+ cls,
424
+ unet: UNet2DConditionModel,
425
+ controlnet_conditioning_channel_order: str = "rgb",
426
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
427
+ load_weights_from_unet: bool = True,
428
+ use_image_cross_attention: bool = False,
429
+ ):
430
+ r"""
431
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
432
+
433
+ Parameters:
434
+ unet (`UNet2DConditionModel`):
435
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
436
+ where applicable.
437
+ """
438
+ transformer_layers_per_block = (
439
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
440
+ )
441
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
442
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
443
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
444
+ addition_time_embed_dim = (
445
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
446
+ )
447
+
448
+ controlnet = cls(
449
+ encoder_hid_dim=encoder_hid_dim,
450
+ encoder_hid_dim_type=encoder_hid_dim_type,
451
+ addition_embed_type=addition_embed_type,
452
+ addition_time_embed_dim=addition_time_embed_dim,
453
+ transformer_layers_per_block=transformer_layers_per_block,
454
+ in_channels=unet.config.in_channels,
455
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
456
+ freq_shift=unet.config.freq_shift,
457
+ down_block_types=unet.config.down_block_types,
458
+ only_cross_attention=unet.config.only_cross_attention,
459
+ block_out_channels=unet.config.block_out_channels,
460
+ layers_per_block=unet.config.layers_per_block,
461
+ downsample_padding=unet.config.downsample_padding,
462
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
463
+ act_fn=unet.config.act_fn,
464
+ norm_num_groups=unet.config.norm_num_groups,
465
+ norm_eps=unet.config.norm_eps,
466
+ cross_attention_dim=unet.config.cross_attention_dim,
467
+ attention_head_dim=unet.config.attention_head_dim,
468
+ num_attention_heads=unet.config.num_attention_heads,
469
+ use_linear_projection=unet.config.use_linear_projection,
470
+ class_embed_type=unet.config.class_embed_type,
471
+ num_class_embeds=unet.config.num_class_embeds,
472
+ upcast_attention=unet.config.upcast_attention,
473
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
474
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
475
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
476
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
477
+ use_image_cross_attention=use_image_cross_attention,
478
+ )
479
+
480
+ if load_weights_from_unet:
481
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
482
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
483
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
484
+
485
+ if controlnet.class_embedding:
486
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
487
+
488
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
489
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
490
+
491
+ return controlnet
492
+
493
+ @property
494
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
495
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
496
+ r"""
497
+ Returns:
498
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
499
+ indexed by its weight name.
500
+ """
501
+ # set recursively
502
+ processors = {}
503
+
504
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
505
+ if hasattr(module, "get_processor"):
506
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
507
+
508
+ for sub_name, child in module.named_children():
509
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
510
+
511
+ return processors
512
+
513
+ for name, module in self.named_children():
514
+ fn_recursive_add_processors(name, module, processors)
515
+
516
+ return processors
517
+
518
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
519
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
520
+ r"""
521
+ Sets the attention processor to use to compute attention.
522
+
523
+ Parameters:
524
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
525
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
526
+ for **all** `Attention` layers.
527
+
528
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
529
+ processor. This is strongly recommended when setting trainable attention processors.
530
+
531
+ """
532
+ count = len(self.attn_processors.keys())
533
+
534
+ if isinstance(processor, dict) and len(processor) != count:
535
+ raise ValueError(
536
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
537
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
538
+ )
539
+
540
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
541
+ if hasattr(module, "set_processor"):
542
+ if not isinstance(processor, dict):
543
+ module.set_processor(processor)
544
+ else:
545
+ module.set_processor(processor.pop(f"{name}.processor"))
546
+
547
+ for sub_name, child in module.named_children():
548
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
549
+
550
+ for name, module in self.named_children():
551
+ fn_recursive_attn_processor(name, module, processor)
552
+
553
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
554
+ def set_default_attn_processor(self):
555
+ """
556
+ Disables custom attention processors and sets the default attention implementation.
557
+ """
558
+ self.set_attn_processor(AttnProcessor())
559
+
560
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
561
+ def set_attention_slice(self, slice_size):
562
+ r"""
563
+ Enable sliced attention computation.
564
+
565
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
566
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
567
+
568
+ Args:
569
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
570
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
571
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
572
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
573
+ must be a multiple of `slice_size`.
574
+ """
575
+ sliceable_head_dims = []
576
+
577
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
578
+ if hasattr(module, "set_attention_slice"):
579
+ sliceable_head_dims.append(module.sliceable_head_dim)
580
+
581
+ for child in module.children():
582
+ fn_recursive_retrieve_sliceable_dims(child)
583
+
584
+ # retrieve number of attention layers
585
+ for module in self.children():
586
+ fn_recursive_retrieve_sliceable_dims(module)
587
+
588
+ num_sliceable_layers = len(sliceable_head_dims)
589
+
590
+ if slice_size == "auto":
591
+ # half the attention head size is usually a good trade-off between
592
+ # speed and memory
593
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
594
+ elif slice_size == "max":
595
+ # make smallest slice possible
596
+ slice_size = num_sliceable_layers * [1]
597
+
598
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
599
+
600
+ if len(slice_size) != len(sliceable_head_dims):
601
+ raise ValueError(
602
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
603
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
604
+ )
605
+
606
+ for i in range(len(slice_size)):
607
+ size = slice_size[i]
608
+ dim = sliceable_head_dims[i]
609
+ if size is not None and size > dim:
610
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
611
+
612
+ # Recursively walk through all the children.
613
+ # Any children which exposes the set_attention_slice method
614
+ # gets the message
615
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
616
+ if hasattr(module, "set_attention_slice"):
617
+ module.set_attention_slice(slice_size.pop())
618
+
619
+ for child in module.children():
620
+ fn_recursive_set_attention_slice(child, slice_size)
621
+
622
+ reversed_slice_size = list(reversed(slice_size))
623
+ for module in self.children():
624
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
625
+
626
+ def _set_gradient_checkpointing(self, module, value=False):
627
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
628
+ module.gradient_checkpointing = value
629
+
630
+ def forward(
631
+ self,
632
+ sample: torch.FloatTensor,
633
+ timestep: Union[torch.Tensor, float, int],
634
+ encoder_hidden_states: torch.Tensor,
635
+ controlnet_cond: torch.FloatTensor,
636
+ conditioning_scale: float = 1.0,
637
+ class_labels: Optional[torch.Tensor] = None,
638
+ timestep_cond: Optional[torch.Tensor] = None,
639
+ attention_mask: Optional[torch.Tensor] = None,
640
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
641
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
642
+ guess_mode: bool = False,
643
+ return_dict: bool = True,
644
+ image_encoder_hidden_states: torch.Tensor = None,
645
+ vae_encode_condition_hidden_states: torch.Tensor = None,
646
+ ) -> Union[ControlNetOutput, Tuple]:
647
+ """
648
+ The [`ControlNetModel`] forward method.
649
+
650
+ Args:
651
+ sample (`torch.FloatTensor`):
652
+ The noisy input tensor.
653
+ timestep (`Union[torch.Tensor, float, int]`):
654
+ The number of timesteps to denoise an input.
655
+ encoder_hidden_states (`torch.Tensor`):
656
+ The encoder hidden states.
657
+ controlnet_cond (`torch.FloatTensor`):
658
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
659
+ conditioning_scale (`float`, defaults to `1.0`):
660
+ The scale factor for ControlNet outputs.
661
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
662
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
663
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
664
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
665
+ added_cond_kwargs (`dict`):
666
+ Additional conditions for the Stable Diffusion XL UNet.
667
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
668
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
669
+ guess_mode (`bool`, defaults to `False`):
670
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
671
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
672
+ return_dict (`bool`, defaults to `True`):
673
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
674
+
675
+ Returns:
676
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
677
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
678
+ returned where the first element is the sample tensor.
679
+ """
680
+ # check channel order
681
+ channel_order = self.config.controlnet_conditioning_channel_order
682
+
683
+ if channel_order == "rgb":
684
+ # in rgb order by default
685
+ ...
686
+ elif channel_order == "bgr":
687
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
688
+ else:
689
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
690
+
691
+ # prepare attention_mask
692
+ if attention_mask is not None:
693
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
694
+ attention_mask = attention_mask.unsqueeze(1)
695
+
696
+ # 1. time
697
+ timesteps = timestep
698
+ if not torch.is_tensor(timesteps):
699
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
700
+ # This would be a good case for the `match` statement (Python 3.10+)
701
+ is_mps = sample.device.type == "mps"
702
+ if isinstance(timestep, float):
703
+ dtype = torch.float32 if is_mps else torch.float64
704
+ else:
705
+ dtype = torch.int32 if is_mps else torch.int64
706
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
707
+ elif len(timesteps.shape) == 0:
708
+ timesteps = timesteps[None].to(sample.device)
709
+
710
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
711
+ timesteps = timesteps.expand(sample.shape[0])
712
+
713
+ t_emb = self.time_proj(timesteps)
714
+
715
+ # timesteps does not contain any weights and will always return f32 tensors
716
+ # but time_embedding might actually be running in fp16. so we need to cast here.
717
+ # there might be better ways to encapsulate this.
718
+ t_emb = t_emb.to(dtype=sample.dtype)
719
+
720
+ emb = self.time_embedding(t_emb, timestep_cond)
721
+ aug_emb = None
722
+
723
+ if self.class_embedding is not None:
724
+ if class_labels is None:
725
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
726
+
727
+ if self.config.class_embed_type == "timestep":
728
+ class_labels = self.time_proj(class_labels)
729
+
730
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
731
+ emb = emb + class_emb
732
+
733
+ if self.config.addition_embed_type is not None:
734
+ if self.config.addition_embed_type == "text":
735
+ aug_emb = self.add_embedding(encoder_hidden_states)
736
+
737
+ elif self.config.addition_embed_type == "text_time":
738
+ if "text_embeds" not in added_cond_kwargs:
739
+ raise ValueError(
740
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
741
+ )
742
+ text_embeds = added_cond_kwargs.get("text_embeds")
743
+ if "time_ids" not in added_cond_kwargs:
744
+ raise ValueError(
745
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
746
+ )
747
+ time_ids = added_cond_kwargs.get("time_ids")
748
+ time_embeds = self.add_time_proj(time_ids.flatten())
749
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
750
+
751
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
752
+ add_embeds = add_embeds.to(emb.dtype)
753
+ aug_emb = self.add_embedding(add_embeds)
754
+
755
+ emb = emb + aug_emb if aug_emb is not None else emb
756
+
757
+ # 2. pre-process
758
+ sample = self.conv_in(sample)
759
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
760
+ sample = sample + controlnet_cond
761
+
762
+ # 3. down
763
+ down_block_res_samples = (sample,)
764
+ for downsample_block in self.down_blocks:
765
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
766
+ sample, res_samples = downsample_block(
767
+ hidden_states=sample,
768
+ temb=emb,
769
+ encoder_hidden_states=encoder_hidden_states,
770
+ attention_mask=attention_mask,
771
+ cross_attention_kwargs=cross_attention_kwargs,
772
+ image_encoder_hidden_states=image_encoder_hidden_states,
773
+ )
774
+ else:
775
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
776
+
777
+ down_block_res_samples += res_samples
778
+
779
+ # 4. mid
780
+ if self.mid_block is not None:
781
+ sample = self.mid_block(
782
+ sample,
783
+ emb,
784
+ encoder_hidden_states=encoder_hidden_states,
785
+ attention_mask=attention_mask,
786
+ cross_attention_kwargs=cross_attention_kwargs,
787
+ image_encoder_hidden_states=image_encoder_hidden_states,
788
+ )
789
+
790
+ # 5. Control net blocks
791
+
792
+ controlnet_down_block_res_samples = ()
793
+
794
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
795
+ down_block_res_sample = controlnet_block(down_block_res_sample)
796
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
797
+
798
+ down_block_res_samples = controlnet_down_block_res_samples
799
+
800
+ mid_block_res_sample = self.controlnet_mid_block(sample)
801
+
802
+ # 6. scaling
803
+ if guess_mode and not self.config.global_pool_conditions:
804
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
805
+
806
+ scales = scales * conditioning_scale
807
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
808
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
809
+ else:
810
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
811
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
812
+
813
+ if self.config.global_pool_conditions:
814
+ down_block_res_samples = [
815
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
816
+ ]
817
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
818
+
819
+ if not return_dict:
820
+ return (down_block_res_samples, mid_block_res_sample)
821
+
822
+ return ControlNetOutput(
823
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
824
+ )
825
+
826
+
827
+ def zero_module(module):
828
+ for p in module.parameters():
829
+ nn.init.zeros_(p)
830
+ return module
831
+
832
+
models/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
models/unet_2d_condition.py ADDED
@@ -0,0 +1,1071 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import UNet2DConditionLoadersMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.activations import get_activation
25
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
26
+ from diffusers.models.embeddings import (
27
+ GaussianFourierProjection,
28
+ ImageHintTimeEmbedding,
29
+ ImageProjection,
30
+ ImageTimeEmbedding,
31
+ PositionNet,
32
+ TextImageProjection,
33
+ TextImageTimeEmbedding,
34
+ TextTimeEmbedding,
35
+ TimestepEmbedding,
36
+ Timesteps,
37
+ )
38
+ from diffusers.models.modeling_utils import ModelMixin
39
+ from .unet_2d_blocks import (
40
+ UNetMidBlock2DCrossAttn,
41
+ UNetMidBlock2DSimpleCrossAttn,
42
+ get_down_block,
43
+ get_up_block,
44
+ )
45
+
46
+ import os, json
47
+
48
+
49
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50
+
51
+
52
+ @dataclass
53
+ class UNet2DConditionOutput(BaseOutput):
54
+ """
55
+ The output of [`UNet2DConditionModel`].
56
+
57
+ Args:
58
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
59
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
60
+ """
61
+
62
+ sample: torch.FloatTensor = None
63
+
64
+
65
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
66
+ r"""
67
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
68
+ shaped output.
69
+
70
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
71
+ for all models (such as downloading or saving).
72
+
73
+ Parameters:
74
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
75
+ Height and width of input/output sample.
76
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
77
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
78
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
79
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
80
+ Whether to flip the sin to cos in the time embedding.
81
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
82
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
83
+ The tuple of downsample blocks to use.
84
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
85
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
86
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
87
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
88
+ The tuple of upsample blocks to use.
89
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
90
+ Whether to include self-attention in the basic transformer blocks, see
91
+ [`~models.attention.BasicTransformerBlock`].
92
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
93
+ The tuple of output channels for each block.
94
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
95
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
96
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
97
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
98
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
99
+ If `None`, normalization and activation layers is skipped in post-processing.
100
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
101
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
102
+ The dimension of the cross attention features.
103
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
104
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
105
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
106
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
107
+ encoder_hid_dim (`int`, *optional*, defaults to None):
108
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
109
+ dimension to `cross_attention_dim`.
110
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
111
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
112
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
113
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
114
+ num_attention_heads (`int`, *optional*):
115
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
116
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
117
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
118
+ class_embed_type (`str`, *optional*, defaults to `None`):
119
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
120
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
121
+ addition_embed_type (`str`, *optional*, defaults to `None`):
122
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
123
+ "text". "text" will use the `TextTimeEmbedding` layer.
124
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
125
+ Dimension for the timestep embeddings.
126
+ num_class_embeds (`int`, *optional*, defaults to `None`):
127
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
128
+ class conditioning with `class_embed_type` equal to `None`.
129
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
130
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
131
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
132
+ An optional override for the dimension of the projected time embedding.
133
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
134
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
135
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
136
+ timestep_post_act (`str`, *optional*, defaults to `None`):
137
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
138
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
139
+ The dimension of `cond_proj` layer in the timestep embedding.
140
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
141
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
142
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
143
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
144
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
145
+ embeddings with the class embeddings.
146
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
147
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
148
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
149
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
150
+ otherwise.
151
+ """
152
+
153
+ _supports_gradient_checkpointing = True
154
+
155
+ @register_to_config
156
+ def __init__(
157
+ self,
158
+ sample_size: Optional[int] = None,
159
+ in_channels: int = 4,
160
+ out_channels: int = 4,
161
+ center_input_sample: bool = False,
162
+ flip_sin_to_cos: bool = True,
163
+ freq_shift: int = 0,
164
+ down_block_types: Tuple[str] = (
165
+ "CrossAttnDownBlock2D",
166
+ "CrossAttnDownBlock2D",
167
+ "CrossAttnDownBlock2D",
168
+ "DownBlock2D",
169
+ ),
170
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
171
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
172
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
173
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
174
+ layers_per_block: Union[int, Tuple[int]] = 2,
175
+ downsample_padding: int = 1,
176
+ mid_block_scale_factor: float = 1,
177
+ act_fn: str = "silu",
178
+ norm_num_groups: Optional[int] = 32,
179
+ norm_eps: float = 1e-5,
180
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
181
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
182
+ encoder_hid_dim: Optional[int] = None,
183
+ encoder_hid_dim_type: Optional[str] = None,
184
+ attention_head_dim: Union[int, Tuple[int]] = 8,
185
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
186
+ dual_cross_attention: bool = False,
187
+ use_linear_projection: bool = False,
188
+ class_embed_type: Optional[str] = None,
189
+ addition_embed_type: Optional[str] = None,
190
+ addition_time_embed_dim: Optional[int] = None,
191
+ num_class_embeds: Optional[int] = None,
192
+ upcast_attention: bool = False,
193
+ resnet_time_scale_shift: str = "default",
194
+ resnet_skip_time_act: bool = False,
195
+ resnet_out_scale_factor: int = 1.0,
196
+ time_embedding_type: str = "positional",
197
+ time_embedding_dim: Optional[int] = None,
198
+ time_embedding_act_fn: Optional[str] = None,
199
+ timestep_post_act: Optional[str] = None,
200
+ time_cond_proj_dim: Optional[int] = None,
201
+ conv_in_kernel: int = 3,
202
+ conv_out_kernel: int = 3,
203
+ projection_class_embeddings_input_dim: Optional[int] = None,
204
+ attention_type: str = "default",
205
+ class_embeddings_concat: bool = False,
206
+ mid_block_only_cross_attention: Optional[bool] = None,
207
+ cross_attention_norm: Optional[str] = None,
208
+ addition_embed_type_num_heads=64,
209
+ use_image_cross_attention=False,
210
+ ):
211
+ super().__init__()
212
+
213
+ self.sample_size = sample_size
214
+
215
+ if num_attention_heads is not None:
216
+ raise ValueError(
217
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
218
+ )
219
+
220
+ # If `num_attention_heads` is not defined (which is the case for most models)
221
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
222
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
223
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
224
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
225
+ # which is why we correct for the naming here.
226
+ num_attention_heads = num_attention_heads or attention_head_dim
227
+
228
+ # Check inputs
229
+ if len(down_block_types) != len(up_block_types):
230
+ raise ValueError(
231
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
232
+ )
233
+
234
+ if len(block_out_channels) != len(down_block_types):
235
+ raise ValueError(
236
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
237
+ )
238
+
239
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
240
+ raise ValueError(
241
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
242
+ )
243
+
244
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
245
+ raise ValueError(
246
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
247
+ )
248
+
249
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
250
+ raise ValueError(
251
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
252
+ )
253
+
254
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
255
+ raise ValueError(
256
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
257
+ )
258
+
259
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
260
+ raise ValueError(
261
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
262
+ )
263
+
264
+ # input
265
+ conv_in_padding = (conv_in_kernel - 1) // 2
266
+ self.conv_in = nn.Conv2d(
267
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
268
+ )
269
+
270
+ # time
271
+ if time_embedding_type == "fourier":
272
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
273
+ if time_embed_dim % 2 != 0:
274
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
275
+ self.time_proj = GaussianFourierProjection(
276
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
277
+ )
278
+ timestep_input_dim = time_embed_dim
279
+ elif time_embedding_type == "positional":
280
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
281
+
282
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
283
+ timestep_input_dim = block_out_channels[0]
284
+ else:
285
+ raise ValueError(
286
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
287
+ )
288
+
289
+ self.time_embedding = TimestepEmbedding(
290
+ timestep_input_dim,
291
+ time_embed_dim,
292
+ act_fn=act_fn,
293
+ post_act_fn=timestep_post_act,
294
+ cond_proj_dim=time_cond_proj_dim,
295
+ )
296
+
297
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
298
+ encoder_hid_dim_type = "text_proj"
299
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
300
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
301
+
302
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
303
+ raise ValueError(
304
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
305
+ )
306
+
307
+ if encoder_hid_dim_type == "text_proj":
308
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
309
+ elif encoder_hid_dim_type == "text_image_proj":
310
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
311
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
312
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
313
+ self.encoder_hid_proj = TextImageProjection(
314
+ text_embed_dim=encoder_hid_dim,
315
+ image_embed_dim=cross_attention_dim,
316
+ cross_attention_dim=cross_attention_dim,
317
+ )
318
+ elif encoder_hid_dim_type == "image_proj":
319
+ # Kandinsky 2.2
320
+ self.encoder_hid_proj = ImageProjection(
321
+ image_embed_dim=encoder_hid_dim,
322
+ cross_attention_dim=cross_attention_dim,
323
+ )
324
+ elif encoder_hid_dim_type is not None:
325
+ raise ValueError(
326
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
327
+ )
328
+ else:
329
+ self.encoder_hid_proj = None
330
+
331
+ # class embedding
332
+ if class_embed_type is None and num_class_embeds is not None:
333
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
334
+ elif class_embed_type == "timestep":
335
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
336
+ elif class_embed_type == "identity":
337
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
338
+ elif class_embed_type == "projection":
339
+ if projection_class_embeddings_input_dim is None:
340
+ raise ValueError(
341
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
342
+ )
343
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
344
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
345
+ # 2. it projects from an arbitrary input dimension.
346
+ #
347
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
348
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
349
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
350
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
351
+ elif class_embed_type == "simple_projection":
352
+ if projection_class_embeddings_input_dim is None:
353
+ raise ValueError(
354
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
355
+ )
356
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
357
+ else:
358
+ self.class_embedding = None
359
+
360
+ if addition_embed_type == "text":
361
+ if encoder_hid_dim is not None:
362
+ text_time_embedding_from_dim = encoder_hid_dim
363
+ else:
364
+ text_time_embedding_from_dim = cross_attention_dim
365
+
366
+ self.add_embedding = TextTimeEmbedding(
367
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
368
+ )
369
+ elif addition_embed_type == "text_image":
370
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
371
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
372
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
373
+ self.add_embedding = TextImageTimeEmbedding(
374
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
375
+ )
376
+ elif addition_embed_type == "text_time":
377
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
378
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
379
+ elif addition_embed_type == "image":
380
+ # Kandinsky 2.2
381
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
382
+ elif addition_embed_type == "image_hint":
383
+ # Kandinsky 2.2 ControlNet
384
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
385
+ elif addition_embed_type is not None:
386
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
387
+
388
+ if time_embedding_act_fn is None:
389
+ self.time_embed_act = None
390
+ else:
391
+ self.time_embed_act = get_activation(time_embedding_act_fn)
392
+
393
+ self.down_blocks = nn.ModuleList([])
394
+ self.up_blocks = nn.ModuleList([])
395
+
396
+ if isinstance(only_cross_attention, bool):
397
+ if mid_block_only_cross_attention is None:
398
+ mid_block_only_cross_attention = only_cross_attention
399
+
400
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
401
+
402
+ if mid_block_only_cross_attention is None:
403
+ mid_block_only_cross_attention = False
404
+
405
+ if isinstance(num_attention_heads, int):
406
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
407
+
408
+ if isinstance(attention_head_dim, int):
409
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
410
+
411
+ if isinstance(cross_attention_dim, int):
412
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
413
+
414
+ if isinstance(layers_per_block, int):
415
+ layers_per_block = [layers_per_block] * len(down_block_types)
416
+
417
+ if isinstance(transformer_layers_per_block, int):
418
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
419
+
420
+ if class_embeddings_concat:
421
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
422
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
423
+ # regular time embeddings
424
+ blocks_time_embed_dim = time_embed_dim * 2
425
+ else:
426
+ blocks_time_embed_dim = time_embed_dim
427
+
428
+ # down
429
+ output_channel = block_out_channels[0]
430
+ for i, down_block_type in enumerate(down_block_types):
431
+ input_channel = output_channel
432
+ output_channel = block_out_channels[i]
433
+ is_final_block = i == len(block_out_channels) - 1
434
+
435
+ down_block = get_down_block(
436
+ down_block_type,
437
+ num_layers=layers_per_block[i],
438
+ transformer_layers_per_block=transformer_layers_per_block[i],
439
+ in_channels=input_channel,
440
+ out_channels=output_channel,
441
+ temb_channels=blocks_time_embed_dim,
442
+ add_downsample=not is_final_block,
443
+ resnet_eps=norm_eps,
444
+ resnet_act_fn=act_fn,
445
+ resnet_groups=norm_num_groups,
446
+ cross_attention_dim=cross_attention_dim[i],
447
+ num_attention_heads=num_attention_heads[i],
448
+ downsample_padding=downsample_padding,
449
+ dual_cross_attention=dual_cross_attention,
450
+ use_linear_projection=use_linear_projection,
451
+ only_cross_attention=only_cross_attention[i],
452
+ upcast_attention=upcast_attention,
453
+ resnet_time_scale_shift=resnet_time_scale_shift,
454
+ attention_type=attention_type,
455
+ resnet_skip_time_act=resnet_skip_time_act,
456
+ resnet_out_scale_factor=resnet_out_scale_factor,
457
+ cross_attention_norm=cross_attention_norm,
458
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
459
+ use_image_cross_attention=use_image_cross_attention,
460
+ )
461
+ self.down_blocks.append(down_block)
462
+
463
+ # mid
464
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
465
+ self.mid_block = UNetMidBlock2DCrossAttn(
466
+ transformer_layers_per_block=transformer_layers_per_block[-1],
467
+ in_channels=block_out_channels[-1],
468
+ temb_channels=blocks_time_embed_dim,
469
+ resnet_eps=norm_eps,
470
+ resnet_act_fn=act_fn,
471
+ output_scale_factor=mid_block_scale_factor,
472
+ resnet_time_scale_shift=resnet_time_scale_shift,
473
+ cross_attention_dim=cross_attention_dim[-1],
474
+ num_attention_heads=num_attention_heads[-1],
475
+ resnet_groups=norm_num_groups,
476
+ dual_cross_attention=dual_cross_attention,
477
+ use_linear_projection=use_linear_projection,
478
+ upcast_attention=upcast_attention,
479
+ attention_type=attention_type,
480
+ use_image_cross_attention=use_image_cross_attention,
481
+ )
482
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
483
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
484
+ in_channels=block_out_channels[-1],
485
+ temb_channels=blocks_time_embed_dim,
486
+ resnet_eps=norm_eps,
487
+ resnet_act_fn=act_fn,
488
+ output_scale_factor=mid_block_scale_factor,
489
+ cross_attention_dim=cross_attention_dim[-1],
490
+ attention_head_dim=attention_head_dim[-1],
491
+ resnet_groups=norm_num_groups,
492
+ resnet_time_scale_shift=resnet_time_scale_shift,
493
+ skip_time_act=resnet_skip_time_act,
494
+ only_cross_attention=mid_block_only_cross_attention,
495
+ cross_attention_norm=cross_attention_norm,
496
+ )
497
+ elif mid_block_type is None:
498
+ self.mid_block = None
499
+ else:
500
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
501
+
502
+ # count how many layers upsample the images
503
+ self.num_upsamplers = 0
504
+
505
+ # up
506
+ reversed_block_out_channels = list(reversed(block_out_channels))
507
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
508
+ reversed_layers_per_block = list(reversed(layers_per_block))
509
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
510
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
511
+ only_cross_attention = list(reversed(only_cross_attention))
512
+
513
+ output_channel = reversed_block_out_channels[0]
514
+ for i, up_block_type in enumerate(up_block_types):
515
+ is_final_block = i == len(block_out_channels) - 1
516
+
517
+ prev_output_channel = output_channel
518
+ output_channel = reversed_block_out_channels[i]
519
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
520
+
521
+ # add upsample block for all BUT final layer
522
+ if not is_final_block:
523
+ add_upsample = True
524
+ self.num_upsamplers += 1
525
+ else:
526
+ add_upsample = False
527
+
528
+ up_block = get_up_block(
529
+ up_block_type,
530
+ num_layers=reversed_layers_per_block[i] + 1,
531
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
532
+ in_channels=input_channel,
533
+ out_channels=output_channel,
534
+ prev_output_channel=prev_output_channel,
535
+ temb_channels=blocks_time_embed_dim,
536
+ add_upsample=add_upsample,
537
+ resnet_eps=norm_eps,
538
+ resnet_act_fn=act_fn,
539
+ resnet_groups=norm_num_groups,
540
+ cross_attention_dim=reversed_cross_attention_dim[i],
541
+ num_attention_heads=reversed_num_attention_heads[i],
542
+ dual_cross_attention=dual_cross_attention,
543
+ use_linear_projection=use_linear_projection,
544
+ only_cross_attention=only_cross_attention[i],
545
+ upcast_attention=upcast_attention,
546
+ resnet_time_scale_shift=resnet_time_scale_shift,
547
+ attention_type=attention_type,
548
+ resnet_skip_time_act=resnet_skip_time_act,
549
+ resnet_out_scale_factor=resnet_out_scale_factor,
550
+ cross_attention_norm=cross_attention_norm,
551
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
552
+ use_image_cross_attention=use_image_cross_attention,
553
+ )
554
+ self.up_blocks.append(up_block)
555
+ prev_output_channel = output_channel
556
+
557
+ # out
558
+ if norm_num_groups is not None:
559
+ self.conv_norm_out = nn.GroupNorm(
560
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
561
+ )
562
+
563
+ self.conv_act = get_activation(act_fn)
564
+
565
+ else:
566
+ self.conv_norm_out = None
567
+ self.conv_act = None
568
+
569
+ conv_out_padding = (conv_out_kernel - 1) // 2
570
+ self.conv_out = nn.Conv2d(
571
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
572
+ )
573
+
574
+ if attention_type == "gated":
575
+ positive_len = 768
576
+ if isinstance(cross_attention_dim, int):
577
+ positive_len = cross_attention_dim
578
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
579
+ positive_len = cross_attention_dim[0]
580
+ self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)
581
+
582
+ @property
583
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
584
+ r"""
585
+ Returns:
586
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
587
+ indexed by its weight name.
588
+ """
589
+ # set recursively
590
+ processors = {}
591
+
592
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
593
+ if hasattr(module, "get_processor"):
594
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
595
+
596
+ for sub_name, child in module.named_children():
597
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
598
+
599
+ return processors
600
+
601
+ for name, module in self.named_children():
602
+ fn_recursive_add_processors(name, module, processors)
603
+
604
+ return processors
605
+
606
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
607
+ r"""
608
+ Sets the attention processor to use to compute attention.
609
+
610
+ Parameters:
611
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
612
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
613
+ for **all** `Attention` layers.
614
+
615
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
616
+ processor. This is strongly recommended when setting trainable attention processors.
617
+
618
+ """
619
+ count = len(self.attn_processors.keys())
620
+
621
+ if isinstance(processor, dict) and len(processor) != count:
622
+ raise ValueError(
623
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
624
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
625
+ )
626
+
627
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
628
+ if hasattr(module, "set_processor"):
629
+ if not isinstance(processor, dict):
630
+ module.set_processor(processor)
631
+ else:
632
+ module.set_processor(processor.pop(f"{name}.processor"))
633
+
634
+ for sub_name, child in module.named_children():
635
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
636
+
637
+ for name, module in self.named_children():
638
+ fn_recursive_attn_processor(name, module, processor)
639
+
640
+ def set_default_attn_processor(self):
641
+ """
642
+ Disables custom attention processors and sets the default attention implementation.
643
+ """
644
+ self.set_attn_processor(AttnProcessor())
645
+
646
+ def set_attention_slice(self, slice_size):
647
+ r"""
648
+ Enable sliced attention computation.
649
+
650
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
651
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
652
+
653
+ Args:
654
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
655
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
656
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
657
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
658
+ must be a multiple of `slice_size`.
659
+ """
660
+ sliceable_head_dims = []
661
+
662
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
663
+ if hasattr(module, "set_attention_slice"):
664
+ sliceable_head_dims.append(module.sliceable_head_dim)
665
+
666
+ for child in module.children():
667
+ fn_recursive_retrieve_sliceable_dims(child)
668
+
669
+ # retrieve number of attention layers
670
+ for module in self.children():
671
+ fn_recursive_retrieve_sliceable_dims(module)
672
+
673
+ num_sliceable_layers = len(sliceable_head_dims)
674
+
675
+ if slice_size == "auto":
676
+ # half the attention head size is usually a good trade-off between
677
+ # speed and memory
678
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
679
+ elif slice_size == "max":
680
+ # make smallest slice possible
681
+ slice_size = num_sliceable_layers * [1]
682
+
683
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
684
+
685
+ if len(slice_size) != len(sliceable_head_dims):
686
+ raise ValueError(
687
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
688
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
689
+ )
690
+
691
+ for i in range(len(slice_size)):
692
+ size = slice_size[i]
693
+ dim = sliceable_head_dims[i]
694
+ if size is not None and size > dim:
695
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
696
+
697
+ # Recursively walk through all the children.
698
+ # Any children which exposes the set_attention_slice method
699
+ # gets the message
700
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
701
+ if hasattr(module, "set_attention_slice"):
702
+ module.set_attention_slice(slice_size.pop())
703
+
704
+ for child in module.children():
705
+ fn_recursive_set_attention_slice(child, slice_size)
706
+
707
+ reversed_slice_size = list(reversed(slice_size))
708
+ for module in self.children():
709
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
710
+
711
+ def _set_gradient_checkpointing(self, module, value=False):
712
+ if hasattr(module, "gradient_checkpointing"):
713
+ module.gradient_checkpointing = value
714
+
715
+ def forward(
716
+ self,
717
+ sample: torch.FloatTensor,
718
+ timestep: Union[torch.Tensor, float, int],
719
+ encoder_hidden_states: torch.Tensor,
720
+ class_labels: Optional[torch.Tensor] = None,
721
+ timestep_cond: Optional[torch.Tensor] = None,
722
+ attention_mask: Optional[torch.Tensor] = None,
723
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
724
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
725
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
726
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
727
+ encoder_attention_mask: Optional[torch.Tensor] = None,
728
+ return_dict: bool = True,
729
+ image_encoder_hidden_states: torch.Tensor = None,
730
+ ) -> Union[UNet2DConditionOutput, Tuple]:
731
+ r"""
732
+ The [`UNet2DConditionModel`] forward method.
733
+
734
+ Args:
735
+ sample (`torch.FloatTensor`):
736
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
737
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
738
+ encoder_hidden_states (`torch.FloatTensor`):
739
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
740
+ encoder_attention_mask (`torch.Tensor`):
741
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
742
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
743
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
744
+ return_dict (`bool`, *optional*, defaults to `True`):
745
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
746
+ tuple.
747
+ cross_attention_kwargs (`dict`, *optional*):
748
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
749
+ added_cond_kwargs: (`dict`, *optional*):
750
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
751
+ are passed along to the UNet blocks.
752
+
753
+ Returns:
754
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
755
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
756
+ a `tuple` is returned where the first element is the sample tensor.
757
+ """
758
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
759
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
760
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
761
+ # on the fly if necessary.
762
+ default_overall_up_factor = 2**self.num_upsamplers
763
+
764
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
765
+ forward_upsample_size = False
766
+ upsample_size = None
767
+
768
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
769
+ logger.info("Forward upsample size to force interpolation output size.")
770
+ forward_upsample_size = True
771
+
772
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
773
+ # expects mask of shape:
774
+ # [batch, key_tokens]
775
+ # adds singleton query_tokens dimension:
776
+ # [batch, 1, key_tokens]
777
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
778
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
779
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
780
+ if attention_mask is not None:
781
+ # assume that mask is expressed as:
782
+ # (1 = keep, 0 = discard)
783
+ # convert mask into a bias that can be added to attention scores:
784
+ # (keep = +0, discard = -10000.0)
785
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
786
+ attention_mask = attention_mask.unsqueeze(1)
787
+
788
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
789
+ if encoder_attention_mask is not None:
790
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
791
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
792
+
793
+ # 0. center input if necessary
794
+ if self.config.center_input_sample:
795
+ sample = 2 * sample - 1.0
796
+
797
+ # 1. time
798
+ timesteps = timestep
799
+ if not torch.is_tensor(timesteps):
800
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
801
+ # This would be a good case for the `match` statement (Python 3.10+)
802
+ is_mps = sample.device.type == "mps"
803
+ if isinstance(timestep, float):
804
+ dtype = torch.float32 if is_mps else torch.float64
805
+ else:
806
+ dtype = torch.int32 if is_mps else torch.int64
807
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
808
+ elif len(timesteps.shape) == 0:
809
+ timesteps = timesteps[None].to(sample.device)
810
+
811
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
812
+ timesteps = timesteps.expand(sample.shape[0])
813
+
814
+ t_emb = self.time_proj(timesteps)
815
+
816
+ # `Timesteps` does not contain any weights and will always return f32 tensors
817
+ # but time_embedding might actually be running in fp16. so we need to cast here.
818
+ # there might be better ways to encapsulate this.
819
+ t_emb = t_emb.to(dtype=sample.dtype)
820
+
821
+ emb = self.time_embedding(t_emb, timestep_cond)
822
+ aug_emb = None
823
+
824
+ if self.class_embedding is not None:
825
+ if class_labels is None:
826
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
827
+
828
+ if self.config.class_embed_type == "timestep":
829
+ class_labels = self.time_proj(class_labels)
830
+
831
+ # `Timesteps` does not contain any weights and will always return f32 tensors
832
+ # there might be better ways to encapsulate this.
833
+ class_labels = class_labels.to(dtype=sample.dtype)
834
+
835
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
836
+
837
+ if self.config.class_embeddings_concat:
838
+ emb = torch.cat([emb, class_emb], dim=-1)
839
+ else:
840
+ emb = emb + class_emb
841
+
842
+ if self.config.addition_embed_type == "text":
843
+ aug_emb = self.add_embedding(encoder_hidden_states)
844
+ elif self.config.addition_embed_type == "text_image":
845
+ # Kandinsky 2.1 - style
846
+ if "image_embeds" not in added_cond_kwargs:
847
+ raise ValueError(
848
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
849
+ )
850
+
851
+ image_embs = added_cond_kwargs.get("image_embeds")
852
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
853
+ aug_emb = self.add_embedding(text_embs, image_embs)
854
+ elif self.config.addition_embed_type == "text_time":
855
+ # SDXL - style
856
+ if "text_embeds" not in added_cond_kwargs:
857
+ raise ValueError(
858
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
859
+ )
860
+ text_embeds = added_cond_kwargs.get("text_embeds")
861
+ if "time_ids" not in added_cond_kwargs:
862
+ raise ValueError(
863
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
864
+ )
865
+ time_ids = added_cond_kwargs.get("time_ids")
866
+ time_embeds = self.add_time_proj(time_ids.flatten())
867
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
868
+
869
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
870
+ add_embeds = add_embeds.to(emb.dtype)
871
+ aug_emb = self.add_embedding(add_embeds)
872
+ elif self.config.addition_embed_type == "image":
873
+ # Kandinsky 2.2 - style
874
+ if "image_embeds" not in added_cond_kwargs:
875
+ raise ValueError(
876
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
877
+ )
878
+ image_embs = added_cond_kwargs.get("image_embeds")
879
+ aug_emb = self.add_embedding(image_embs)
880
+ elif self.config.addition_embed_type == "image_hint":
881
+ # Kandinsky 2.2 - style
882
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
883
+ raise ValueError(
884
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
885
+ )
886
+ image_embs = added_cond_kwargs.get("image_embeds")
887
+ hint = added_cond_kwargs.get("hint")
888
+ aug_emb, hint = self.add_embedding(image_embs, hint)
889
+ sample = torch.cat([sample, hint], dim=1)
890
+
891
+ emb = emb + aug_emb if aug_emb is not None else emb
892
+
893
+ if self.time_embed_act is not None:
894
+ emb = self.time_embed_act(emb)
895
+
896
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
897
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
898
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
899
+ # Kadinsky 2.1 - style
900
+ if "image_embeds" not in added_cond_kwargs:
901
+ raise ValueError(
902
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
903
+ )
904
+
905
+ image_embeds = added_cond_kwargs.get("image_embeds")
906
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
907
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
908
+ # Kandinsky 2.2 - style
909
+ if "image_embeds" not in added_cond_kwargs:
910
+ raise ValueError(
911
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
912
+ )
913
+ image_embeds = added_cond_kwargs.get("image_embeds")
914
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
915
+ # 2. pre-process
916
+ sample = self.conv_in(sample)
917
+
918
+ # 2.5 GLIGEN position net
919
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
920
+ cross_attention_kwargs = cross_attention_kwargs.copy()
921
+ gligen_args = cross_attention_kwargs.pop("gligen")
922
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
923
+
924
+ # 3. down
925
+
926
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
927
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
928
+
929
+ down_block_res_samples = (sample,)
930
+ for downsample_block in self.down_blocks:
931
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
932
+ # For t2i-adapter CrossAttnDownBlock2D
933
+ additional_residuals = {}
934
+ if is_adapter and len(down_block_additional_residuals) > 0:
935
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
936
+
937
+ sample, res_samples = downsample_block(
938
+ hidden_states=sample,
939
+ temb=emb,
940
+ encoder_hidden_states=encoder_hidden_states,
941
+ attention_mask=attention_mask,
942
+ cross_attention_kwargs=cross_attention_kwargs,
943
+ encoder_attention_mask=encoder_attention_mask,
944
+ image_encoder_hidden_states=image_encoder_hidden_states,
945
+ **additional_residuals,
946
+ )
947
+ else:
948
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
949
+
950
+ if is_adapter and len(down_block_additional_residuals) > 0:
951
+ sample += down_block_additional_residuals.pop(0)
952
+
953
+ down_block_res_samples += res_samples
954
+
955
+ if is_controlnet:
956
+ new_down_block_res_samples = ()
957
+
958
+ for down_block_res_sample, down_block_additional_residual in zip(
959
+ down_block_res_samples, down_block_additional_residuals
960
+ ):
961
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
962
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
963
+
964
+ down_block_res_samples = new_down_block_res_samples
965
+
966
+ # 4. mid
967
+ if self.mid_block is not None:
968
+ sample = self.mid_block(
969
+ sample,
970
+ emb,
971
+ encoder_hidden_states=encoder_hidden_states,
972
+ attention_mask=attention_mask,
973
+ cross_attention_kwargs=cross_attention_kwargs,
974
+ encoder_attention_mask=encoder_attention_mask,
975
+ image_encoder_hidden_states=image_encoder_hidden_states,
976
+ )
977
+ # To support T2I-Adapter-XL
978
+ if (
979
+ is_adapter
980
+ and len(down_block_additional_residuals) > 0
981
+ and sample.shape == down_block_additional_residuals[0].shape
982
+ ):
983
+ sample += down_block_additional_residuals.pop(0)
984
+
985
+ if is_controlnet:
986
+ sample = sample + mid_block_additional_residual
987
+
988
+ # 5. up
989
+ for i, upsample_block in enumerate(self.up_blocks):
990
+ is_final_block = i == len(self.up_blocks) - 1
991
+
992
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
993
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
994
+
995
+ # if we have not reached the final block and need to forward the
996
+ # upsample size, we do it here
997
+ if not is_final_block and forward_upsample_size:
998
+ upsample_size = down_block_res_samples[-1].shape[2:]
999
+
1000
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1001
+ sample = upsample_block(
1002
+ hidden_states=sample,
1003
+ temb=emb,
1004
+ res_hidden_states_tuple=res_samples,
1005
+ encoder_hidden_states=encoder_hidden_states,
1006
+ cross_attention_kwargs=cross_attention_kwargs,
1007
+ upsample_size=upsample_size,
1008
+ attention_mask=attention_mask,
1009
+ encoder_attention_mask=encoder_attention_mask,
1010
+ image_encoder_hidden_states=image_encoder_hidden_states,
1011
+ )
1012
+ else:
1013
+ sample = upsample_block(
1014
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1015
+ )
1016
+
1017
+ # 6. post-process
1018
+ if self.conv_norm_out:
1019
+ sample = self.conv_norm_out(sample)
1020
+ sample = self.conv_act(sample)
1021
+ sample = self.conv_out(sample)
1022
+
1023
+ if not return_dict:
1024
+ return (sample,)
1025
+
1026
+ return UNet2DConditionOutput(sample=sample)
1027
+
1028
+ @classmethod
1029
+ def from_pretrained_orig(cls, pretrained_model_path, seesr_model_path, subfolder=None, use_image_cross_attention=False, **kwargs):
1030
+ if subfolder is not None:
1031
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1032
+ seesr_model_path = os.path.join(seesr_model_path, subfolder)
1033
+
1034
+ config_file = os.path.join(pretrained_model_path, 'config.json')
1035
+ if not os.path.isfile(config_file):
1036
+ raise RuntimeError(f"{config_file} does not exist")
1037
+ with open(config_file, "r") as f:
1038
+ config = json.load(f)
1039
+
1040
+ config['use_image_cross_attention'] = use_image_cross_attention
1041
+
1042
+ from diffusers.utils import WEIGHTS_NAME
1043
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME
1044
+
1045
+
1046
+ model = cls.from_config(config)
1047
+
1048
+ ## for .bin file
1049
+ # model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1050
+ # if not os.path.isfile(model_file):
1051
+ # raise RuntimeError(f"{model_file} does not exist")
1052
+ # state_dict = torch.load(model_file, map_location="cpu")
1053
+ # model.load_state_dict(state_dict, strict=False)
1054
+
1055
+ ## for .safetensors file
1056
+ import safetensors
1057
+ model_file = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME)
1058
+ model_file_seesr = os.path.join(seesr_model_path, SAFETENSORS_WEIGHTS_NAME)
1059
+ if not os.path.isfile(model_file):
1060
+ raise RuntimeError(f"{model_file} does not exist")
1061
+ if not os.path.isfile(model_file_seesr):
1062
+ raise RuntimeError(f"{model_file_seesr} does not exist")
1063
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
1064
+ state_dict_seesr = safetensors.torch.load_file(model_file_seesr, device="cpu")
1065
+ # for k, v in model_seesr.state_dict().items():
1066
+ for k, v in state_dict_seesr.items():
1067
+ if 'image_attentions' in k:
1068
+ state_dict.update({k: v})
1069
+ model.load_state_dict(state_dict, strict=False)
1070
+
1071
+ return model
pipelines/pipeline_seesr.py ADDED
@@ -0,0 +1,1225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import inspect
17
+ import os
18
+ import warnings
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import PIL.Image
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from torchvision.utils import save_image
26
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
27
+
28
+ from diffusers.image_processor import VaeImageProcessor
29
+ from diffusers.loaders import TextualInversionLoaderMixin
30
+ from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
31
+ from diffusers.schedulers import KarrasDiffusionSchedulers
32
+ from diffusers.utils import (
33
+ PIL_INTERPOLATION,
34
+ is_accelerate_available,
35
+ is_accelerate_version,
36
+ logging,
37
+ replace_example_docstring,
38
+ )
39
+
40
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
41
+
42
+ from diffusers.pipeline_utils import DiffusionPipeline
43
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
44
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
45
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
46
+
47
+
48
+ from utils.vaehook import VAEHook, perfcount
49
+
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+
54
+ EXAMPLE_DOC_STRING = """
55
+ Examples:
56
+ ```py
57
+ >>> # !pip install opencv-python transformers accelerate
58
+ >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
59
+ >>> from diffusers.utils import load_image
60
+ >>> import numpy as np
61
+ >>> import torch
62
+
63
+ >>> import cv2
64
+ >>> from PIL import Image
65
+
66
+ >>> # download an image
67
+ >>> image = load_image(
68
+ ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
69
+ ... )
70
+ >>> image = np.array(image)
71
+
72
+ >>> # get canny image
73
+ >>> image = cv2.Canny(image, 100, 200)
74
+ >>> image = image[:, :, None]
75
+ >>> image = np.concatenate([image, image, image], axis=2)
76
+ >>> canny_image = Image.fromarray(image)
77
+
78
+ >>> # load control net and stable diffusion v1-5
79
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
80
+ >>> pipe = StableDiffusionControlNetPipeline.from_pretrained(
81
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
82
+ ... )
83
+
84
+ >>> # speed up diffusion process with faster scheduler and memory optimization
85
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
86
+ >>> # remove following line if xformers is not installed
87
+ >>> pipe.enable_xformers_memory_efficient_attention()
88
+
89
+ >>> pipe.enable_model_cpu_offload()
90
+
91
+ >>> # generate image
92
+ >>> generator = torch.manual_seed(0)
93
+ >>> image = pipe(
94
+ ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
95
+ ... ).images[0]
96
+ ```
97
+ """
98
+
99
+
100
+ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
101
+ r"""
102
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
103
+
104
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
105
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
106
+
107
+ In addition the pipeline inherits the following loading methods:
108
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
109
+
110
+ Args:
111
+ vae ([`AutoencoderKL`]):
112
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
113
+ text_encoder ([`CLIPTextModel`]):
114
+ Frozen text-encoder. Stable Diffusion uses the text portion of
115
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
116
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
117
+ tokenizer (`CLIPTokenizer`):
118
+ Tokenizer of class
119
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
120
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
121
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
122
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
123
+ as a list, the outputs from each ControlNet are added together to create one combined additional
124
+ conditioning.
125
+ scheduler ([`SchedulerMixin`]):
126
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
127
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
128
+ safety_checker ([`StableDiffusionSafetyChecker`]):
129
+ Classification module that estimates whether generated images could be considered offensive or harmful.
130
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
131
+ feature_extractor ([`CLIPImageProcessor`]):
132
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
133
+ """
134
+ _optional_components = ["safety_checker", "feature_extractor"]
135
+
136
+ def __init__(
137
+ self,
138
+ vae: AutoencoderKL,
139
+ text_encoder: CLIPTextModel,
140
+ tokenizer: CLIPTokenizer,
141
+ unet: UNet2DConditionModel,
142
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
143
+ scheduler: KarrasDiffusionSchedulers,
144
+ safety_checker: StableDiffusionSafetyChecker,
145
+ feature_extractor: CLIPImageProcessor,
146
+ requires_safety_checker: bool = True,
147
+ ):
148
+ super().__init__()
149
+
150
+ if safety_checker is None and requires_safety_checker:
151
+ logger.warning(
152
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
153
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
154
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
155
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
156
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
157
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
158
+ )
159
+
160
+ if safety_checker is not None and feature_extractor is None:
161
+ raise ValueError(
162
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
163
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
164
+ )
165
+
166
+ if isinstance(controlnet, (list, tuple)):
167
+ controlnet = MultiControlNetModel(controlnet)
168
+
169
+ self.register_modules(
170
+ vae=vae,
171
+ text_encoder=text_encoder,
172
+ tokenizer=tokenizer,
173
+ unet=unet,
174
+ controlnet=controlnet,
175
+ scheduler=scheduler,
176
+ safety_checker=safety_checker,
177
+ feature_extractor=feature_extractor,
178
+ )
179
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
180
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
181
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
182
+
183
+ def _init_tiled_vae(self,
184
+ encoder_tile_size = 256,
185
+ decoder_tile_size = 256,
186
+ fast_decoder = False,
187
+ fast_encoder = False,
188
+ color_fix = False,
189
+ vae_to_gpu = True):
190
+ # save original forward (only once)
191
+ if not hasattr(self.vae.encoder, 'original_forward'):
192
+ setattr(self.vae.encoder, 'original_forward', self.vae.encoder.forward)
193
+ if not hasattr(self.vae.decoder, 'original_forward'):
194
+ setattr(self.vae.decoder, 'original_forward', self.vae.decoder.forward)
195
+
196
+ encoder = self.vae.encoder
197
+ decoder = self.vae.decoder
198
+
199
+ self.vae.encoder.forward = VAEHook(
200
+ encoder, encoder_tile_size, is_decoder=False, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu)
201
+ self.vae.decoder.forward = VAEHook(
202
+ decoder, decoder_tile_size, is_decoder=True, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu)
203
+
204
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
205
+ def enable_vae_slicing(self):
206
+ r"""
207
+ Enable sliced VAE decoding.
208
+
209
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
210
+ steps. This is useful to save some memory and allow larger batch sizes.
211
+ """
212
+ self.vae.enable_slicing()
213
+
214
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
215
+ def disable_vae_slicing(self):
216
+ r"""
217
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
218
+ computing decoding in one step.
219
+ """
220
+ self.vae.disable_slicing()
221
+
222
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
223
+ def enable_vae_tiling(self):
224
+ r"""
225
+ Enable tiled VAE decoding.
226
+
227
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
228
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
229
+ """
230
+ self.vae.enable_tiling()
231
+
232
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
233
+ def disable_vae_tiling(self):
234
+ r"""
235
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
236
+ computing decoding in one step.
237
+ """
238
+ self.vae.disable_tiling()
239
+
240
+ def enable_sequential_cpu_offload(self, gpu_id=0):
241
+ r"""
242
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
243
+ text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
244
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
245
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
246
+ `enable_model_cpu_offload`, but performance is lower.
247
+ """
248
+ if is_accelerate_available():
249
+ from accelerate import cpu_offload
250
+ else:
251
+ raise ImportError("Please install accelerate via `pip install accelerate`")
252
+
253
+ device = torch.device(f"cuda:{gpu_id}")
254
+
255
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
256
+ cpu_offload(cpu_offloaded_model, device)
257
+
258
+ if self.safety_checker is not None:
259
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
260
+
261
+ def enable_model_cpu_offload(self, gpu_id=0):
262
+ r"""
263
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
264
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
265
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
266
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
267
+ """
268
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
269
+ from accelerate import cpu_offload_with_hook
270
+ else:
271
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
272
+
273
+ device = torch.device(f"cuda:{gpu_id}")
274
+
275
+ hook = None
276
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
277
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
278
+
279
+ if self.safety_checker is not None:
280
+ # the safety checker can offload the vae again
281
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
282
+
283
+ # control net hook has be manually offloaded as it alternates with unet
284
+ cpu_offload_with_hook(self.controlnet, device)
285
+
286
+ # We'll offload the last model manually.
287
+ self.final_offload_hook = hook
288
+
289
+ @property
290
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
291
+ def _execution_device(self):
292
+ r"""
293
+ Returns the device on which the pipeline's models will be executed. After calling
294
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
295
+ hooks.
296
+ """
297
+ if not hasattr(self.unet, "_hf_hook"):
298
+ return self.device
299
+ for module in self.unet.modules():
300
+ if (
301
+ hasattr(module, "_hf_hook")
302
+ and hasattr(module._hf_hook, "execution_device")
303
+ and module._hf_hook.execution_device is not None
304
+ ):
305
+ return torch.device(module._hf_hook.execution_device)
306
+ return self.device
307
+
308
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
309
+ def _encode_prompt(
310
+ self,
311
+ prompt,
312
+ device,
313
+ num_images_per_prompt,
314
+ do_classifier_free_guidance,
315
+ negative_prompt=None,
316
+ prompt_embeds: Optional[torch.FloatTensor] = None,
317
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
318
+ ram_encoder_hidden_states: Optional[torch.FloatTensor] = None,
319
+ ):
320
+ r"""
321
+ Encodes the prompt into text encoder hidden states.
322
+
323
+ Args:
324
+ prompt (`str` or `List[str]`, *optional*):
325
+ prompt to be encoded
326
+ device: (`torch.device`):
327
+ torch device
328
+ num_images_per_prompt (`int`):
329
+ number of images that should be generated per prompt
330
+ do_classifier_free_guidance (`bool`):
331
+ whether to use classifier free guidance or not
332
+ negative_prompt (`str` or `List[str]`, *optional*):
333
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
334
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
335
+ less than `1`).
336
+ prompt_embeds (`torch.FloatTensor`, *optional*):
337
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
338
+ provided, text embeddings will be generated from `prompt` input argument.
339
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
340
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
341
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
342
+ argument.
343
+ """
344
+ if prompt is not None and isinstance(prompt, str):
345
+ batch_size = 1
346
+ elif prompt is not None and isinstance(prompt, list):
347
+ batch_size = len(prompt)
348
+ else:
349
+ batch_size = prompt_embeds.shape[0]
350
+
351
+ if prompt_embeds is None:
352
+ # textual inversion: procecss multi-vector tokens if necessary
353
+ if isinstance(self, TextualInversionLoaderMixin):
354
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
355
+
356
+ text_inputs = self.tokenizer(
357
+ prompt,
358
+ padding="max_length",
359
+ max_length=self.tokenizer.model_max_length,
360
+ truncation=True,
361
+ return_tensors="pt",
362
+ )
363
+ text_input_ids = text_inputs.input_ids
364
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
365
+
366
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
367
+ text_input_ids, untruncated_ids
368
+ ):
369
+ removed_text = self.tokenizer.batch_decode(
370
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
371
+ )
372
+ logger.warning(
373
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
374
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
375
+ )
376
+
377
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
378
+ attention_mask = text_inputs.attention_mask.to(device)
379
+ else:
380
+ attention_mask = None
381
+
382
+ prompt_embeds = self.text_encoder(
383
+ text_input_ids.to(device),
384
+ attention_mask=attention_mask,
385
+ )
386
+ prompt_embeds = prompt_embeds[0]
387
+
388
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
389
+
390
+ bs_embed, seq_len, _ = prompt_embeds.shape
391
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
392
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
393
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
394
+
395
+ # get unconditional embeddings for classifier free guidance
396
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
397
+ uncond_tokens: List[str]
398
+ if negative_prompt is None:
399
+ uncond_tokens = [""] * batch_size
400
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
401
+ raise TypeError(
402
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
403
+ f" {type(prompt)}."
404
+ )
405
+ elif isinstance(negative_prompt, str):
406
+ uncond_tokens = [negative_prompt]
407
+ elif batch_size != len(negative_prompt):
408
+ raise ValueError(
409
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
410
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
411
+ " the batch size of `prompt`."
412
+ )
413
+ else:
414
+ uncond_tokens = negative_prompt
415
+
416
+ # textual inversion: procecss multi-vector tokens if necessary
417
+ if isinstance(self, TextualInversionLoaderMixin):
418
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
419
+
420
+ max_length = prompt_embeds.shape[1]
421
+ uncond_input = self.tokenizer(
422
+ uncond_tokens,
423
+ padding="max_length",
424
+ max_length=max_length,
425
+ truncation=True,
426
+ return_tensors="pt",
427
+ )
428
+
429
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
430
+ attention_mask = uncond_input.attention_mask.to(device)
431
+ else:
432
+ attention_mask = None
433
+
434
+ negative_prompt_embeds = self.text_encoder(
435
+ uncond_input.input_ids.to(device),
436
+ attention_mask=attention_mask,
437
+ )
438
+ negative_prompt_embeds = negative_prompt_embeds[0]
439
+
440
+ if do_classifier_free_guidance:
441
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
442
+ seq_len = negative_prompt_embeds.shape[1]
443
+
444
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
445
+
446
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
447
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
448
+
449
+ # For classifier free guidance, we need to do two forward passes.
450
+ # Here we concatenate the unconditional and text embeddings into a single batch
451
+ # to avoid doing two forward passes
452
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
453
+ ram_encoder_hidden_states = torch.cat([ram_encoder_hidden_states, ram_encoder_hidden_states])
454
+
455
+ return prompt_embeds, ram_encoder_hidden_states
456
+
457
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
458
+ def run_safety_checker(self, image, device, dtype):
459
+ if self.safety_checker is None:
460
+ has_nsfw_concept = None
461
+ else:
462
+ if torch.is_tensor(image):
463
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
464
+ else:
465
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
466
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
467
+ image, has_nsfw_concept = self.safety_checker(
468
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
469
+ )
470
+ return image, has_nsfw_concept
471
+
472
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
473
+ def decode_latents(self, latents):
474
+ warnings.warn(
475
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
476
+ " use VaeImageProcessor instead",
477
+ FutureWarning,
478
+ )
479
+ latents = 1 / self.vae.config.scaling_factor * latents
480
+ image = self.vae.decode(latents, return_dict=False)[0]
481
+ image = (image / 2 + 0.5).clamp(0, 1)
482
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
483
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
484
+ return image
485
+
486
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
487
+ def prepare_extra_step_kwargs(self, generator, eta):
488
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
489
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
490
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
491
+ # and should be between [0, 1]
492
+
493
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
494
+ extra_step_kwargs = {}
495
+ if accepts_eta:
496
+ extra_step_kwargs["eta"] = eta
497
+
498
+ # check if the scheduler accepts generator
499
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
500
+ if accepts_generator:
501
+ extra_step_kwargs["generator"] = generator
502
+ #extra_step_kwargs["generator"] = generator
503
+ return extra_step_kwargs
504
+
505
+ def check_inputs(
506
+ self,
507
+ prompt,
508
+ image,
509
+ height,
510
+ width,
511
+ callback_steps,
512
+ negative_prompt=None,
513
+ prompt_embeds=None,
514
+ negative_prompt_embeds=None,
515
+ controlnet_conditioning_scale=1.0,
516
+ ):
517
+ if height % 8 != 0 or width % 8 != 0:
518
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
519
+
520
+ if (callback_steps is None) or (
521
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
522
+ ):
523
+ raise ValueError(
524
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
525
+ f" {type(callback_steps)}."
526
+ )
527
+
528
+ if prompt is not None and prompt_embeds is not None:
529
+ raise ValueError(
530
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
531
+ " only forward one of the two."
532
+ )
533
+ elif prompt is None and prompt_embeds is None:
534
+ raise ValueError(
535
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
536
+ )
537
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
538
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
539
+
540
+ if negative_prompt is not None and negative_prompt_embeds is not None:
541
+ raise ValueError(
542
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
543
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
544
+ )
545
+
546
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
547
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
548
+ raise ValueError(
549
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
550
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
551
+ f" {negative_prompt_embeds.shape}."
552
+ )
553
+
554
+ # `prompt` needs more sophisticated handling when there are multiple
555
+ # conditionings.
556
+ if isinstance(self.controlnet, MultiControlNetModel):
557
+ if isinstance(prompt, list):
558
+ logger.warning(
559
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
560
+ " prompts. The conditionings will be fixed across the prompts."
561
+ )
562
+
563
+ # Check `image`
564
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
565
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
566
+ )
567
+ if (
568
+ isinstance(self.controlnet, ControlNetModel)
569
+ or is_compiled
570
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
571
+ ):
572
+ self.check_image(image, prompt, prompt_embeds)
573
+ elif (
574
+ isinstance(self.controlnet, MultiControlNetModel)
575
+ or is_compiled
576
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
577
+ ):
578
+ if not isinstance(image, list):
579
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
580
+
581
+ # When `image` is a nested list:
582
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
583
+ elif any(isinstance(i, list) for i in image):
584
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
585
+ elif len(image) != len(self.controlnet.nets):
586
+ raise ValueError(
587
+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
588
+ )
589
+
590
+ for image_ in image:
591
+ self.check_image(image_, prompt, prompt_embeds)
592
+ else:
593
+ assert False
594
+
595
+ # Check `controlnet_conditioning_scale`
596
+ if (
597
+ isinstance(self.controlnet, ControlNetModel)
598
+ or is_compiled
599
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
600
+ ):
601
+ if not isinstance(controlnet_conditioning_scale, float):
602
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
603
+ elif (
604
+ isinstance(self.controlnet, MultiControlNetModel)
605
+ or is_compiled
606
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
607
+ ):
608
+ if isinstance(controlnet_conditioning_scale, list):
609
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
610
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
611
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
612
+ self.controlnet.nets
613
+ ):
614
+ raise ValueError(
615
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
616
+ " the same length as the number of controlnets"
617
+ )
618
+ else:
619
+ assert False
620
+
621
+ def check_image(self, image, prompt, prompt_embeds):
622
+ image_is_pil = isinstance(image, PIL.Image.Image)
623
+ image_is_tensor = isinstance(image, torch.Tensor)
624
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
625
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
626
+
627
+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
628
+ raise TypeError(
629
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
630
+ )
631
+
632
+ if image_is_pil:
633
+ image_batch_size = 1
634
+ elif image_is_tensor:
635
+ image_batch_size = image.shape[0]
636
+ elif image_is_pil_list:
637
+ image_batch_size = len(image)
638
+ elif image_is_tensor_list:
639
+ image_batch_size = len(image)
640
+
641
+ if prompt is not None and isinstance(prompt, str):
642
+ prompt_batch_size = 1
643
+ elif prompt is not None and isinstance(prompt, list):
644
+ prompt_batch_size = len(prompt)
645
+ elif prompt_embeds is not None:
646
+ prompt_batch_size = prompt_embeds.shape[0]
647
+
648
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
649
+ raise ValueError(
650
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
651
+ )
652
+
653
+ def prepare_image(
654
+ self,
655
+ image,
656
+ width,
657
+ height,
658
+ batch_size,
659
+ num_images_per_prompt,
660
+ device,
661
+ dtype,
662
+ do_classifier_free_guidance=False,
663
+ guess_mode=False,
664
+ ):
665
+ if not isinstance(image, torch.Tensor):
666
+ if isinstance(image, PIL.Image.Image):
667
+ image = [image]
668
+
669
+ if isinstance(image[0], PIL.Image.Image):
670
+ images = []
671
+
672
+ for image_ in image:
673
+ image_ = image_.convert("RGB")
674
+ #image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
675
+ image_ = np.array(image_)
676
+ image_ = image_[None, :]
677
+ images.append(image_)
678
+
679
+ image = images
680
+
681
+ image = np.concatenate(image, axis=0)
682
+ image = np.array(image).astype(np.float32) / 255.0
683
+ image = image.transpose(0, 3, 1, 2)
684
+ image = torch.from_numpy(image)#.flip(1)
685
+ elif isinstance(image[0], torch.Tensor):
686
+ image = torch.cat(image, dim=0)
687
+
688
+ image_batch_size = image.shape[0]
689
+
690
+ if image_batch_size == 1:
691
+ repeat_by = batch_size
692
+ else:
693
+ # image batch size is the same as prompt batch size
694
+ repeat_by = num_images_per_prompt
695
+
696
+ image = image.repeat_interleave(repeat_by, dim=0)
697
+
698
+ image = image.to(device=device, dtype=dtype)
699
+
700
+ if do_classifier_free_guidance and not guess_mode:
701
+ image = torch.cat([image] * 2)
702
+
703
+ return image
704
+
705
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
706
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
707
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
708
+ if isinstance(generator, list) and len(generator) != batch_size:
709
+ raise ValueError(
710
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
711
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
712
+ )
713
+
714
+ if latents is None:
715
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
716
+ #latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
717
+ #offset_noise = torch.randn(batch_size, num_channels_latents, 1, 1, device=device)
718
+ #latents = latents + 0.1 * offset_noise
719
+ else:
720
+ latents = latents.to(device)
721
+
722
+ # scale the initial noise by the standard deviation required by the scheduler
723
+ latents = latents * self.scheduler.init_noise_sigma
724
+ return latents
725
+
726
+ def _default_height_width(self, height, width, image):
727
+ # NOTE: It is possible that a list of images have different
728
+ # dimensions for each image, so just checking the first image
729
+ # is not _exactly_ correct, but it is simple.
730
+ while isinstance(image, list):
731
+ image = image[0]
732
+
733
+ if height is None:
734
+ if isinstance(image, PIL.Image.Image):
735
+ height = image.height
736
+ elif isinstance(image, torch.Tensor):
737
+ height = image.shape[2]
738
+
739
+ height = (height // 8) * 8 # round down to nearest multiple of 8
740
+
741
+ if width is None:
742
+ if isinstance(image, PIL.Image.Image):
743
+ width = image.width
744
+ elif isinstance(image, torch.Tensor):
745
+ width = image.shape[3]
746
+
747
+ width = (width // 8) * 8 # round down to nearest multiple of 8
748
+
749
+ return height, width
750
+
751
+ # override DiffusionPipeline
752
+ def save_pretrained(
753
+ self,
754
+ save_directory: Union[str, os.PathLike],
755
+ safe_serialization: bool = False,
756
+ variant: Optional[str] = None,
757
+ ):
758
+ if isinstance(self.controlnet, ControlNetModel):
759
+ super().save_pretrained(save_directory, safe_serialization, variant)
760
+ else:
761
+ raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")
762
+
763
+ def _gaussian_weights(self, tile_width, tile_height, nbatches):
764
+ """Generates a gaussian mask of weights for tile contributions"""
765
+ from numpy import pi, exp, sqrt
766
+ import numpy as np
767
+
768
+ latent_width = tile_width
769
+ latent_height = tile_height
770
+
771
+ var = 0.01
772
+ midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
773
+ x_probs = [exp(-(x-midpoint)*(x-midpoint)/(latent_width*latent_width)/(2*var)) / sqrt(2*pi*var) for x in range(latent_width)]
774
+ midpoint = latent_height / 2
775
+ y_probs = [exp(-(y-midpoint)*(y-midpoint)/(latent_height*latent_height)/(2*var)) / sqrt(2*pi*var) for y in range(latent_height)]
776
+
777
+ weights = np.outer(y_probs, x_probs)
778
+ return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1))
779
+
780
+ @perfcount
781
+ @torch.no_grad()
782
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
783
+ def __call__(
784
+ self,
785
+ prompt: Union[str, List[str]] = None,
786
+ image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
787
+ height: Optional[int] = None,
788
+ width: Optional[int] = None,
789
+ num_inference_steps: int = 50,
790
+ guidance_scale: float = 7.5,
791
+ negative_prompt: Optional[Union[str, List[str]]] = None,
792
+ num_images_per_prompt: Optional[int] = 1,
793
+ eta: float = 0.0,
794
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
795
+ latents: Optional[torch.FloatTensor] = None,
796
+ prompt_embeds: Optional[torch.FloatTensor] = None,
797
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
798
+ output_type: Optional[str] = "pil",
799
+ return_dict: bool = True,
800
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
801
+ callback_steps: int = 1,
802
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
803
+ conditioning_scale: Union[float, List[float]] = 1.0,
804
+ guess_mode: bool = False,
805
+ image_sr = None,
806
+ start_steps = 999,
807
+ start_point = 'noise',
808
+ ram_encoder_hidden_states=None,
809
+ latent_tiled_size=320,
810
+ latent_tiled_overlap=4,
811
+ args=None
812
+ ):
813
+ r"""
814
+ Function invoked when calling the pipeline for generation.
815
+
816
+ Args:
817
+ prompt (`str` or `List[str]`, *optional*):
818
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
819
+ instead.
820
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
821
+ `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
822
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
823
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
824
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
825
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
826
+ specified in init, images must be passed as a list such that each element of the list can be correctly
827
+ batched for input to a single controlnet.
828
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
829
+ The height in pixels of the generated image.
830
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
831
+ The width in pixels of the generated image.
832
+ num_inference_steps (`int`, *optional*, defaults to 50):
833
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
834
+ expense of slower inference.
835
+ guidance_scale (`float`, *optional*, defaults to 7.5):
836
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
837
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
838
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
839
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
840
+ usually at the expense of lower image quality.
841
+ negative_prompt (`str` or `List[str]`, *optional*):
842
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
843
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
844
+ less than `1`).
845
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
846
+ The number of images to generate per prompt.
847
+ eta (`float`, *optional*, defaults to 0.0):
848
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
849
+ [`schedulers.DDIMScheduler`], will be ignored for others.
850
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
851
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
852
+ to make generation deterministic.
853
+ latents (`torch.FloatTensor`, *optional*):
854
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
855
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
856
+ tensor will ge generated by sampling using the supplied random `generator`.
857
+ prompt_embeds (`torch.FloatTensor`, *optional*):
858
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
859
+ provided, text embeddings will be generated from `prompt` input argument.
860
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
861
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
862
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
863
+ argument.
864
+ output_type (`str`, *optional*, defaults to `"pil"`):
865
+ The output format of the generate image. Choose between
866
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
867
+ return_dict (`bool`, *optional*, defaults to `True`):
868
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
869
+ plain tuple.
870
+ callback (`Callable`, *optional*):
871
+ A function that will be called every `callback_steps` steps during inference. The function will be
872
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
873
+ callback_steps (`int`, *optional*, defaults to 1):
874
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
875
+ called at every step.
876
+ cross_attention_kwargs (`dict`, *optional*):
877
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
878
+ `self.processor` in
879
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
880
+ conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
881
+ The outputs of the controlnet are multiplied by `conditioning_scale` before they are added
882
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
883
+ corresponding scale as a list.
884
+ guess_mode (`bool`, *optional*, defaults to `False`):
885
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
886
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
887
+
888
+ Examples:
889
+
890
+ Returns:
891
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
892
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
893
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
894
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
895
+ (nsfw) content, according to the `safety_checker`.
896
+ """
897
+ # 0. Default height and width to unet
898
+ height, width = self._default_height_width(height, width, image)
899
+
900
+ # 1. Check inputs. Raise error if not correct
901
+ """
902
+ self.check_inputs(
903
+ prompt,
904
+ image,
905
+ height,
906
+ width,
907
+ callback_steps,
908
+ negative_prompt,
909
+ prompt_embeds,
910
+ negative_prompt_embeds,
911
+ conditioning_scale,
912
+ )
913
+ """
914
+
915
+ # 2. Define call parameters
916
+ if prompt is not None and isinstance(prompt, str):
917
+ batch_size = 1
918
+ elif prompt is not None and isinstance(prompt, list):
919
+ batch_size = len(prompt)
920
+ else:
921
+ batch_size = prompt_embeds.shape[0]
922
+
923
+ device = self._execution_device
924
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
925
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
926
+ # corresponds to doing no classifier free guidance.
927
+ do_classifier_free_guidance = guidance_scale > 1.0
928
+
929
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
930
+ """
931
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(conditioning_scale, float):
932
+ conditioning_scale = [conditioning_scale] * len(controlnet.nets)
933
+
934
+ global_pool_conditions = (
935
+ controlnet.config.global_pool_conditions
936
+ if isinstance(controlnet, ControlNetModel)
937
+ else controlnet.nets[0].config.global_pool_conditions
938
+ )
939
+
940
+ guess_mode = guess_mode or global_pool_conditions
941
+ """
942
+
943
+ # 3. Encode input prompt
944
+ prompt_embeds, ram_encoder_hidden_states = self._encode_prompt(
945
+ prompt,
946
+ device,
947
+ num_images_per_prompt,
948
+ do_classifier_free_guidance,
949
+ negative_prompt,
950
+ prompt_embeds=prompt_embeds,
951
+ negative_prompt_embeds=negative_prompt_embeds,
952
+ ram_encoder_hidden_states=ram_encoder_hidden_states
953
+ )
954
+
955
+ # 4. Prepare image
956
+ image = self.prepare_image(
957
+ image=image,
958
+ width=width,
959
+ height=height,
960
+ batch_size=batch_size * num_images_per_prompt,
961
+ num_images_per_prompt=num_images_per_prompt,
962
+ device=device,
963
+ dtype=controlnet.dtype,
964
+ do_classifier_free_guidance=do_classifier_free_guidance,
965
+ guess_mode=guess_mode,
966
+ )
967
+
968
+ # 5. Prepare timesteps
969
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
970
+ timesteps = self.scheduler.timesteps
971
+
972
+ # 6. Prepare latent variables
973
+ num_channels_latents = self.unet.config.in_channels
974
+ latents = self.prepare_latents(
975
+ batch_size * num_images_per_prompt,
976
+ num_channels_latents,
977
+ height,
978
+ width,
979
+ prompt_embeds.dtype,
980
+ device,
981
+ generator,
982
+ latents,
983
+ )
984
+
985
+ # 6. Prepare the start point
986
+ if start_point == 'noise':
987
+ latents = latents
988
+ elif start_point == 'lr': # LRE Strategy
989
+ latents_condition_image = self.vae.encode(image*2-1).latent_dist.sample()
990
+ latents_condition_image = latents_condition_image * self.vae.config.scaling_factor
991
+ start_steps_tensor = torch.randint(start_steps, start_steps+1, (latents.shape[0],), device=latents.device)
992
+ start_steps_tensor = start_steps_tensor.long()
993
+ latents = self.scheduler.add_noise(latents_condition_image[0:1, ...], latents, start_steps_tensor)
994
+
995
+
996
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
997
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
998
+
999
+ # 8. Denoising loop
1000
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1001
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1002
+
1003
+ _, _, h, w = latents.size()
1004
+ tile_size, tile_overlap = (latent_tiled_size, latent_tiled_overlap) if args is not None else (256, 8)
1005
+ if h*w<=tile_size*tile_size:
1006
+ print(f"[Tiled Latent]: the input size is tiny and unnecessary to tile.")
1007
+ else:
1008
+ print(f"[Tiled Latent]: the input size is {image.shape[-2]}x{image.shape[-1]}, need to tiled")
1009
+
1010
+ for i, t in enumerate(timesteps):
1011
+ # pass, if the timestep is larger than start_steps
1012
+ if t > start_steps:
1013
+ print(f'pass {t} steps.')
1014
+ continue
1015
+
1016
+ # expand the latents if we are doing classifier free guidance
1017
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1018
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1019
+
1020
+ # controlnet(s) inference
1021
+ if guess_mode and do_classifier_free_guidance:
1022
+ # Infer ControlNet only for the conditional batch.
1023
+ controlnet_latent_model_input = latents
1024
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1025
+
1026
+ else:
1027
+ controlnet_latent_model_input = latent_model_input
1028
+ controlnet_prompt_embeds = prompt_embeds
1029
+
1030
+ if h*w<=tile_size*tile_size: # tiled latent input
1031
+ down_block_res_samples, mid_block_res_sample = [None]*10, None
1032
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1033
+ controlnet_latent_model_input,
1034
+ t,
1035
+ encoder_hidden_states=controlnet_prompt_embeds,
1036
+ controlnet_cond=image,
1037
+ conditioning_scale=conditioning_scale,
1038
+ guess_mode=guess_mode,
1039
+ return_dict=False,
1040
+ image_encoder_hidden_states = ram_encoder_hidden_states,
1041
+ )
1042
+
1043
+
1044
+ if guess_mode and do_classifier_free_guidance:
1045
+ # Infered ControlNet only for the conditional batch.
1046
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1047
+ # add 0 to the unconditional batch to keep it unchanged.
1048
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1049
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1050
+
1051
+ # predict the noise residual
1052
+ noise_pred = self.unet(
1053
+ latent_model_input,
1054
+ t,
1055
+ encoder_hidden_states=prompt_embeds,
1056
+ cross_attention_kwargs=cross_attention_kwargs,
1057
+ down_block_additional_residuals=down_block_res_samples,
1058
+ mid_block_additional_residual=mid_block_res_sample,
1059
+ return_dict=False,
1060
+ image_encoder_hidden_states = ram_encoder_hidden_states,
1061
+ )[0]
1062
+ else:
1063
+ tile_weights = self._gaussian_weights(tile_size, tile_size, 1)
1064
+ tile_size = min(tile_size, min(h, w))
1065
+ tile_weights = self._gaussian_weights(tile_size, tile_size, 1)
1066
+
1067
+ grid_rows = 0
1068
+ cur_x = 0
1069
+ while cur_x < latent_model_input.size(-1):
1070
+ cur_x = max(grid_rows * tile_size-tile_overlap * grid_rows, 0)+tile_size
1071
+ grid_rows += 1
1072
+
1073
+ grid_cols = 0
1074
+ cur_y = 0
1075
+ while cur_y < latent_model_input.size(-2):
1076
+ cur_y = max(grid_cols * tile_size-tile_overlap * grid_cols, 0)+tile_size
1077
+ grid_cols += 1
1078
+
1079
+ input_list = []
1080
+ cond_list = []
1081
+ img_list = []
1082
+ noise_preds = []
1083
+ for row in range(grid_rows):
1084
+ noise_preds_row = []
1085
+ for col in range(grid_cols):
1086
+ if col < grid_cols-1 or row < grid_rows-1:
1087
+ # extract tile from input image
1088
+ ofs_x = max(row * tile_size-tile_overlap * row, 0)
1089
+ ofs_y = max(col * tile_size-tile_overlap * col, 0)
1090
+ # input tile area on total image
1091
+ if row == grid_rows-1:
1092
+ ofs_x = w - tile_size
1093
+ if col == grid_cols-1:
1094
+ ofs_y = h - tile_size
1095
+
1096
+ input_start_x = ofs_x
1097
+ input_end_x = ofs_x + tile_size
1098
+ input_start_y = ofs_y
1099
+ input_end_y = ofs_y + tile_size
1100
+
1101
+ # input tile dimensions
1102
+ input_tile = latent_model_input[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
1103
+ input_list.append(input_tile)
1104
+ cond_tile = controlnet_latent_model_input[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
1105
+ cond_list.append(cond_tile)
1106
+ img_tile = image[:, :, input_start_y*8:input_end_y*8, input_start_x*8:input_end_x*8]
1107
+ img_list.append(img_tile)
1108
+
1109
+ if len(input_list) == batch_size or col == grid_cols-1:
1110
+ input_list_t = torch.cat(input_list, dim=0)
1111
+ cond_list_t = torch.cat(cond_list, dim=0)
1112
+ img_list_t = torch.cat(img_list, dim=0)
1113
+ #print(input_list_t.shape, cond_list_t.shape, img_list_t.shape, fg_mask_list_t.shape)
1114
+
1115
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1116
+ cond_list_t,
1117
+ t,
1118
+ encoder_hidden_states=controlnet_prompt_embeds,
1119
+ controlnet_cond=img_list_t,
1120
+ conditioning_scale=conditioning_scale,
1121
+ guess_mode=guess_mode,
1122
+ return_dict=False,
1123
+ image_encoder_hidden_states = ram_encoder_hidden_states,
1124
+ )
1125
+
1126
+ if guess_mode and do_classifier_free_guidance:
1127
+ # Infered ControlNet only for the conditional batch.
1128
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1129
+ # add 0 to the unconditional batch to keep it unchanged.
1130
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1131
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1132
+
1133
+ # predict the noise residual
1134
+ model_out = self.unet(
1135
+ input_list_t,
1136
+ t,
1137
+ encoder_hidden_states=prompt_embeds,
1138
+ cross_attention_kwargs=cross_attention_kwargs,
1139
+ down_block_additional_residuals=down_block_res_samples,
1140
+ mid_block_additional_residual=mid_block_res_sample,
1141
+ return_dict=False,
1142
+ image_encoder_hidden_states = ram_encoder_hidden_states,
1143
+ )[0]
1144
+
1145
+ #for sample_i in range(model_out.size(0)):
1146
+ # noise_preds_row.append(model_out[sample_i].unsqueeze(0))
1147
+ input_list = []
1148
+ cond_list = []
1149
+ img_list = []
1150
+
1151
+ noise_preds.append(model_out)
1152
+
1153
+ # Stitch noise predictions for all tiles
1154
+ noise_pred = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
1155
+ contributors = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
1156
+ # Add each tile contribution to overall latents
1157
+ for row in range(grid_rows):
1158
+ for col in range(grid_cols):
1159
+ if col < grid_cols-1 or row < grid_rows-1:
1160
+ # extract tile from input image
1161
+ ofs_x = max(row * tile_size-tile_overlap * row, 0)
1162
+ ofs_y = max(col * tile_size-tile_overlap * col, 0)
1163
+ # input tile area on total image
1164
+ if row == grid_rows-1:
1165
+ ofs_x = w - tile_size
1166
+ if col == grid_cols-1:
1167
+ ofs_y = h - tile_size
1168
+
1169
+ input_start_x = ofs_x
1170
+ input_end_x = ofs_x + tile_size
1171
+ input_start_y = ofs_y
1172
+ input_end_y = ofs_y + tile_size
1173
+
1174
+ noise_pred[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += noise_preds[row*grid_cols + col] * tile_weights
1175
+ contributors[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += tile_weights
1176
+ # Average overlapping areas with more than 1 contributor
1177
+ noise_pred /= contributors
1178
+
1179
+
1180
+ # perform guidance
1181
+ if do_classifier_free_guidance:
1182
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1183
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1184
+
1185
+
1186
+
1187
+ # compute the previous noisy sample x_t -> x_t-1
1188
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1189
+
1190
+ # call the callback, if provided
1191
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1192
+ progress_bar.update()
1193
+ if callback is not None and i % callback_steps == 0:
1194
+ callback(i, t, latents)
1195
+
1196
+ # If we do sequential model offloading, let's offload unet and controlnet
1197
+ # manually for max memory savings
1198
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1199
+ self.unet.to("cpu")
1200
+ self.controlnet.to("cpu")
1201
+ torch.cuda.empty_cache()
1202
+
1203
+ has_nsfw_concept = None
1204
+ if not output_type == "latent":
1205
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]#.flip(1)
1206
+ #image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1207
+ else:
1208
+ image = latents
1209
+ has_nsfw_concept = None
1210
+
1211
+ if has_nsfw_concept is None:
1212
+ do_denormalize = [True] * image.shape[0]
1213
+ else:
1214
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1215
+
1216
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1217
+
1218
+ # Offload last model to CPU
1219
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1220
+ self.final_offload_hook.offload()
1221
+
1222
+ if not return_dict:
1223
+ return (image, has_nsfw_concept)
1224
+
1225
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)