Eyalgut commited on
Commit
be084ae
·
verified ·
1 Parent(s): 9549cef

Upload 3 files

Browse files
Files changed (3) hide show
  1. bria_utils.py +71 -0
  2. pipeline_bria.py +562 -0
  3. transformer_bria.py +335 -0
bria_utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Optional, List
2
+ import torch
3
+ from diffusers.utils import logging
4
+ from transformers import (
5
+ T5EncoderModel,
6
+ T5TokenizerFast,
7
+ )
8
+ import numpy as np
9
+
10
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
11
+
12
+ def get_t5_prompt_embeds(
13
+ tokenizer: T5TokenizerFast ,
14
+ text_encoder: T5EncoderModel,
15
+ prompt: Union[str, List[str]] = None,
16
+ num_images_per_prompt: int = 1,
17
+ max_sequence_length: int = 128,
18
+ device: Optional[torch.device] = None,
19
+ ):
20
+ device = device or text_encoder.device
21
+
22
+ prompt = [prompt] if isinstance(prompt, str) else prompt
23
+ batch_size = len(prompt)
24
+
25
+ text_inputs = tokenizer(
26
+ prompt,
27
+ # padding="max_length",
28
+ max_length=max_sequence_length,
29
+ truncation=True,
30
+ add_special_tokens=True,
31
+ return_tensors="pt",
32
+ )
33
+ text_input_ids = text_inputs.input_ids
34
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
35
+
36
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
37
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
38
+ logger.warning(
39
+ "The following part of your input was truncated because `max_sequence_length` is set to "
40
+ f" {max_sequence_length} tokens: {removed_text}"
41
+ )
42
+
43
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
44
+
45
+ # Concat zeros to max_sequence
46
+ b, seq_len, dim = prompt_embeds.shape
47
+ if seq_len<max_sequence_length:
48
+ padding = torch.zeros((b,max_sequence_length-seq_len,dim),dtype=prompt_embeds.dtype,device=prompt_embeds.device)
49
+ prompt_embeds = torch.concat([prompt_embeds,padding],dim=1)
50
+
51
+ prompt_embeds = prompt_embeds.to(device=device)
52
+
53
+ _, seq_len, _ = prompt_embeds.shape
54
+
55
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
56
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
57
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
58
+
59
+ return prompt_embeds
60
+
61
+ # in order the get the same sigmas as in training and sample from them
62
+ def get_original_sigmas(num_train_timesteps=1000,num_inference_steps=1000):
63
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
64
+ sigmas = timesteps / num_train_timesteps
65
+
66
+ inds = [int(ind) for ind in np.linspace(0, num_train_timesteps-1, num_inference_steps)]
67
+ new_sigmas = sigmas[inds]
68
+ return new_sigmas
69
+
70
+ def is_ng_none(negative_prompt):
71
+ return negative_prompt is None or negative_prompt=='' or (isinstance(negative_prompt,list) and negative_prompt[0] is None) or (type(negative_prompt)==list and negative_prompt[0]=='')
pipeline_bria.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, retrieve_timesteps
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import torch
5
+
6
+ from transformers import (
7
+ T5EncoderModel,
8
+ T5TokenizerFast,
9
+ )
10
+
11
+ from diffusers.image_processor import VaeImageProcessor
12
+ from diffusers import AutoencoderKL , DDIMScheduler, EulerAncestralDiscreteScheduler
13
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
14
+ from diffusers.schedulers import KarrasDiffusionSchedulers
15
+ from diffusers.loaders import FluxLoraLoaderMixin
16
+ from diffusers.utils import (
17
+ USE_PEFT_BACKEND,
18
+ is_torch_xla_available,
19
+ logging,
20
+ replace_example_docstring,
21
+ scale_lora_layers,
22
+ unscale_lora_layers,
23
+ )
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
26
+ from transformer_bria import BriaTransformer2DModel
27
+ from bria_utils import get_t5_prompt_embeds, get_original_sigmas, is_ng_none
28
+
29
+ if is_torch_xla_available():
30
+ import torch_xla.core.xla_model as xm
31
+
32
+ XLA_AVAILABLE = True
33
+ else:
34
+ XLA_AVAILABLE = False
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+ EXAMPLE_DOC_STRING = """
40
+ Examples:
41
+ ```py
42
+ >>> import torch
43
+ >>> from diffusers import StableDiffusion3Pipeline
44
+
45
+ >>> pipe = StableDiffusion3Pipeline.from_pretrained(
46
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
47
+ ... )
48
+ >>> pipe.to("cuda")
49
+ >>> prompt = "A cat holding a sign that says hello world"
50
+ >>> image = pipe(prompt).images[0]
51
+ >>> image.save("sd3.png")
52
+ ```
53
+ """
54
+
55
+ T5_PRECISION = torch.float16
56
+
57
+ """
58
+ Based on FluxPipeline with several changes:
59
+ - no pooled embeddings
60
+ - We use zero padding for prompts
61
+ - No guidance embedding since this is not a distilled version
62
+ """
63
+ class BriaPipeline(FluxPipeline):
64
+ r"""
65
+ Args:
66
+ transformer ([`SD3Transformer2DModel`]):
67
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
68
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
69
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
70
+ vae ([`AutoencoderKL`]):
71
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
72
+ text_encoder ([`T5EncoderModel`]):
73
+ Frozen text-encoder. Stable Diffusion 3 uses
74
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
75
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
76
+ tokenizer (`T5TokenizerFast`):
77
+ Tokenizer of class
78
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
79
+ """
80
+
81
+ # model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder->transformer->vae"
82
+ # _optional_components = []
83
+ # _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
84
+
85
+ def __init__(
86
+ self,
87
+ transformer: BriaTransformer2DModel,
88
+ scheduler: Union[FlowMatchEulerDiscreteScheduler,KarrasDiffusionSchedulers],
89
+ vae: AutoencoderKL,
90
+ text_encoder: T5EncoderModel,
91
+ tokenizer: T5TokenizerFast
92
+ ):
93
+ self.register_modules(
94
+ vae=vae,
95
+ text_encoder=text_encoder,
96
+ tokenizer=tokenizer,
97
+ transformer=transformer,
98
+ scheduler=scheduler,
99
+ )
100
+
101
+ # TODO - why different than offical flux (-1)
102
+ self.vae_scale_factor = (
103
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
104
+ )
105
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
106
+ self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k
107
+
108
+ # T5 is senstive to precision so we use the precision used for precompute and cast as needed
109
+ self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
110
+ for block in self.text_encoder.encoder.block:
111
+ block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
112
+
113
+ def encode_prompt(
114
+ self,
115
+ prompt: Union[str, List[str]],
116
+ device: Optional[torch.device] = None,
117
+ num_images_per_prompt: int = 1,
118
+ do_classifier_free_guidance: bool = True,
119
+ negative_prompt: Optional[Union[str, List[str]]] = None,
120
+ prompt_embeds: Optional[torch.FloatTensor] = None,
121
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
122
+ max_sequence_length: int = 128,
123
+ lora_scale: Optional[float] = None,
124
+ ):
125
+ r"""
126
+
127
+ Args:
128
+ prompt (`str` or `List[str]`, *optional*):
129
+ prompt to be encoded
130
+ device: (`torch.device`):
131
+ torch device
132
+ num_images_per_prompt (`int`):
133
+ number of images that should be generated per prompt
134
+ do_classifier_free_guidance (`bool`):
135
+ whether to use classifier free guidance or not
136
+ negative_prompt (`str` or `List[str]`, *optional*):
137
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
138
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
139
+ less than `1`).
140
+ prompt_embeds (`torch.FloatTensor`, *optional*):
141
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
142
+ provided, text embeddings will be generated from `prompt` input argument.
143
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
144
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
145
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
146
+ argument.
147
+ """
148
+ device = device or self._execution_device
149
+
150
+ # set lora scale so that monkey patched LoRA
151
+ # function of text encoder can correctly access it
152
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
153
+ self._lora_scale = lora_scale
154
+
155
+ # dynamically adjust the LoRA scale
156
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
157
+ scale_lora_layers(self.text_encoder, lora_scale)
158
+
159
+ prompt = [prompt] if isinstance(prompt, str) else prompt
160
+ if prompt is not None:
161
+ batch_size = len(prompt)
162
+ else:
163
+ batch_size = prompt_embeds.shape[0]
164
+
165
+ if prompt_embeds is None:
166
+ prompt_embeds = get_t5_prompt_embeds(
167
+ self.tokenizer,
168
+ self.text_encoder,
169
+ prompt=prompt,
170
+ num_images_per_prompt=num_images_per_prompt,
171
+ max_sequence_length=max_sequence_length,
172
+ device=device,
173
+ ).to(dtype=self.transformer.dtype)
174
+
175
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
176
+ if not is_ng_none(negative_prompt):
177
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
178
+
179
+ if prompt is not None and type(prompt) is not type(negative_prompt):
180
+ raise TypeError(
181
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
182
+ f" {type(prompt)}."
183
+ )
184
+ elif batch_size != len(negative_prompt):
185
+ raise ValueError(
186
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
187
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
188
+ " the batch size of `prompt`."
189
+ )
190
+
191
+ negative_prompt_embeds = get_t5_prompt_embeds(
192
+ self.tokenizer,
193
+ self.text_encoder,
194
+ prompt=negative_prompt,
195
+ num_images_per_prompt=num_images_per_prompt,
196
+ max_sequence_length=max_sequence_length,
197
+ device=device,
198
+ ).to(dtype=self.transformer.dtype)
199
+ else:
200
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
201
+
202
+ if self.text_encoder is not None:
203
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
204
+ # Retrieve the original scale by scaling back the LoRA layers
205
+ unscale_lora_layers(self.text_encoder, lora_scale)
206
+
207
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
208
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
209
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
210
+
211
+ return prompt_embeds, negative_prompt_embeds, text_ids
212
+
213
+ @property
214
+ def guidance_scale(self):
215
+ return self._guidance_scale
216
+
217
+
218
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
219
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
220
+ # corresponds to doing no classifier free guidance.
221
+ @property
222
+ def do_classifier_free_guidance(self):
223
+ return self._guidance_scale > 1
224
+
225
+ @property
226
+ def joint_attention_kwargs(self):
227
+ return self._joint_attention_kwargs
228
+
229
+ @property
230
+ def num_timesteps(self):
231
+ return self._num_timesteps
232
+
233
+ @property
234
+ def interrupt(self):
235
+ return self._interrupt
236
+
237
+ @torch.no_grad()
238
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
239
+ def __call__(
240
+ self,
241
+ prompt: Union[str, List[str]] = None,
242
+ height: Optional[int] = None,
243
+ width: Optional[int] = None,
244
+ num_inference_steps: int = 30,
245
+ timesteps: List[int] = None,
246
+ guidance_scale: float = 5,
247
+ negative_prompt: Optional[Union[str, List[str]]] = None,
248
+ num_images_per_prompt: Optional[int] = 1,
249
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
250
+ latents: Optional[torch.FloatTensor] = None,
251
+ prompt_embeds: Optional[torch.FloatTensor] = None,
252
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
253
+ output_type: Optional[str] = "pil",
254
+ return_dict: bool = True,
255
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
256
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
257
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
258
+ max_sequence_length: int = 128,
259
+ clip_value:Union[None,float] = None,
260
+ normalize:bool = False
261
+ ):
262
+ r"""
263
+ Function invoked when calling the pipeline for generation.
264
+
265
+ Args:
266
+ prompt (`str` or `List[str]`, *optional*):
267
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
268
+ instead.
269
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
270
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
271
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
272
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
273
+ num_inference_steps (`int`, *optional*, defaults to 50):
274
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
275
+ expense of slower inference.
276
+ timesteps (`List[int]`, *optional*):
277
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
278
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
279
+ passed will be used. Must be in descending order.
280
+ guidance_scale (`float`, *optional*, defaults to 5.0):
281
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
282
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
283
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
284
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
285
+ usually at the expense of lower image quality.
286
+ negative_prompt (`str` or `List[str]`, *optional*):
287
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
288
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
289
+ less than `1`).
290
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
291
+ The number of images to generate per prompt.
292
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
293
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
294
+ to make generation deterministic.
295
+ latents (`torch.FloatTensor`, *optional*):
296
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
297
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
298
+ tensor will ge generated by sampling using the supplied random `generator`.
299
+ prompt_embeds (`torch.FloatTensor`, *optional*):
300
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
301
+ provided, text embeddings will be generated from `prompt` input argument.
302
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
303
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
304
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
305
+ argument.
306
+ output_type (`str`, *optional*, defaults to `"pil"`):
307
+ The output format of the generate image. Choose between
308
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
309
+ return_dict (`bool`, *optional*, defaults to `True`):
310
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
311
+ of a plain tuple.
312
+ joint_attention_kwargs (`dict`, *optional*):
313
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
314
+ `self.processor` in
315
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
316
+ callback_on_step_end (`Callable`, *optional*):
317
+ A function that calls at the end of each denoising steps during the inference. The function is called
318
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
319
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
320
+ `callback_on_step_end_tensor_inputs`.
321
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
322
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
323
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
324
+ `._callback_tensor_inputs` attribute of your pipeline class.
325
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
326
+
327
+ Examples:
328
+
329
+ Returns:
330
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
331
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
332
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
333
+ """
334
+
335
+ height = height or self.default_sample_size * self.vae_scale_factor
336
+ width = width or self.default_sample_size * self.vae_scale_factor
337
+
338
+ # 1. Check inputs. Raise error if not correct
339
+ self.check_inputs(
340
+ prompt=prompt,
341
+ height=height,
342
+ width=width,
343
+ prompt_embeds=prompt_embeds,
344
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
345
+ max_sequence_length=max_sequence_length,
346
+ )
347
+
348
+ self._guidance_scale = guidance_scale
349
+ self._joint_attention_kwargs = joint_attention_kwargs
350
+ self._interrupt = False
351
+
352
+ # 2. Define call parameters
353
+ if prompt is not None and isinstance(prompt, str):
354
+ batch_size = 1
355
+ elif prompt is not None and isinstance(prompt, list):
356
+ batch_size = len(prompt)
357
+ else:
358
+ batch_size = prompt_embeds.shape[0]
359
+
360
+ device = self._execution_device
361
+
362
+ lora_scale = (
363
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
364
+ )
365
+
366
+ (
367
+ prompt_embeds,
368
+ negative_prompt_embeds,
369
+ text_ids
370
+ ) = self.encode_prompt(
371
+ prompt=prompt,
372
+ negative_prompt=negative_prompt,
373
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
374
+ prompt_embeds=prompt_embeds,
375
+ negative_prompt_embeds=negative_prompt_embeds,
376
+ device=device,
377
+ num_images_per_prompt=num_images_per_prompt,
378
+ max_sequence_length=max_sequence_length,
379
+ lora_scale=lora_scale,
380
+ )
381
+
382
+ if self.do_classifier_free_guidance:
383
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
384
+
385
+ # 4. Prepare timesteps
386
+ # Sample from training sigmas
387
+ if isinstance(self.scheduler,DDIMScheduler) or isinstance(self.scheduler,EulerAncestralDiscreteScheduler):
388
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, None)
389
+ else:
390
+ sigmas = get_original_sigmas(num_train_timesteps=self.scheduler.config.num_train_timesteps,num_inference_steps=num_inference_steps)
391
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps,sigmas=sigmas)
392
+
393
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
394
+ self._num_timesteps = len(timesteps)
395
+
396
+ # 5. Prepare latent variables
397
+ num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
398
+ latents, latent_image_ids = self.prepare_latents(
399
+ batch_size * num_images_per_prompt,
400
+ num_channels_latents,
401
+ height,
402
+ width,
403
+ prompt_embeds.dtype,
404
+ device,
405
+ generator,
406
+ latents,
407
+ )
408
+
409
+ # Supprot different diffusers versions
410
+ if len(latent_image_ids.shape)==2:
411
+ text_ids=text_ids.squeeze()
412
+
413
+ # 6. Denoising loop
414
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
415
+ for i, t in enumerate(timesteps):
416
+ if self.interrupt:
417
+ continue
418
+
419
+ # expand the latents if we are doing classifier free guidance
420
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
421
+ if type(self.scheduler)!=FlowMatchEulerDiscreteScheduler:
422
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
423
+
424
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
425
+ timestep = t.expand(latent_model_input.shape[0])
426
+
427
+ # This is predicts "v" from flow-matching or eps from diffusion
428
+ noise_pred = self.transformer(
429
+ hidden_states=latent_model_input,
430
+ timestep=timestep,
431
+ encoder_hidden_states=prompt_embeds,
432
+ joint_attention_kwargs=self.joint_attention_kwargs,
433
+ return_dict=False,
434
+ txt_ids=text_ids,
435
+ img_ids=latent_image_ids,
436
+ )[0]
437
+
438
+ # perform guidance
439
+ if self.do_classifier_free_guidance:
440
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
441
+ cfg_noise_pred_text = noise_pred_text.std()
442
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
443
+
444
+ if normalize:
445
+ noise_pred = noise_pred * (0.7 *(cfg_noise_pred_text/noise_pred.std())) + 0.3 * noise_pred
446
+
447
+ if clip_value:
448
+ assert clip_value>0
449
+ noise_pred = noise_pred.clip(-clip_value,clip_value)
450
+
451
+ # compute the previous noisy sample x_t -> x_t-1
452
+ latents_dtype = latents.dtype
453
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
454
+
455
+
456
+ # if latents.std().item()>2:
457
+ # print('Warning')
458
+
459
+ # print(t.item(),latents.mean().item(),latents.std().item())
460
+
461
+ if latents.dtype != latents_dtype:
462
+ if torch.backends.mps.is_available():
463
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
464
+ latents = latents.to(latents_dtype)
465
+
466
+ if callback_on_step_end is not None:
467
+ callback_kwargs = {}
468
+ for k in callback_on_step_end_tensor_inputs:
469
+ callback_kwargs[k] = locals()[k]
470
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
471
+
472
+ latents = callback_outputs.pop("latents", latents)
473
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
474
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
475
+
476
+ # call the callback, if provided
477
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
478
+ progress_bar.update()
479
+
480
+ if XLA_AVAILABLE:
481
+ xm.mark_step()
482
+
483
+ if output_type == "latent":
484
+ image = latents
485
+
486
+ else:
487
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
488
+ latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor
489
+ image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
490
+ image = self.image_processor.postprocess(image, output_type=output_type)
491
+
492
+ # Offload all models
493
+ self.maybe_free_model_hooks()
494
+
495
+ if not return_dict:
496
+ return (image,)
497
+
498
+ return FluxPipelineOutput(images=image)
499
+
500
+ def check_inputs(
501
+ self,
502
+ prompt,
503
+ height,
504
+ width,
505
+ negative_prompt=None,
506
+ prompt_embeds=None,
507
+ negative_prompt_embeds=None,
508
+ callback_on_step_end_tensor_inputs=None,
509
+ max_sequence_length=None,
510
+ ):
511
+ if height % 8 != 0 or width % 8 != 0:
512
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
513
+
514
+ if callback_on_step_end_tensor_inputs is not None and not all(
515
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
516
+ ):
517
+ raise ValueError(
518
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
519
+ )
520
+
521
+ if prompt is not None and prompt_embeds is not None:
522
+ raise ValueError(
523
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
524
+ " only forward one of the two."
525
+ )
526
+ elif prompt is None and prompt_embeds is None:
527
+ raise ValueError(
528
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
529
+ )
530
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
531
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
532
+
533
+ if negative_prompt is not None and negative_prompt_embeds is not None:
534
+ raise ValueError(
535
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
536
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
537
+ )
538
+
539
+
540
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
541
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
542
+ raise ValueError(
543
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
544
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
545
+ f" {negative_prompt_embeds.shape}."
546
+ )
547
+
548
+ if max_sequence_length is not None and max_sequence_length > 512:
549
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
550
+
551
+ def to(self, *args, **kwargs):
552
+ DiffusionPipeline.to(self, *args, **kwargs)
553
+ # T5 is senstive to precision so we use the precision used for precompute and cast as needed
554
+ self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
555
+ for block in self.text_encoder.encoder.block:
556
+ block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
557
+
558
+ return self
559
+
560
+
561
+
562
+
transformer_bria.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Union
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.loaders import PeftAdapterMixin, FromOriginalModelMixin
7
+ from diffusers.models.modeling_utils import ModelMixin
8
+ from diffusers.models.normalization import AdaLayerNormContinuous
9
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
10
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
11
+ from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding
12
+ from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
13
+
14
+ # Support different diffusers versions
15
+ try:
16
+ from diffusers.models.embeddings import FluxPosEmbed as EmbedND
17
+ except:
18
+ from diffusers.models.transformers.transformer_flux import rope
19
+ class EmbedND(nn.Module):
20
+ def __init__(self, theta: int, axes_dim: List[int]):
21
+ super().__init__()
22
+ self.theta = theta
23
+ self.axes_dim = axes_dim
24
+
25
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
26
+ n_axes = ids.shape[-1]
27
+ emb = torch.cat(
28
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
29
+ dim=-3,
30
+ )
31
+ return emb.unsqueeze(1)
32
+
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+ class Timesteps(nn.Module):
38
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1,max_period=10000):
39
+ super().__init__()
40
+ self.num_channels = num_channels
41
+ self.flip_sin_to_cos = flip_sin_to_cos
42
+ self.downscale_freq_shift = downscale_freq_shift
43
+ self.scale = scale
44
+ self.max_period=max_period
45
+
46
+ def forward(self, timesteps):
47
+ t_emb = get_timestep_embedding(
48
+ timesteps,
49
+ self.num_channels,
50
+ flip_sin_to_cos=self.flip_sin_to_cos,
51
+ downscale_freq_shift=self.downscale_freq_shift,
52
+ scale=self.scale,
53
+ max_period=self.max_period
54
+ )
55
+ return t_emb
56
+
57
+ class TimestepProjEmbeddings(nn.Module):
58
+ def __init__(self, embedding_dim, max_period):
59
+ super().__init__()
60
+
61
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0,max_period=max_period)
62
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
63
+
64
+ def forward(self, timestep, dtype):
65
+ timesteps_proj = self.time_proj(timestep)
66
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
67
+ return timesteps_emb
68
+
69
+ """
70
+ Based on FluxPipeline with several changes:
71
+ - no pooled embeddings
72
+ - We use zero padding for prompts
73
+ - No guidance embedding since this is not a distilled version
74
+ """
75
+ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
76
+ """
77
+ The Transformer model introduced in Flux.
78
+
79
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
80
+
81
+ Parameters:
82
+ patch_size (`int`): Patch size to turn the input data into small patches.
83
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
84
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
85
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
86
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
87
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
88
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
89
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
90
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
91
+ """
92
+
93
+ _supports_gradient_checkpointing = True
94
+
95
+ @register_to_config
96
+ def __init__(
97
+ self,
98
+ patch_size: int = 1,
99
+ in_channels: int = 64,
100
+ num_layers: int = 19,
101
+ num_single_layers: int = 38,
102
+ attention_head_dim: int = 128,
103
+ num_attention_heads: int = 24,
104
+ joint_attention_dim: int = 4096,
105
+ pooled_projection_dim: int = None,
106
+ guidance_embeds: bool = False,
107
+ axes_dims_rope: List[int] = [16, 56, 56],
108
+ rope_theta = 10000,
109
+ max_period = 10000
110
+ ):
111
+ super().__init__()
112
+ self.out_channels = in_channels
113
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
114
+
115
+ self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
116
+
117
+
118
+ self.time_embed = TimestepProjEmbeddings(
119
+ embedding_dim=self.inner_dim,max_period=max_period
120
+ )
121
+
122
+ # if pooled_projection_dim:
123
+ # self.pooled_text_embed = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim=self.inner_dim, act_fn="silu")
124
+
125
+ if guidance_embeds:
126
+ self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim)
127
+
128
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
129
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
130
+
131
+ self.transformer_blocks = nn.ModuleList(
132
+ [
133
+ FluxTransformerBlock(
134
+ dim=self.inner_dim,
135
+ num_attention_heads=self.config.num_attention_heads,
136
+ attention_head_dim=self.config.attention_head_dim,
137
+ )
138
+ for i in range(self.config.num_layers)
139
+ ]
140
+ )
141
+
142
+ self.single_transformer_blocks = nn.ModuleList(
143
+ [
144
+ FluxSingleTransformerBlock(
145
+ dim=self.inner_dim,
146
+ num_attention_heads=self.config.num_attention_heads,
147
+ attention_head_dim=self.config.attention_head_dim,
148
+ )
149
+ for i in range(self.config.num_single_layers)
150
+ ]
151
+ )
152
+
153
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
154
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
155
+
156
+ self.gradient_checkpointing = False
157
+
158
+ def _set_gradient_checkpointing(self, module, value=False):
159
+ if hasattr(module, "gradient_checkpointing"):
160
+ module.gradient_checkpointing = value
161
+
162
+ def forward(
163
+ self,
164
+ hidden_states: torch.Tensor,
165
+ encoder_hidden_states: torch.Tensor = None,
166
+ pooled_projections: torch.Tensor = None,
167
+ timestep: torch.LongTensor = None,
168
+ img_ids: torch.Tensor = None,
169
+ txt_ids: torch.Tensor = None,
170
+ guidance: torch.Tensor = None,
171
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
172
+ return_dict: bool = True,
173
+ controlnet_block_samples = None,
174
+ controlnet_single_block_samples=None,
175
+
176
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
177
+ """
178
+ The [`FluxTransformer2DModel`] forward method.
179
+
180
+ Args:
181
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
182
+ Input `hidden_states`.
183
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
184
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
185
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
186
+ from the embeddings of input conditions.
187
+ timestep ( `torch.LongTensor`):
188
+ Used to indicate denoising step.
189
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
190
+ A list of tensors that if specified are added to the residuals of transformer blocks.
191
+ joint_attention_kwargs (`dict`, *optional*):
192
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
193
+ `self.processor` in
194
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
195
+ return_dict (`bool`, *optional*, defaults to `True`):
196
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
197
+ tuple.
198
+
199
+ Returns:
200
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
201
+ `tuple` where the first element is the sample tensor.
202
+ """
203
+ if joint_attention_kwargs is not None:
204
+ joint_attention_kwargs = joint_attention_kwargs.copy()
205
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
206
+ else:
207
+ lora_scale = 1.0
208
+
209
+ if USE_PEFT_BACKEND:
210
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
211
+ scale_lora_layers(self, lora_scale)
212
+ else:
213
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
214
+ logger.warning(
215
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
216
+ )
217
+ hidden_states = self.x_embedder(hidden_states)
218
+
219
+ timestep = timestep.to(hidden_states.dtype)
220
+ if guidance is not None:
221
+ guidance = guidance.to(hidden_states.dtype)
222
+ else:
223
+ guidance = None
224
+
225
+ # temb = (
226
+ # self.time_text_embed(timestep, pooled_projections)
227
+ # if guidance is None
228
+ # else self.time_text_embed(timestep, guidance, pooled_projections)
229
+ # )
230
+
231
+ temb = self.time_embed(timestep,dtype=hidden_states.dtype)
232
+
233
+ # if pooled_projections:
234
+ # temb+=self.pooled_text_embed(pooled_projections)
235
+
236
+ if guidance:
237
+ temb+=self.guidance_embed(guidance,dtype=hidden_states.dtype)
238
+
239
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
240
+
241
+ if len(txt_ids.shape)==2:
242
+ ids = torch.cat((txt_ids, img_ids), dim=0)
243
+ else:
244
+ ids = torch.cat((txt_ids, img_ids), dim=1)
245
+ image_rotary_emb = self.pos_embed(ids)
246
+
247
+ for index_block, block in enumerate(self.transformer_blocks):
248
+ if self.training and self.gradient_checkpointing:
249
+
250
+ def create_custom_forward(module, return_dict=None):
251
+ def custom_forward(*inputs):
252
+ if return_dict is not None:
253
+ return module(*inputs, return_dict=return_dict)
254
+ else:
255
+ return module(*inputs)
256
+
257
+ return custom_forward
258
+
259
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
260
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
261
+ create_custom_forward(block),
262
+ hidden_states,
263
+ encoder_hidden_states,
264
+ temb,
265
+ image_rotary_emb,
266
+ **ckpt_kwargs,
267
+ )
268
+
269
+ else:
270
+ encoder_hidden_states, hidden_states = block(
271
+ hidden_states=hidden_states,
272
+ encoder_hidden_states=encoder_hidden_states,
273
+ temb=temb,
274
+ image_rotary_emb=image_rotary_emb,
275
+ )
276
+
277
+ # controlnet residual
278
+ if controlnet_block_samples is not None:
279
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
280
+ interval_control = int(np.ceil(interval_control))
281
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
282
+
283
+
284
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
285
+
286
+ for index_block, block in enumerate(self.single_transformer_blocks):
287
+ if self.training and self.gradient_checkpointing:
288
+
289
+ def create_custom_forward(module, return_dict=None):
290
+ def custom_forward(*inputs):
291
+ if return_dict is not None:
292
+ return module(*inputs, return_dict=return_dict)
293
+ else:
294
+ return module(*inputs)
295
+
296
+ return custom_forward
297
+
298
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
299
+ hidden_states = torch.utils.checkpoint.checkpoint(
300
+ create_custom_forward(block),
301
+ hidden_states,
302
+ temb,
303
+ image_rotary_emb,
304
+ **ckpt_kwargs,
305
+ )
306
+
307
+ else:
308
+ hidden_states = block(
309
+ hidden_states=hidden_states,
310
+ temb=temb,
311
+ image_rotary_emb=image_rotary_emb,
312
+ )
313
+
314
+ # controlnet residual
315
+ if controlnet_single_block_samples is not None:
316
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
317
+ interval_control = int(np.ceil(interval_control))
318
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
319
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
320
+ + controlnet_single_block_samples[index_block // interval_control]
321
+ )
322
+
323
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
324
+
325
+ hidden_states = self.norm_out(hidden_states, temb)
326
+ output = self.proj_out(hidden_states)
327
+
328
+ if USE_PEFT_BACKEND:
329
+ # remove `lora_scale` from each PEFT layer
330
+ unscale_lora_layers(self, lora_scale)
331
+
332
+ if not return_dict:
333
+ return (output,)
334
+
335
+ return Transformer2DModelOutput(sample=output)