lcybuaa commited on
Commit
400e019
·
verified ·
1 Parent(s): 96226bb

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. model_index.json +1 -1
  2. pipeline_text2earth_diffusion.py +1103 -0
model_index.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_class_name": "StableDiffusionPipeline",
3
  "_diffusers_version": "0.8.0",
4
  "feature_extractor": [
5
  "transformers",
 
1
  {
2
+ "_class_name": "Text2EarthDiffusionPipeline",
3
  "_diffusers_version": "0.8.0",
4
  "feature_extractor": [
5
  "transformers",
pipeline_text2earth_diffusion.py ADDED
@@ -0,0 +1,1103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import inspect
15
+ from typing import Any, Callable, Dict, List, Optional, Union
16
+
17
+ import torch
18
+ from packaging import version
19
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
20
+
21
+
22
+
23
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
24
+ from diffusers.configuration_utils import FrozenDict
25
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
26
+ from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
28
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
29
+ from diffusers.schedulers import KarrasDiffusionSchedulers
30
+ from diffusers.utils import (
31
+ USE_PEFT_BACKEND,
32
+ deprecate,
33
+ logging,
34
+ replace_example_docstring,
35
+ scale_lora_layers,
36
+ unscale_lora_layers,
37
+ )
38
+ from diffusers.utils.torch_utils import randn_tensor
39
+ from diffusers import DiffusionPipeline, StableDiffusionMixin
40
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
41
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
42
+
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+ EXAMPLE_DOC_STRING = """
47
+ Examples:
48
+ ```py
49
+ >>> import torch
50
+ >>> from diffusers import Text2EarthDiffusionPipeline
51
+
52
+ >>> pipe = Text2EarthDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
53
+ >>> pipe = pipe.to("cuda")
54
+
55
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
56
+ >>> image = pipe(prompt).images[0]
57
+ ```
58
+ """
59
+
60
+
61
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
62
+ """
63
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
64
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
65
+ """
66
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
67
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
68
+ # rescale the results from guidance (fixes overexposure)
69
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
70
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
71
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
72
+ return noise_cfg
73
+
74
+
75
+ def retrieve_timesteps(
76
+ scheduler,
77
+ num_inference_steps: Optional[int] = None,
78
+ device: Optional[Union[str, torch.device]] = None,
79
+ timesteps: Optional[List[int]] = None,
80
+ sigmas: Optional[List[float]] = None,
81
+ **kwargs,
82
+ ):
83
+ """
84
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
85
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
86
+
87
+ Args:
88
+ scheduler (`SchedulerMixin`):
89
+ The scheduler to get timesteps from.
90
+ num_inference_steps (`int`):
91
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
92
+ must be `None`.
93
+ device (`str` or `torch.device`, *optional*):
94
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
95
+ timesteps (`List[int]`, *optional*):
96
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
97
+ `num_inference_steps` and `sigmas` must be `None`.
98
+ sigmas (`List[float]`, *optional*):
99
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
100
+ `num_inference_steps` and `timesteps` must be `None`.
101
+
102
+ Returns:
103
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
104
+ second element is the number of inference steps.
105
+ """
106
+ if timesteps is not None and sigmas is not None:
107
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
108
+ if timesteps is not None:
109
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
110
+ if not accepts_timesteps:
111
+ raise ValueError(
112
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
113
+ f" timestep schedules. Please check whether you are using the correct scheduler."
114
+ )
115
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
116
+ timesteps = scheduler.timesteps
117
+ num_inference_steps = len(timesteps)
118
+ elif sigmas is not None:
119
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
120
+ if not accept_sigmas:
121
+ raise ValueError(
122
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
123
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
124
+ )
125
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
126
+ timesteps = scheduler.timesteps
127
+ num_inference_steps = len(timesteps)
128
+ else:
129
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
130
+ timesteps = scheduler.timesteps
131
+ return timesteps, num_inference_steps
132
+
133
+
134
+ class Text2EarthDiffusionPipeline(
135
+ DiffusionPipeline,
136
+ StableDiffusionMixin,
137
+ TextualInversionLoaderMixin,
138
+ LoraLoaderMixin,
139
+ IPAdapterMixin,
140
+ FromSingleFileMixin,
141
+ ):
142
+ r"""
143
+ Pipeline for text-to-image generation using Stable Diffusion.
144
+
145
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
146
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
147
+
148
+ The pipeline also inherits the following loading methods:
149
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
150
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
151
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
152
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
153
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
154
+
155
+ Args:
156
+ vae ([`AutoencoderKL`]):
157
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
158
+ text_encoder ([`~transformers.CLIPTextModel`]):
159
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
160
+ tokenizer ([`~transformers.CLIPTokenizer`]):
161
+ A `CLIPTokenizer` to tokenize text.
162
+ unet ([`UNet2DConditionModel`]):
163
+ A `UNet2DConditionModel` to denoise the encoded image latents.
164
+ scheduler ([`SchedulerMixin`]):
165
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
166
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
167
+ safety_checker ([`StableDiffusionSafetyChecker`]):
168
+ Classification module that estimates whether generated images could be considered offensive or harmful.
169
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
170
+ about a model's potential harms.
171
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
172
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
173
+ """
174
+
175
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
176
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
177
+ _exclude_from_cpu_offload = ["safety_checker"]
178
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
179
+
180
+ def __init__(
181
+ self,
182
+ vae: AutoencoderKL,
183
+ text_encoder: CLIPTextModel,
184
+ tokenizer: CLIPTokenizer,
185
+ unet: UNet2DConditionModel,
186
+ scheduler: KarrasDiffusionSchedulers,
187
+ safety_checker: StableDiffusionSafetyChecker,
188
+ feature_extractor: CLIPImageProcessor,
189
+ image_encoder: CLIPVisionModelWithProjection = None,
190
+ requires_safety_checker: bool = True,
191
+ ):
192
+ super().__init__()
193
+
194
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
195
+ deprecation_message = (
196
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
197
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
198
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
199
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
200
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
201
+ " file"
202
+ )
203
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
204
+ new_config = dict(scheduler.config)
205
+ new_config["steps_offset"] = 1
206
+ scheduler._internal_dict = FrozenDict(new_config)
207
+
208
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
209
+ deprecation_message = (
210
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
211
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
212
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
213
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
214
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
215
+ )
216
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
217
+ new_config = dict(scheduler.config)
218
+ new_config["clip_sample"] = False
219
+ scheduler._internal_dict = FrozenDict(new_config)
220
+
221
+ if safety_checker is None and requires_safety_checker:
222
+ logger.warning(
223
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
224
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
225
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
226
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
227
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
228
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
229
+ )
230
+
231
+ if safety_checker is not None and feature_extractor is None:
232
+ raise ValueError(
233
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
234
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
235
+ )
236
+
237
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
238
+ version.parse(unet.config._diffusers_version).base_version
239
+ ) < version.parse("0.9.0.dev0")
240
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
241
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
242
+ deprecation_message = (
243
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
244
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
245
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
246
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
247
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
248
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
249
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
250
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
251
+ " the `unet/config.json` file"
252
+ )
253
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
254
+ new_config = dict(unet.config)
255
+ new_config["sample_size"] = 64
256
+ unet._internal_dict = FrozenDict(new_config)
257
+
258
+ self.register_modules(
259
+ vae=vae,
260
+ text_encoder=text_encoder,
261
+ tokenizer=tokenizer,
262
+ unet=unet,
263
+ scheduler=scheduler,
264
+ safety_checker=safety_checker,
265
+ feature_extractor=feature_extractor,
266
+ image_encoder=image_encoder,
267
+ )
268
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
269
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
270
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
271
+
272
+ def _encode_prompt(
273
+ self,
274
+ prompt,
275
+ device,
276
+ num_images_per_prompt,
277
+ do_classifier_free_guidance,
278
+ negative_prompt=None,
279
+ prompt_embeds: Optional[torch.Tensor] = None,
280
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
281
+ lora_scale: Optional[float] = None,
282
+ **kwargs,
283
+ ):
284
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
285
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
286
+
287
+ prompt_embeds_tuple = self.encode_prompt(
288
+ prompt=prompt,
289
+ device=device,
290
+ num_images_per_prompt=num_images_per_prompt,
291
+ do_classifier_free_guidance=do_classifier_free_guidance,
292
+ negative_prompt=negative_prompt,
293
+ prompt_embeds=prompt_embeds,
294
+ negative_prompt_embeds=negative_prompt_embeds,
295
+ lora_scale=lora_scale,
296
+ **kwargs,
297
+ )
298
+
299
+ # concatenate for backwards comp
300
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
301
+
302
+ return prompt_embeds
303
+
304
+ def encode_prompt(
305
+ self,
306
+ prompt,
307
+ device,
308
+ num_images_per_prompt,
309
+ do_classifier_free_guidance,
310
+ negative_prompt=None,
311
+ prompt_embeds: Optional[torch.Tensor] = None,
312
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
313
+ lora_scale: Optional[float] = None,
314
+ clip_skip: Optional[int] = None,
315
+ ):
316
+ r"""
317
+ Encodes the prompt into text encoder hidden states.
318
+
319
+ Args:
320
+ prompt (`str` or `List[str]`, *optional*):
321
+ prompt to be encoded
322
+ device: (`torch.device`):
323
+ torch device
324
+ num_images_per_prompt (`int`):
325
+ number of images that should be generated per prompt
326
+ do_classifier_free_guidance (`bool`):
327
+ whether to use classifier free guidance or not
328
+ negative_prompt (`str` or `List[str]`, *optional*):
329
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
330
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
331
+ less than `1`).
332
+ prompt_embeds (`torch.Tensor`, *optional*):
333
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
334
+ provided, text embeddings will be generated from `prompt` input argument.
335
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
336
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
337
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
338
+ argument.
339
+ lora_scale (`float`, *optional*):
340
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
341
+ clip_skip (`int`, *optional*):
342
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
343
+ the output of the pre-final layer will be used for computing the prompt embeddings.
344
+ """
345
+ # set lora scale so that monkey patched LoRA
346
+ # function of text encoder can correctly access it
347
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
348
+ self._lora_scale = lora_scale
349
+
350
+ # dynamically adjust the LoRA scale
351
+ if not USE_PEFT_BACKEND:
352
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
353
+ else:
354
+ scale_lora_layers(self.text_encoder, lora_scale)
355
+
356
+ if prompt is not None and isinstance(prompt, str):
357
+ batch_size = 1
358
+ elif prompt is not None and isinstance(prompt, list):
359
+ batch_size = len(prompt)
360
+ else:
361
+ batch_size = prompt_embeds.shape[0]
362
+
363
+ if prompt_embeds is None:
364
+ # textual inversion: process multi-vector tokens if necessary
365
+ if isinstance(self, TextualInversionLoaderMixin):
366
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
367
+
368
+ text_inputs = self.tokenizer(
369
+ prompt,
370
+ padding="max_length",
371
+ max_length=self.tokenizer.model_max_length,
372
+ truncation=True,
373
+ return_tensors="pt",
374
+ )
375
+ text_input_ids = text_inputs.input_ids
376
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
377
+
378
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
379
+ text_input_ids, untruncated_ids
380
+ ):
381
+ removed_text = self.tokenizer.batch_decode(
382
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
383
+ )
384
+ logger.warning(
385
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
386
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
387
+ )
388
+
389
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
390
+ attention_mask = text_inputs.attention_mask.to(device)
391
+ else:
392
+ attention_mask = None
393
+
394
+ if clip_skip is None:
395
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
396
+ prompt_embeds = prompt_embeds[0]
397
+ else:
398
+ prompt_embeds = self.text_encoder(
399
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
400
+ )
401
+ # Access the `hidden_states` first, that contains a tuple of
402
+ # all the hidden states from the encoder layers. Then index into
403
+ # the tuple to access the hidden states from the desired layer.
404
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
405
+ # We also need to apply the final LayerNorm here to not mess with the
406
+ # representations. The `last_hidden_states` that we typically use for
407
+ # obtaining the final prompt representations passes through the LayerNorm
408
+ # layer.
409
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
410
+
411
+ if self.text_encoder is not None:
412
+ prompt_embeds_dtype = self.text_encoder.dtype
413
+ elif self.unet is not None:
414
+ prompt_embeds_dtype = self.unet.dtype
415
+ else:
416
+ prompt_embeds_dtype = prompt_embeds.dtype
417
+
418
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
419
+
420
+ bs_embed, seq_len, _ = prompt_embeds.shape
421
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
422
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
423
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
424
+
425
+ # get unconditional embeddings for classifier free guidance
426
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
427
+ uncond_tokens: List[str]
428
+ if negative_prompt is None:
429
+ uncond_tokens = [""] * batch_size
430
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
431
+ raise TypeError(
432
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
433
+ f" {type(prompt)}."
434
+ )
435
+ elif isinstance(negative_prompt, str):
436
+ uncond_tokens = [negative_prompt]
437
+ elif batch_size != len(negative_prompt):
438
+ raise ValueError(
439
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
440
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
441
+ " the batch size of `prompt`."
442
+ )
443
+ else:
444
+ uncond_tokens = negative_prompt
445
+
446
+ # textual inversion: process multi-vector tokens if necessary
447
+ if isinstance(self, TextualInversionLoaderMixin):
448
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
449
+
450
+ max_length = prompt_embeds.shape[1]
451
+ uncond_input = self.tokenizer(
452
+ uncond_tokens,
453
+ padding="max_length",
454
+ max_length=max_length,
455
+ truncation=True,
456
+ return_tensors="pt",
457
+ )
458
+
459
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
460
+ attention_mask = uncond_input.attention_mask.to(device)
461
+ else:
462
+ attention_mask = None
463
+
464
+ negative_prompt_embeds = self.text_encoder(
465
+ uncond_input.input_ids.to(device),
466
+ attention_mask=attention_mask,
467
+ )
468
+ negative_prompt_embeds = negative_prompt_embeds[0]
469
+
470
+ if do_classifier_free_guidance:
471
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
472
+ seq_len = negative_prompt_embeds.shape[1]
473
+
474
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
475
+
476
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
477
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
478
+
479
+ if self.text_encoder is not None:
480
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
481
+ # Retrieve the original scale by scaling back the LoRA layers
482
+ unscale_lora_layers(self.text_encoder, lora_scale)
483
+
484
+ return prompt_embeds, negative_prompt_embeds
485
+
486
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
487
+ dtype = next(self.image_encoder.parameters()).dtype
488
+
489
+ if not isinstance(image, torch.Tensor):
490
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
491
+
492
+ image = image.to(device=device, dtype=dtype)
493
+ if output_hidden_states:
494
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
495
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
496
+ uncond_image_enc_hidden_states = self.image_encoder(
497
+ torch.zeros_like(image), output_hidden_states=True
498
+ ).hidden_states[-2]
499
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
500
+ num_images_per_prompt, dim=0
501
+ )
502
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
503
+ else:
504
+ image_embeds = self.image_encoder(image).image_embeds
505
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
506
+ uncond_image_embeds = torch.zeros_like(image_embeds)
507
+
508
+ return image_embeds, uncond_image_embeds
509
+
510
+ def prepare_ip_adapter_image_embeds(
511
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
512
+ ):
513
+ if ip_adapter_image_embeds is None:
514
+ if not isinstance(ip_adapter_image, list):
515
+ ip_adapter_image = [ip_adapter_image]
516
+
517
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
518
+ raise ValueError(
519
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
520
+ )
521
+
522
+ image_embeds = []
523
+ for single_ip_adapter_image, image_proj_layer in zip(
524
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
525
+ ):
526
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
527
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
528
+ single_ip_adapter_image, device, 1, output_hidden_state
529
+ )
530
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
531
+ single_negative_image_embeds = torch.stack(
532
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
533
+ )
534
+
535
+ if do_classifier_free_guidance:
536
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
537
+ single_image_embeds = single_image_embeds.to(device)
538
+
539
+ image_embeds.append(single_image_embeds)
540
+ else:
541
+ repeat_dims = [1]
542
+ image_embeds = []
543
+ for single_image_embeds in ip_adapter_image_embeds:
544
+ if do_classifier_free_guidance:
545
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
546
+ single_image_embeds = single_image_embeds.repeat(
547
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
548
+ )
549
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
550
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
551
+ )
552
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
553
+ else:
554
+ single_image_embeds = single_image_embeds.repeat(
555
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
556
+ )
557
+ image_embeds.append(single_image_embeds)
558
+
559
+ return image_embeds
560
+
561
+ def run_safety_checker(self, image, device, dtype):
562
+ if self.safety_checker is None:
563
+ has_nsfw_concept = None
564
+ else:
565
+ if torch.is_tensor(image):
566
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
567
+ else:
568
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
569
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
570
+ image, has_nsfw_concept = self.safety_checker(
571
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
572
+ )
573
+ return image, has_nsfw_concept
574
+
575
+ def decode_latents(self, latents):
576
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
577
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
578
+
579
+ latents = 1 / self.vae.config.scaling_factor * latents
580
+ image = self.vae.decode(latents, return_dict=False)[0]
581
+ image = (image / 2 + 0.5).clamp(0, 1)
582
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
583
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
584
+ return image
585
+
586
+ def prepare_extra_step_kwargs(self, generator, eta):
587
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
588
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
589
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
590
+ # and should be between [0, 1]
591
+
592
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
593
+ extra_step_kwargs = {}
594
+ if accepts_eta:
595
+ extra_step_kwargs["eta"] = eta
596
+
597
+ # check if the scheduler accepts generator
598
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
599
+ if accepts_generator:
600
+ extra_step_kwargs["generator"] = generator
601
+ return extra_step_kwargs
602
+
603
+ def check_inputs(
604
+ self,
605
+ prompt,
606
+ height,
607
+ width,
608
+ callback_steps,
609
+ negative_prompt=None,
610
+ prompt_embeds=None,
611
+ negative_prompt_embeds=None,
612
+ ip_adapter_image=None,
613
+ ip_adapter_image_embeds=None,
614
+ callback_on_step_end_tensor_inputs=None,
615
+ ):
616
+ if height % 8 != 0 or width % 8 != 0:
617
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
618
+
619
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
620
+ raise ValueError(
621
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
622
+ f" {type(callback_steps)}."
623
+ )
624
+ if callback_on_step_end_tensor_inputs is not None and not all(
625
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
626
+ ):
627
+ raise ValueError(
628
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
629
+ )
630
+
631
+ if prompt is not None and prompt_embeds is not None:
632
+ raise ValueError(
633
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
634
+ " only forward one of the two."
635
+ )
636
+ elif prompt is None and prompt_embeds is None:
637
+ raise ValueError(
638
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
639
+ )
640
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
641
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
642
+
643
+ if negative_prompt is not None and negative_prompt_embeds is not None:
644
+ raise ValueError(
645
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
646
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
647
+ )
648
+
649
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
650
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
651
+ raise ValueError(
652
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
653
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
654
+ f" {negative_prompt_embeds.shape}."
655
+ )
656
+
657
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
658
+ raise ValueError(
659
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
660
+ )
661
+
662
+ if ip_adapter_image_embeds is not None:
663
+ if not isinstance(ip_adapter_image_embeds, list):
664
+ raise ValueError(
665
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
666
+ )
667
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
668
+ raise ValueError(
669
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
670
+ )
671
+
672
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
673
+ shape = (
674
+ batch_size,
675
+ num_channels_latents,
676
+ int(height) // self.vae_scale_factor,
677
+ int(width) // self.vae_scale_factor,
678
+ )
679
+ if isinstance(generator, list) and len(generator) != batch_size:
680
+ raise ValueError(
681
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
682
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
683
+ )
684
+
685
+ if latents is None:
686
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
687
+ else:
688
+ latents = latents.to(device)
689
+
690
+ # scale the initial noise by the standard deviation required by the scheduler
691
+ latents = latents * self.scheduler.init_noise_sigma
692
+ return latents
693
+
694
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
695
+ def get_guidance_scale_embedding(
696
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
697
+ ) -> torch.Tensor:
698
+ """
699
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
700
+
701
+ Args:
702
+ w (`torch.Tensor`):
703
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
704
+ embedding_dim (`int`, *optional*, defaults to 512):
705
+ Dimension of the embeddings to generate.
706
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
707
+ Data type of the generated embeddings.
708
+
709
+ Returns:
710
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
711
+ """
712
+ assert len(w.shape) == 1
713
+ w = w * 1000.0
714
+
715
+ half_dim = embedding_dim // 2
716
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
717
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
718
+ emb = w.to(dtype)[:, None] * emb[None, :]
719
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
720
+ if embedding_dim % 2 == 1: # zero pad
721
+ emb = torch.nn.functional.pad(emb, (0, 1))
722
+ assert emb.shape == (w.shape[0], embedding_dim)
723
+ return emb
724
+
725
+ @property
726
+ def guidance_scale(self):
727
+ return self._guidance_scale
728
+
729
+ @property
730
+ def guidance_rescale(self):
731
+ return self._guidance_rescale
732
+
733
+ @property
734
+ def clip_skip(self):
735
+ return self._clip_skip
736
+
737
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
738
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
739
+ # corresponds to doing no classifier free guidance.
740
+ @property
741
+ def do_classifier_free_guidance(self):
742
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
743
+
744
+ @property
745
+ def cross_attention_kwargs(self):
746
+ return self._cross_attention_kwargs
747
+
748
+ @property
749
+ def num_timesteps(self):
750
+ return self._num_timesteps
751
+
752
+ @property
753
+ def interrupt(self):
754
+ return self._interrupt
755
+
756
+ @torch.no_grad()
757
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
758
+ def __call__(
759
+ self,
760
+ prompt: Union[str, List[str]] = None,
761
+ height: Optional[int] = None,
762
+ width: Optional[int] = None,
763
+ num_inference_steps: int = 50,
764
+ timesteps: List[int] = None,
765
+ sigmas: List[float] = None,
766
+ guidance_scale: float = 7.5,
767
+ negative_prompt: Optional[Union[str, List[str]]] = None,
768
+ num_images_per_prompt: Optional[int] = 1,
769
+ eta: float = 0.0,
770
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
771
+ latents: Optional[torch.Tensor] = None,
772
+ prompt_embeds: Optional[torch.Tensor] = None,
773
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
774
+ ip_adapter_image: Optional[PipelineImageInput] = None,
775
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
776
+ output_type: Optional[str] = "pil",
777
+ return_dict: bool = True,
778
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
779
+ guidance_rescale: float = 0.0,
780
+ clip_skip: Optional[int] = None,
781
+ callback_on_step_end: Optional[
782
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
783
+ ] = None,
784
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
785
+ **kwargs,
786
+ ):
787
+ r"""
788
+ The call function to the pipeline for generation.
789
+
790
+ Args:
791
+ prompt (`str` or `List[str]`, *optional*):
792
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
793
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
794
+ The height in pixels of the generated image.
795
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
796
+ The width in pixels of the generated image.
797
+ num_inference_steps (`int`, *optional*, defaults to 50):
798
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
799
+ expense of slower inference.
800
+ timesteps (`List[int]`, *optional*):
801
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
802
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
803
+ passed will be used. Must be in descending order.
804
+ sigmas (`List[float]`, *optional*):
805
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
806
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
807
+ will be used.
808
+ guidance_scale (`float`, *optional*, defaults to 7.5):
809
+ A higher guidance scale value encourages the model to generate images closely linked to the text
810
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
811
+ negative_prompt (`str` or `List[str]`, *optional*):
812
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
813
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
814
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
815
+ The number of images to generate per prompt.
816
+ eta (`float`, *optional*, defaults to 0.0):
817
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
818
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
819
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
820
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
821
+ generation deterministic.
822
+ latents (`torch.Tensor`, *optional*):
823
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
824
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
825
+ tensor is generated by sampling using the supplied random `generator`.
826
+ prompt_embeds (`torch.Tensor`, *optional*):
827
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
828
+ provided, text embeddings are generated from the `prompt` input argument.
829
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
830
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
831
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
832
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
833
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
834
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
835
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
836
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
837
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
838
+ output_type (`str`, *optional*, defaults to `"pil"`):
839
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
840
+ return_dict (`bool`, *optional*, defaults to `True`):
841
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
842
+ plain tuple.
843
+ cross_attention_kwargs (`dict`, *optional*):
844
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
845
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
846
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
847
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
848
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
849
+ using zero terminal SNR.
850
+ clip_skip (`int`, *optional*):
851
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
852
+ the output of the pre-final layer will be used for computing the prompt embeddings.
853
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
854
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
855
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
856
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
857
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
858
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
859
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
860
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
861
+ `._callback_tensor_inputs` attribute of your pipeline class.
862
+
863
+ Examples:
864
+
865
+ Returns:
866
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
867
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
868
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
869
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
870
+ "not-safe-for-work" (nsfw) content.
871
+ """
872
+
873
+ callback = kwargs.pop("callback", None)
874
+ callback_steps = kwargs.pop("callback_steps", None)
875
+
876
+ if callback is not None:
877
+ deprecate(
878
+ "callback",
879
+ "1.0.0",
880
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
881
+ )
882
+ if callback_steps is not None:
883
+ deprecate(
884
+ "callback_steps",
885
+ "1.0.0",
886
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
887
+ )
888
+
889
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
890
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
891
+
892
+ # 0. Default height and width to unet
893
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
894
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
895
+ # to deal with lora scaling and other possible forward hooks
896
+
897
+ # 1. Check inputs. Raise error if not correct
898
+ self.check_inputs(
899
+ prompt,
900
+ height,
901
+ width,
902
+ callback_steps,
903
+ negative_prompt,
904
+ prompt_embeds,
905
+ negative_prompt_embeds,
906
+ ip_adapter_image,
907
+ ip_adapter_image_embeds,
908
+ callback_on_step_end_tensor_inputs,
909
+ )
910
+
911
+ self._guidance_scale = guidance_scale
912
+ self._guidance_rescale = guidance_rescale
913
+ self._clip_skip = clip_skip
914
+ self._cross_attention_kwargs = cross_attention_kwargs
915
+ self._interrupt = False
916
+
917
+ # 2. Define call parameters
918
+ if prompt is not None and isinstance(prompt, str):
919
+ batch_size = 1
920
+ elif prompt is not None and isinstance(prompt, list):
921
+ batch_size = len(prompt)
922
+ else:
923
+ batch_size = prompt_embeds.shape[0]
924
+
925
+ device = self._execution_device
926
+
927
+ # 3. Encode input prompt
928
+ lora_scale = (
929
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
930
+ )
931
+
932
+ # FIXME: 判断prompt是str还是list
933
+ if prompt is not None and isinstance(prompt, str):
934
+ # assert '_GOOGLE_LEVEL_' in prompt
935
+ if '_GOOGLE_LEVEL_' in prompt:
936
+ res = [int(prompt.split('_GOOGLE_LEVEL_')[0])]
937
+ prompt = prompt.split('_GOOGLE_LEVEL_')[-1]
938
+ else:
939
+ res = [0]
940
+ prompt = prompt.split('_GOOGLE_LEVEL_')[-1]
941
+ elif prompt is not None and isinstance(prompt, list):
942
+ res_list = []
943
+ prompt_buff = []
944
+ for p in prompt:
945
+ # assert '_GOOGLE_LEVEL_' in p
946
+ if '_GOOGLE_LEVEL_' in p:
947
+ res = int(p.split('_GOOGLE_LEVEL_')[0])
948
+ p = p.split('_GOOGLE_LEVEL_')[-1]
949
+ else:
950
+ res = 0
951
+ p = p.split('_GOOGLE_LEVEL_')[-1]
952
+ res_list.append(res)
953
+ prompt_buff.append(p)
954
+ res = res_list
955
+ prompt = prompt_buff
956
+
957
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
958
+ prompt,
959
+ device,
960
+ num_images_per_prompt,
961
+ self.do_classifier_free_guidance,
962
+ negative_prompt,
963
+ prompt_embeds=prompt_embeds,
964
+ negative_prompt_embeds=negative_prompt_embeds,
965
+ lora_scale=lora_scale,
966
+ clip_skip=self.clip_skip,
967
+ )
968
+
969
+ # For classifier free guidance, we need to do two forward passes.
970
+ # Here we concatenate the unconditional and text embeddings into a single batch
971
+ # to avoid doing two forward passes
972
+ if self.do_classifier_free_guidance:
973
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
974
+
975
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
976
+ image_embeds = self.prepare_ip_adapter_image_embeds(
977
+ ip_adapter_image,
978
+ ip_adapter_image_embeds,
979
+ device,
980
+ batch_size * num_images_per_prompt,
981
+ self.do_classifier_free_guidance,
982
+ )
983
+
984
+ # 4. Prepare timesteps
985
+ timesteps, num_inference_steps = retrieve_timesteps(
986
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
987
+ )
988
+
989
+ # 5. Prepare latent variables
990
+ num_channels_latents = self.unet.config.in_channels
991
+ latents = self.prepare_latents(
992
+ batch_size * num_images_per_prompt,
993
+ num_channels_latents,
994
+ height,
995
+ width,
996
+ prompt_embeds.dtype,
997
+ device,
998
+ generator,
999
+ latents,
1000
+ )
1001
+
1002
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1003
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1004
+
1005
+ # 6.1 Add image embeds for IP-Adapter
1006
+ added_cond_kwargs = (
1007
+ {"image_embeds": image_embeds}
1008
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
1009
+ else None
1010
+ )
1011
+
1012
+ # 6.2 Optionally get Guidance Scale Embedding
1013
+ timestep_cond = None
1014
+ if self.unet.config.time_cond_proj_dim is not None:
1015
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1016
+ timestep_cond = self.get_guidance_scale_embedding(
1017
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1018
+ ).to(device=device, dtype=latents.dtype)
1019
+
1020
+ # 7. Denoising loop
1021
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1022
+ self._num_timesteps = len(timesteps)
1023
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1024
+ for i, t in enumerate(timesteps):
1025
+ if self.interrupt:
1026
+ continue
1027
+
1028
+ # expand the latents if we are doing classifier free guidance
1029
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1030
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1031
+
1032
+ # fixme
1033
+ assert num_images_per_prompt == 1
1034
+ res = torch.tensor(res, dtype=t.dtype, device=device).clone().detach()#torch.tensor(res).to(t.dtype).to(device).clone().detach()
1035
+ res_null = torch.tensor([0]*batch_size, dtype=t.dtype, device=device).clone().detach()
1036
+ # res_in = torch.cat([res]*2) if self.do_classifier_free_guidance else res
1037
+ res_in = torch.cat([res_null, res]) if self.do_classifier_free_guidance else res
1038
+ # TODO: assert num_images_per_prompt != 1
1039
+
1040
+ # predict the noise residual
1041
+ noise_pred = self.unet(
1042
+ latent_model_input,
1043
+ t,
1044
+ encoder_hidden_states=prompt_embeds,
1045
+ timestep_cond=timestep_cond,
1046
+ class_labels=res_in if self.unet.class_embedding is not None else None, # FIXME: res_in
1047
+ cross_attention_kwargs=self.cross_attention_kwargs,
1048
+ added_cond_kwargs=added_cond_kwargs,
1049
+ return_dict=False,
1050
+ )[0]
1051
+
1052
+ # perform guidance
1053
+ if self.do_classifier_free_guidance:
1054
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1055
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1056
+
1057
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1058
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1059
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1060
+
1061
+ # compute the previous noisy sample x_t -> x_t-1
1062
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1063
+
1064
+ if callback_on_step_end is not None:
1065
+ callback_kwargs = {}
1066
+ for k in callback_on_step_end_tensor_inputs:
1067
+ callback_kwargs[k] = locals()[k]
1068
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1069
+
1070
+ latents = callback_outputs.pop("latents", latents)
1071
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1072
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1073
+
1074
+ # call the callback, if provided
1075
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1076
+ progress_bar.update()
1077
+ if callback is not None and i % callback_steps == 0:
1078
+ step_idx = i // getattr(self.scheduler, "order", 1)
1079
+ callback(step_idx, t, latents)
1080
+
1081
+ if not output_type == "latent":
1082
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1083
+ 0
1084
+ ]
1085
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1086
+ else:
1087
+ image = latents
1088
+ has_nsfw_concept = None
1089
+
1090
+ if has_nsfw_concept is None:
1091
+ do_denormalize = [True] * image.shape[0]
1092
+ else:
1093
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1094
+
1095
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1096
+
1097
+ # Offload all models
1098
+ self.maybe_free_model_hooks()
1099
+
1100
+ if not return_dict:
1101
+ return (image, has_nsfw_concept)
1102
+
1103
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)