Scaryplasmon96 commited on
Commit
203c467
·
verified ·
1 Parent(s): c7e8bdd
Files changed (1) hide show
  1. DoodlePix_pipeline.py +977 -0
DoodlePix_pipeline.py ADDED
@@ -0,0 +1,977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
9
+ import os
10
+
11
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
12
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
13
+ from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
14
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
15
+ from diffusers.schedulers import KarrasDiffusionSchedulers
16
+ from diffusers.utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
19
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
20
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
21
+
22
+
23
+ if is_torch_xla_available():
24
+ import torch_xla.core.xla_model as xm
25
+
26
+ XLA_AVAILABLE = True
27
+ else:
28
+ XLA_AVAILABLE = False
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
34
+ def preprocess(image):
35
+ deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
36
+ deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
37
+ if isinstance(image, torch.Tensor):
38
+ return image
39
+ elif isinstance(image, PIL.Image.Image):
40
+ image = [image]
41
+
42
+ if isinstance(image[0], PIL.Image.Image):
43
+ w, h = image[0].size
44
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
45
+
46
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
47
+ image = np.concatenate(image, axis=0)
48
+ image = np.array(image).astype(np.float32) / 255.0
49
+ image = image.transpose(0, 3, 1, 2)
50
+ image = 2.0 * image - 1.0
51
+ image = torch.from_numpy(image)
52
+ elif isinstance(image[0], torch.Tensor):
53
+ image = torch.cat(image, dim=0)
54
+ return image
55
+
56
+
57
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
58
+ def retrieve_latents(
59
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
60
+ ):
61
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
62
+ return encoder_output.latent_dist.sample(generator)
63
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
64
+ return encoder_output.latent_dist.mode()
65
+ elif hasattr(encoder_output, "latents"):
66
+ return encoder_output.latents
67
+ else:
68
+ raise AttributeError("Could not access latents of provided encoder_output")
69
+
70
+
71
+ class StableDiffusionInstructPix2PixPipeline(
72
+ DiffusionPipeline,
73
+ StableDiffusionMixin,
74
+ TextualInversionLoaderMixin,
75
+ StableDiffusionLoraLoaderMixin,
76
+ IPAdapterMixin,
77
+ FromSingleFileMixin,
78
+ ):
79
+ r"""
80
+ Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).
81
+
82
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
83
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
84
+
85
+ The pipeline also inherits the following loading methods:
86
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
87
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
88
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
89
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
90
+
91
+ Args:
92
+ vae ([`AutoencoderKL`]):
93
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
94
+ text_encoder ([`~transformers.CLIPTextModel`]):
95
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
96
+ tokenizer ([`~transformers.CLIPTokenizer`]):
97
+ A `CLIPTokenizer` to tokenize text.
98
+ unet ([`UNet2DConditionModel`]):
99
+ A `UNet2DConditionModel` to denoise the encoded image latents.
100
+ scheduler ([`SchedulerMixin`]):
101
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
102
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
103
+ safety_checker ([`StableDiffusionSafetyChecker`]):
104
+ Classification module that estimates whether generated images could be considered offensive or harmful.
105
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
106
+ more details about a model's potential harms.
107
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
108
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
109
+ """
110
+
111
+ model_cpu_offload_seq = "text_encoder->unet->vae"
112
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
113
+ _exclude_from_cpu_offload = ["safety_checker"]
114
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "image_latents"]
115
+
116
+ def __init__(
117
+ self,
118
+ vae: AutoencoderKL,
119
+ text_encoder: CLIPTextModel,
120
+ tokenizer: CLIPTokenizer,
121
+ unet: UNet2DConditionModel,
122
+ scheduler: KarrasDiffusionSchedulers,
123
+ safety_checker: StableDiffusionSafetyChecker,
124
+ feature_extractor: CLIPImageProcessor,
125
+ image_encoder: Optional[CLIPVisionModelWithProjection] = None,
126
+ requires_safety_checker: bool = False,
127
+ teacher_text_encoder: Optional[CLIPTextModel] = None,
128
+ teacher_loss_weight: float = 0.1,
129
+ fidelity_mlp = None,
130
+ ):
131
+ super().__init__()
132
+
133
+ if safety_checker is None and requires_safety_checker:
134
+ logger.warning(
135
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
136
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
137
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
138
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
139
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
140
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
141
+ )
142
+
143
+ if safety_checker is not None and feature_extractor is None:
144
+ raise ValueError(
145
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
146
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
147
+ )
148
+
149
+ self.register_modules(
150
+ vae=vae,
151
+ text_encoder=text_encoder,
152
+ tokenizer=tokenizer,
153
+ unet=unet,
154
+ scheduler=scheduler,
155
+ safety_checker=safety_checker,
156
+ feature_extractor=feature_extractor,
157
+ image_encoder=image_encoder,
158
+ )
159
+ self.teacher_text_encoder = teacher_text_encoder
160
+ self.teacher_loss_weight = teacher_loss_weight
161
+ self.fidelity_mlp = fidelity_mlp
162
+ if self.teacher_text_encoder is not None:
163
+ self.teacher_text_encoder.eval()
164
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
165
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
166
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
167
+
168
+ @torch.no_grad()
169
+ def __call__(
170
+ self,
171
+ prompt: Union[str, List[str]] = None,
172
+ image: PipelineImageInput = None,
173
+ num_inference_steps: int = 100,
174
+ guidance_scale: float = 7.5,
175
+ image_guidance_scale: float = 1.5,
176
+ negative_prompt: Optional[Union[str, List[str]]] = None,
177
+ num_images_per_prompt: Optional[int] = 1,
178
+ eta: float = 0.0,
179
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
180
+ latents: Optional[torch.Tensor] = None,
181
+ prompt_embeds: Optional[torch.Tensor] = None,
182
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
183
+ ip_adapter_image: Optional[PipelineImageInput] = None,
184
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
185
+ output_type: Optional[str] = "pil",
186
+ return_dict: bool = True,
187
+ callback_on_step_end: Optional[
188
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
189
+ ] = None,
190
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
191
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
192
+ **kwargs,
193
+ ):
194
+ r"""
195
+ The call function to the pipeline for generation.
196
+
197
+ Args:
198
+ prompt (`str` or `List[str]`, *optional*):
199
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
200
+ image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
201
+ `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
202
+ image latents as `image`, but if passing latents directly it is not encoded again.
203
+ num_inference_steps (`int`, *optional*, defaults to 100):
204
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
205
+ expense of slower inference.
206
+ guidance_scale (`float`, *optional*, defaults to 7.5):
207
+ A higher guidance scale value encourages the model to generate images closely linked to the text
208
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
209
+ image_guidance_scale (`float`, *optional*, defaults to 1.5):
210
+ Push the generated image towards the initial `image`. Image guidance scale is enabled by setting
211
+ `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
212
+ linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
213
+ value of at least `1`.
214
+ negative_prompt (`str` or `List[str]`, *optional*):
215
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
216
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
217
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
218
+ The number of images to generate per prompt.
219
+ eta (`float`, *optional*, defaults to 0.0):
220
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
221
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
222
+ generator (`torch.Generator`, *optional*):
223
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
224
+ generation deterministic.
225
+ latents (`torch.Tensor`, *optional*):
226
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
227
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
228
+ tensor is generated by sampling using the supplied random `generator`.
229
+ prompt_embeds (`torch.Tensor`, *optional*):
230
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
231
+ provided, text embeddings are generated from the `prompt` input argument.
232
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
233
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
234
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
235
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
236
+ Optional image input to work with IP Adapters.
237
+ output_type (`str`, *optional*, defaults to `"pil"`):
238
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
239
+ return_dict (`bool`, *optional*, defaults to `True`):
240
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
241
+ plain tuple.
242
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
243
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
244
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
245
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
246
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
247
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
248
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
249
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
250
+ `._callback_tensor_inputs` attribute of your pipeline class.
251
+ cross_attention_kwargs (`dict`, *optional*):
252
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
253
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
254
+
255
+ Examples:
256
+
257
+ ```py
258
+ >>> import PIL
259
+ >>> import requests
260
+ >>> import torch
261
+ >>> from io import BytesIO
262
+
263
+ >>> from diffusers import StableDiffusionInstructPix2PixPipeline
264
+
265
+
266
+ >>> def download_image(url):
267
+ ... response = requests.get(url)
268
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
269
+
270
+
271
+ >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
272
+
273
+ >>> image = download_image(img_url).resize((512, 512))
274
+
275
+ >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
276
+ ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
277
+ ... )
278
+ >>> pipe = pipe.to("cuda")
279
+
280
+ >>> prompt = "make the mountains snowy"
281
+ >>> image = pipe(prompt=prompt, image=image).images[0]
282
+ ```
283
+
284
+ Returns:
285
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
286
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
287
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
288
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
289
+ "not-safe-for-work" (nsfw) content.
290
+ """
291
+
292
+ callback = kwargs.pop("callback", None)
293
+ callback_steps = kwargs.pop("callback_steps", None)
294
+
295
+ if callback is not None:
296
+ deprecate(
297
+ "callback",
298
+ "1.0.0",
299
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
300
+ )
301
+ if callback_steps is not None:
302
+ deprecate(
303
+ "callback_steps",
304
+ "1.0.0",
305
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
306
+ )
307
+
308
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
309
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
310
+
311
+ # 0. Check inputs
312
+ self.check_inputs(
313
+ prompt,
314
+ callback_steps,
315
+ negative_prompt,
316
+ prompt_embeds,
317
+ negative_prompt_embeds,
318
+ ip_adapter_image,
319
+ ip_adapter_image_embeds,
320
+ callback_on_step_end_tensor_inputs,
321
+ )
322
+ self._guidance_scale = guidance_scale
323
+ self._image_guidance_scale = image_guidance_scale
324
+
325
+ device = self._execution_device
326
+
327
+ if image is None:
328
+ raise ValueError("`image` input cannot be undefined.")
329
+
330
+ # 1. Define call parameters
331
+ if prompt is not None and isinstance(prompt, str):
332
+ batch_size = 1
333
+ elif prompt is not None and isinstance(prompt, list):
334
+ batch_size = len(prompt)
335
+ else:
336
+ batch_size = prompt_embeds.shape[0]
337
+
338
+ device = self._execution_device
339
+
340
+ # 2. Encode input prompt
341
+ prompt_embeds = self._encode_prompt(
342
+ prompt,
343
+ device,
344
+ num_images_per_prompt,
345
+ self.do_classifier_free_guidance,
346
+ negative_prompt,
347
+ prompt_embeds=prompt_embeds,
348
+ negative_prompt_embeds=negative_prompt_embeds,
349
+ return_teacher_loss=False,
350
+ )
351
+
352
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
353
+ image_embeds = self.prepare_ip_adapter_image_embeds(
354
+ ip_adapter_image,
355
+ ip_adapter_image_embeds,
356
+ device,
357
+ batch_size * num_images_per_prompt,
358
+ self.do_classifier_free_guidance,
359
+ )
360
+ # 3. Preprocess image
361
+ image = self.image_processor.preprocess(image)
362
+
363
+ # 4. set timesteps
364
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
365
+ timesteps = self.scheduler.timesteps
366
+
367
+ # 5. Prepare Image latents
368
+ image_latents = self.prepare_image_latents(
369
+ image,
370
+ batch_size,
371
+ num_images_per_prompt,
372
+ prompt_embeds.dtype,
373
+ device,
374
+ self.do_classifier_free_guidance,
375
+ )
376
+
377
+ height, width = image_latents.shape[-2:]
378
+ height = height * self.vae_scale_factor
379
+ width = width * self.vae_scale_factor
380
+
381
+ # 6. Prepare latent variables
382
+ num_channels_latents = self.vae.config.latent_channels
383
+ latents = self.prepare_latents(
384
+ batch_size * num_images_per_prompt,
385
+ num_channels_latents,
386
+ height,
387
+ width,
388
+ prompt_embeds.dtype,
389
+ device,
390
+ generator,
391
+ latents,
392
+ )
393
+
394
+ # 7. Check that shapes of latents and image match the UNet channels
395
+ num_channels_image = image_latents.shape[1]
396
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
397
+ raise ValueError(
398
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
399
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
400
+ f" `num_channels_image`: {num_channels_image} "
401
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
402
+ " `pipeline.unet` or your `image` input."
403
+ )
404
+
405
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
406
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
407
+
408
+ # 8.1 Add image embeds for IP-Adapter
409
+ added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
410
+
411
+ # 9. Denoising loop
412
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
413
+ self._num_timesteps = len(timesteps)
414
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
415
+ for i, t in enumerate(timesteps):
416
+ # Expand the latents if we are doing classifier free guidance.
417
+ # The latents are expanded 3 times because for pix2pix the guidance\
418
+ # is applied for both the text and the input image.
419
+ latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents
420
+
421
+ # concat latents, image_latents in the channel dimension
422
+ scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
423
+ scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
424
+
425
+ # predict the noise residual
426
+ noise_pred = self.unet(
427
+ scaled_latent_model_input,
428
+ t,
429
+ encoder_hidden_states=prompt_embeds,
430
+ added_cond_kwargs=added_cond_kwargs,
431
+ cross_attention_kwargs=cross_attention_kwargs,
432
+ return_dict=False,
433
+ )[0]
434
+
435
+ # perform guidance
436
+ if self.do_classifier_free_guidance:
437
+ noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
438
+ noise_pred = (
439
+ noise_pred_uncond
440
+ + self.guidance_scale * (noise_pred_text - noise_pred_image)
441
+ + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond)
442
+ )
443
+
444
+ # compute the previous noisy sample x_t -> x_t-1
445
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
446
+
447
+ if callback_on_step_end is not None:
448
+ callback_kwargs = {}
449
+ for k in callback_on_step_end_tensor_inputs:
450
+ callback_kwargs[k] = locals()[k]
451
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
452
+
453
+ latents = callback_outputs.pop("latents", latents)
454
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
455
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
456
+ image_latents = callback_outputs.pop("image_latents", image_latents)
457
+
458
+ # call the callback, if provided
459
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
460
+ progress_bar.update()
461
+ if callback is not None and i % callback_steps == 0:
462
+ step_idx = i // getattr(self.scheduler, "order", 1)
463
+ callback(step_idx, t, latents)
464
+
465
+ if XLA_AVAILABLE:
466
+ xm.mark_step()
467
+
468
+ if not output_type == "latent":
469
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
470
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
471
+ else:
472
+ image = latents
473
+ has_nsfw_concept = None
474
+
475
+ if has_nsfw_concept is None:
476
+ do_denormalize = [True] * image.shape[0]
477
+ else:
478
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
479
+
480
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
481
+
482
+ # Offload all models
483
+ self.maybe_free_model_hooks()
484
+
485
+ if not return_dict:
486
+ return (image, has_nsfw_concept)
487
+
488
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
489
+
490
+ def _encode_prompt(
491
+ self,
492
+ prompt,
493
+ device,
494
+ num_images_per_prompt,
495
+ do_classifier_free_guidance,
496
+ negative_prompt=None,
497
+ prompt_embeds: Optional[torch.Tensor] = None,
498
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
499
+ return_teacher_loss: bool = False,
500
+ ):
501
+ r"""
502
+ Encodes the prompt into text encoder hidden states.
503
+
504
+ Args:
505
+ prompt (`str` or `List[str]`, *optional*):
506
+ prompt to be encoded
507
+ device: (`torch.device`):
508
+ torch device
509
+ num_images_per_prompt (`int`):
510
+ number of images that should be generated per prompt
511
+ do_classifier_free_guidance (`bool`):
512
+ whether to use classifier free guidance or not
513
+ negative_ prompt (`str` or `List[str]`, *optional*):
514
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
515
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
516
+ less than `1`).
517
+ prompt_embeds (`torch.Tensor`, *optional*):
518
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
519
+ provided, text embeddings will be generated from `prompt` input argument.
520
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
521
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
522
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
523
+ argument.
524
+ """
525
+ if prompt is not None and isinstance(prompt, str):
526
+ batch_size = 1
527
+ elif prompt is not None and isinstance(prompt, list):
528
+ batch_size = len(prompt)
529
+ else:
530
+ batch_size = prompt_embeds.shape[0]
531
+
532
+ if prompt_embeds is None:
533
+ # textual inversion: process multi-vector tokens if necessary
534
+ if isinstance(self, TextualInversionLoaderMixin):
535
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
536
+
537
+ text_inputs = self.tokenizer(
538
+ prompt,
539
+ padding="max_length",
540
+ max_length=self.tokenizer.model_max_length,
541
+ truncation=True,
542
+ return_tensors="pt",
543
+ )
544
+ text_input_ids = text_inputs.input_ids
545
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
546
+
547
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
548
+ text_input_ids, untruncated_ids
549
+ ):
550
+ removed_text = self.tokenizer.batch_decode(
551
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
552
+ )
553
+ logger.warning(
554
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
555
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
556
+ )
557
+
558
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
559
+ attention_mask = text_inputs.attention_mask.to(device)
560
+ else:
561
+ attention_mask = None
562
+
563
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
564
+ prompt_embeds = prompt_embeds[0]
565
+
566
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
567
+
568
+ # Extract fidelity value from the prompt
569
+ fidelity_val = 0.5 # default fidelity
570
+ if self.fidelity_mlp is not None:
571
+ if prompt is not None:
572
+ if isinstance(prompt, str):
573
+ import re
574
+ match = re.search(r"f\s*=?\s*(\d+)|f(\d+)", prompt, re.IGNORECASE)
575
+ if match:
576
+ f_int = int(match.group(1) if match.group(1) else match.group(2))
577
+ f_int = max(0, min(f_int, 9))
578
+ fidelity_val = f_int / 9.0
579
+ elif isinstance(prompt, list):
580
+ import re
581
+ f_vals = []
582
+ for p in prompt:
583
+ match = re.search(r"f\s*=?\s*(\d+)|f(\d+)", p, re.IGNORECASE)
584
+ if match:
585
+ f_int = int(match.group(1) if match.group(1) else match.group(2))
586
+ f_int = max(0, min(f_int, 9))
587
+ f_vals.append(f_int / 9.0)
588
+ if f_vals:
589
+ fidelity_val = sum(f_vals) / len(f_vals)
590
+
591
+ # Create fidelity tensor
592
+ batch_size = prompt_embeds.shape[0]
593
+ fidelity_tensor = torch.full((batch_size, 1), fidelity_val, device=device, dtype=prompt_embeds.dtype)
594
+
595
+ # Get fidelity embedding
596
+ fidelity_embedding = self.fidelity_mlp(fidelity_tensor) # (batch, hidden_size)
597
+
598
+ # This keeps the sequence length the same (77 tokens)
599
+ prompt_embeds[:, 0] = prompt_embeds[:, 0] + (0.8 * fidelity_embedding)
600
+
601
+ bs_embed, seq_len, _ = prompt_embeds.shape
602
+ # duplicate text embeddings for each generation per prompt
603
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
604
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
605
+
606
+ teacher_loss = None
607
+ if return_teacher_loss and self.teacher_text_encoder is not None:
608
+ # Compute teacher embeddings using the frozen teacher text encoder
609
+ with torch.no_grad():
610
+ teacher_outputs = self.teacher_text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
611
+ teacher_embeds = teacher_outputs[0]
612
+ teacher_embeds = teacher_embeds.to(dtype=self.text_encoder.dtype, device=device)
613
+ teacher_embeds = teacher_embeds.repeat(1, num_images_per_prompt, 1)
614
+ teacher_embeds = teacher_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
615
+ # Compute an MSE loss between teacher and student text embeddings
616
+ teacher_loss = F.mse_loss(prompt_embeds, teacher_embeds) * self.teacher_loss_weight
617
+
618
+ if do_classifier_free_guidance:
619
+ # get unconditional embeddings for classifier free guidance
620
+ if negative_prompt_embeds is None:
621
+ uncond_tokens: List[str]
622
+ if negative_prompt is None:
623
+ uncond_tokens = [""] * batch_size
624
+ elif isinstance(negative_prompt, str):
625
+ uncond_tokens = [negative_prompt]
626
+ elif batch_size != len(negative_prompt):
627
+ raise ValueError(
628
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
629
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
630
+ " the batch size of `prompt`."
631
+ )
632
+ else:
633
+ uncond_tokens = negative_prompt
634
+
635
+ if isinstance(self, TextualInversionLoaderMixin):
636
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
637
+
638
+ max_length = prompt_embeds.shape[1]
639
+ uncond_input = self.tokenizer(
640
+ uncond_tokens,
641
+ padding="max_length",
642
+ max_length=max_length,
643
+ truncation=True,
644
+ return_tensors="pt",
645
+ )
646
+
647
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
648
+ attention_mask = uncond_input.attention_mask.to(device)
649
+ else:
650
+ attention_mask = None
651
+
652
+ negative_prompt_embeds = self.text_encoder(
653
+ uncond_input.input_ids.to(device),
654
+ attention_mask=attention_mask,
655
+ )[0]
656
+
657
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
658
+
659
+ # Apply the same fidelity modification to negative prompt embeddings
660
+ if self.fidelity_mlp is not None:
661
+ bs_neg = negative_prompt_embeds.shape[0]
662
+ fidelity_tensor = torch.full((bs_neg, 1), fidelity_val, device=device, dtype=negative_prompt_embeds.dtype)
663
+ fidelity_embedding = self.fidelity_mlp(fidelity_tensor)
664
+
665
+ # FIXED: Modify the first token of negative embedding instead of concatenating
666
+ negative_prompt_embeds[:, 0] = negative_prompt_embeds[:, 0] + 0.2 * fidelity_embedding
667
+
668
+ bs_embed, seq_len, _ = negative_prompt_embeds.shape
669
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
670
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
671
+
672
+ # For classifier free guidance, we need to do it for both positive and negative
673
+ prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds])
674
+
675
+ if return_teacher_loss:
676
+ return prompt_embeds, teacher_loss
677
+ return prompt_embeds
678
+
679
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
680
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
681
+ dtype = next(self.image_encoder.parameters()).dtype
682
+
683
+ if not isinstance(image, torch.Tensor):
684
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
685
+
686
+ image = image.to(device=device, dtype=dtype)
687
+ if output_hidden_states:
688
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
689
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
690
+ uncond_image_enc_hidden_states = self.image_encoder(
691
+ torch.zeros_like(image), output_hidden_states=True
692
+ ).hidden_states[-2]
693
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
694
+ num_images_per_prompt, dim=0
695
+ )
696
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
697
+ else:
698
+ image_embeds = self.image_encoder(image).image_embeds
699
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
700
+ uncond_image_embeds = torch.zeros_like(image_embeds)
701
+
702
+ return image_embeds, uncond_image_embeds
703
+
704
+ def prepare_ip_adapter_image_embeds(
705
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
706
+ ):
707
+ if ip_adapter_image_embeds is None:
708
+ if not isinstance(ip_adapter_image, list):
709
+ ip_adapter_image = [ip_adapter_image]
710
+
711
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
712
+ raise ValueError(
713
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
714
+ )
715
+
716
+ image_embeds = []
717
+ for single_ip_adapter_image, image_proj_layer in zip(
718
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
719
+ ):
720
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
721
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
722
+ single_ip_adapter_image, device, 1, output_hidden_state
723
+ )
724
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
725
+ single_negative_image_embeds = torch.stack(
726
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
727
+ )
728
+
729
+ if do_classifier_free_guidance:
730
+ single_image_embeds = torch.cat(
731
+ [single_image_embeds, single_negative_image_embeds, single_negative_image_embeds]
732
+ )
733
+ single_image_embeds = single_image_embeds.to(device)
734
+
735
+ image_embeds.append(single_image_embeds)
736
+ else:
737
+ repeat_dims = [1]
738
+ image_embeds = []
739
+ for single_image_embeds in ip_adapter_image_embeds:
740
+ if do_classifier_free_guidance:
741
+ (
742
+ single_image_embeds,
743
+ single_negative_image_embeds,
744
+ single_negative_image_embeds,
745
+ ) = single_image_embeds.chunk(3)
746
+ single_image_embeds = single_image_embeds.repeat(
747
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
748
+ )
749
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
750
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
751
+ )
752
+ single_image_embeds = torch.cat(
753
+ [single_image_embeds, single_negative_image_embeds, single_negative_image_embeds]
754
+ )
755
+ else:
756
+ single_image_embeds = single_image_embeds.repeat(
757
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
758
+ )
759
+ image_embeds.append(single_image_embeds)
760
+
761
+ return image_embeds
762
+
763
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
764
+ def run_safety_checker(self, image, device, dtype):
765
+ if self.safety_checker is None:
766
+ has_nsfw_concept = None
767
+ else:
768
+ if torch.is_tensor(image):
769
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
770
+ else:
771
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
772
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
773
+ image, has_nsfw_concept = self.safety_checker(
774
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
775
+ )
776
+ return image, has_nsfw_concept
777
+
778
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
779
+ def prepare_extra_step_kwargs(self, generator, eta):
780
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
781
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
782
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
783
+ # and should be between [0, 1]
784
+
785
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
786
+ extra_step_kwargs = {}
787
+ if accepts_eta:
788
+ extra_step_kwargs["eta"] = eta
789
+
790
+ # check if the scheduler accepts generator
791
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
792
+ if accepts_generator:
793
+ extra_step_kwargs["generator"] = generator
794
+ return extra_step_kwargs
795
+
796
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
797
+ def decode_latents(self, latents):
798
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
799
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
800
+
801
+ latents = 1 / self.vae.config.scaling_factor * latents
802
+ image = self.vae.decode(latents, return_dict=False)[0]
803
+ image = (image / 2 + 0.5).clamp(0, 1)
804
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
805
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
806
+ return image
807
+
808
+ def check_inputs(
809
+ self,
810
+ prompt,
811
+ callback_steps,
812
+ negative_prompt=None,
813
+ prompt_embeds=None,
814
+ negative_prompt_embeds=None,
815
+ ip_adapter_image=None,
816
+ ip_adapter_image_embeds=None,
817
+ callback_on_step_end_tensor_inputs=None,
818
+ ):
819
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
820
+ raise ValueError(
821
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
822
+ f" {type(callback_steps)}."
823
+ )
824
+
825
+ if callback_on_step_end_tensor_inputs is not None and not all(
826
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
827
+ ):
828
+ raise ValueError(
829
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
830
+ )
831
+
832
+ if prompt is not None and prompt_embeds is not None:
833
+ raise ValueError(
834
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
835
+ " only forward one of the two."
836
+ )
837
+ elif prompt is None and prompt_embeds is None:
838
+ raise ValueError(
839
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
840
+ )
841
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
842
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
843
+
844
+ if negative_prompt is not None and negative_prompt_embeds is not None:
845
+ raise ValueError(
846
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
847
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
848
+ )
849
+
850
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
851
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
852
+ raise ValueError(
853
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
854
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
855
+ f" {negative_prompt_embeds.shape}."
856
+ )
857
+
858
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
859
+ raise ValueError(
860
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
861
+ )
862
+
863
+ if ip_adapter_image_embeds is not None:
864
+ if not isinstance(ip_adapter_image_embeds, list):
865
+ raise ValueError(
866
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
867
+ )
868
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
869
+ raise ValueError(
870
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
871
+ )
872
+
873
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
874
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
875
+ shape = (
876
+ batch_size,
877
+ num_channels_latents,
878
+ int(height) // self.vae_scale_factor,
879
+ int(width) // self.vae_scale_factor,
880
+ )
881
+ if isinstance(generator, list) and len(generator) != batch_size:
882
+ raise ValueError(
883
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
884
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
885
+ )
886
+
887
+ if latents is None:
888
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
889
+ else:
890
+ latents = latents.to(device)
891
+
892
+ # scale the initial noise by the standard deviation required by the scheduler
893
+ latents = latents * self.scheduler.init_noise_sigma
894
+ return latents
895
+
896
+ def prepare_image_latents(
897
+ self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
898
+ ):
899
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
900
+ raise ValueError(
901
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
902
+ )
903
+
904
+ image = image.to(device=device, dtype=dtype)
905
+
906
+ batch_size = batch_size * num_images_per_prompt
907
+
908
+ if image.shape[1] == 4:
909
+ image_latents = image
910
+ else:
911
+ image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
912
+
913
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
914
+ # expand image_latents for batch_size
915
+ deprecation_message = (
916
+ f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
917
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
918
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
919
+ " your script to pass as many initial images as text prompts to suppress this warning."
920
+ )
921
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
922
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
923
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
924
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
925
+ raise ValueError(
926
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
927
+ )
928
+ else:
929
+ image_latents = torch.cat([image_latents], dim=0)
930
+
931
+ if do_classifier_free_guidance:
932
+ uncond_image_latents = torch.zeros_like(image_latents)
933
+ image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)
934
+
935
+ return image_latents
936
+
937
+ @property
938
+ def guidance_scale(self):
939
+ return self._guidance_scale
940
+
941
+ @property
942
+ def image_guidance_scale(self):
943
+ return self._image_guidance_scale
944
+
945
+ @property
946
+ def num_timesteps(self):
947
+ return self._num_timesteps
948
+
949
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
950
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
951
+ # corresponds to doing no classifier free guidance.
952
+ @property
953
+ def do_classifier_free_guidance(self):
954
+ return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0
955
+
956
+ @classmethod
957
+ def from_pretrained_with_fidelity(cls, pretrained_model_path, fidelity_mlp_path=None, **kwargs):
958
+ """Load the pipeline with a fidelity MLP model."""
959
+ pipeline = cls.from_pretrained(pretrained_model_path, **kwargs)
960
+
961
+ if fidelity_mlp_path is not None:
962
+ from fidelity_mlp import FidelityMLP
963
+ # Load the FidelityMLP
964
+ hidden_size = pipeline.text_encoder.config.hidden_size
965
+ if os.path.exists(fidelity_mlp_path):
966
+ pipeline.fidelity_mlp = FidelityMLP.from_pretrained(fidelity_mlp_path)
967
+ else:
968
+ # Create a new one if path doesn't exist
969
+ pipeline.fidelity_mlp = FidelityMLP(hidden_size)
970
+ logger.warning(
971
+ f"Fidelity MLP not found at {fidelity_mlp_path}, initialized a new one with hidden size {hidden_size}"
972
+ )
973
+
974
+ # Move to the same device as the pipeline
975
+ pipeline.fidelity_mlp.to(pipeline.device)
976
+
977
+ return pipeline