staswrs
git and hf files
514160d
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
from typing import Any, Callable, Dict, List, Optional, Union
import numpy
import numpy as np
import torch
import torch.distributed
import torch.utils.checkpoint
import transformers
from PIL import Image
import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
EulerAncestralDiscreteScheduler,
UNet2DConditionModel,
ImagePipelineOutput
)
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.image_processor import PipelineImageInput
from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline, \
retrieve_timesteps, rescale_noise_cfg
from diffusers.schedulers import KarrasDiffusionSchedulers, LCMScheduler
from diffusers.utils import deprecate
from einops import rearrange
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from .unet.modules import UNet2p5DConditionModel, \
compute_multi_resolution_mask, compute_multi_resolution_discrete_voxel_indice
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
return x[(...,) + (None,) * dims_to_append]
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
scaled_timestep = timestep_scaling * timestep
c_skip = sigma_data ** 2 / (scaled_timestep ** 2 + sigma_data ** 2)
c_out = scaled_timestep / (scaled_timestep ** 2 + sigma_data ** 2) ** 0.5
return c_skip, c_out
# Compare LCMScheduler.step, Step 4
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas, N_gen):
alphas = extract_into_tensor(alphas, timesteps, sample.shape, N_gen)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape, N_gen)
model_output = rearrange(model_output, '(b n) c h w -> b n c h w', n=N_gen)
if prediction_type == "epsilon":
pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "sample":
pred_x_0 = model_output
elif prediction_type == "v_prediction":
pred_x_0 = alphas * sample - sigmas * model_output
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_x_0
# Based on step 4 in DDIMScheduler.step
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas, N_gen):
alphas = extract_into_tensor(alphas, timesteps, sample.shape, N_gen)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape, N_gen)
model_output = rearrange(model_output, '(b n) c h w -> b n c h w', n=N_gen)
if prediction_type == "epsilon":
pred_epsilon = model_output
elif prediction_type == "sample":
pred_epsilon = (sample - alphas * model_output) / sigmas
elif prediction_type == "v_prediction":
pred_epsilon = alphas * model_output + sigmas * sample
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_epsilon
def extract_into_tensor(a, t, x_shape, N_gen):
# b, *_ = t.shape
out = a.gather(-1, t)
out = out.repeat(N_gen)
out = rearrange(out, '(b n) -> b n', n=N_gen)
b, c, *_ = out.shape
return out.reshape(b, c, *((1,) * (len(x_shape) - 2)))
class DDIMSolver:
def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
# DDIM sampling parameters
step_ratio = timesteps // ddim_timesteps
self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
self.ddim_alpha_cumprods_prev = np.asarray(
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
)
# convert to torch tensors
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
def to(self, device):
self.ddim_timesteps = self.ddim_timesteps.to(device)
self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
return self
def ddim_step(self, pred_x0, pred_noise, timestep_index, N_gen):
alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape, N_gen)
dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
return x_prev
@torch.no_grad()
def update_ema(target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
def to_rgb_image(maybe_rgba: Image.Image):
if maybe_rgba.mode == 'RGB':
return maybe_rgba
elif maybe_rgba.mode == 'RGBA':
rgba = maybe_rgba
img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
img = Image.fromarray(img, 'RGB')
img.paste(rgba, mask=rgba.getchannel('A'))
return img
else:
raise ValueError("Unsupported image type.", maybe_rgba.mode)
class HunyuanPaintPipeline(StableDiffusionPipeline):
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2p5DConditionModel,
scheduler: KarrasDiffusionSchedulers,
feature_extractor: CLIPImageProcessor,
safety_checker=None,
use_torch_compile=False,
):
DiffusionPipeline.__init__(self)
safety_checker = None
self.register_modules(
vae=torch.compile(vae) if use_torch_compile else vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=torch.compile(feature_extractor) if use_torch_compile else feature_extractor,
)
self.solver = DDIMSolver(
scheduler.alphas_cumprod.numpy(),
timesteps=scheduler.config.num_train_timesteps,
ddim_timesteps=30,
).to('cuda')
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.is_turbo = False
def set_turbo(self, is_turbo: bool):
self.is_turbo = is_turbo
@torch.no_grad()
def encode_images(self, images):
B = images.shape[0]
images = rearrange(images, 'b n c h w -> (b n) c h w')
dtype = next(self.vae.parameters()).dtype
images = (images - 0.5) * 2.0
posterior = self.vae.encode(images.to(dtype)).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
latents = rearrange(latents, '(b n) c h w -> b n c h w', b=B)
return latents
@torch.no_grad()
def __call__(
self,
image: Image.Image = None,
prompt=None,
negative_prompt='watermark, ugly, deformed, noisy, blurry, low contrast',
*args,
num_images_per_prompt: Optional[int] = 1,
guidance_scale=2.0,
output_type: Optional[str] = "pil",
width=512,
height=512,
num_inference_steps=28,
return_dict=True,
**cached_condition,
):
device = self._execution_device
if image is None:
raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.")
assert not isinstance(image, torch.Tensor)
if not isinstance(image, List):
image = [image]
image = [to_rgb_image(img) for img in image]
image_vae = [torch.tensor(np.array(img) / 255.0) for img in image]
image_vae = [img_vae.unsqueeze(0).permute(0, 3, 1, 2).unsqueeze(0) for img_vae in image_vae]
image_vae = torch.cat(image_vae, dim=1)
image_vae = image_vae.to(device=device, dtype=self.vae.dtype)
batch_size, N_ref = image_vae.shape[0], image_vae.shape[1]
assert batch_size == 1
assert num_images_per_prompt == 1
ref_latents = self.encode_images(image_vae)
def convert_pil_list_to_tensor(images):
bg_c = [1., 1., 1.]
images_tensor = []
for batch_imgs in images:
view_imgs = []
for pil_img in batch_imgs:
img = numpy.asarray(pil_img, dtype=numpy.float32) / 255.
if img.shape[2] > 3:
alpha = img[:, :, 3:]
img = img[:, :, :3] * alpha + bg_c * (1 - alpha)
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).contiguous().half().to("cuda")
view_imgs.append(img)
view_imgs = torch.cat(view_imgs, dim=0)
images_tensor.append(view_imgs.unsqueeze(0))
images_tensor = torch.cat(images_tensor, dim=0)
return images_tensor
if "normal_imgs" in cached_condition:
if isinstance(cached_condition["normal_imgs"], List):
cached_condition["normal_imgs"] = convert_pil_list_to_tensor(cached_condition["normal_imgs"])
cached_condition['normal_imgs'] = self.encode_images(cached_condition["normal_imgs"])
if "position_imgs" in cached_condition:
if isinstance(cached_condition["position_imgs"], List):
cached_condition["position_imgs"] = convert_pil_list_to_tensor(cached_condition["position_imgs"])
cached_condition['position_maps'] = cached_condition['position_imgs']
cached_condition["position_imgs"] = self.encode_images(cached_condition["position_imgs"])
if 'camera_info_gen' in cached_condition:
camera_info = cached_condition['camera_info_gen'] # B,N
if isinstance(camera_info, List):
camera_info = torch.tensor(camera_info)
camera_info = camera_info.to(device).to(torch.int64)
cached_condition['camera_info_gen'] = camera_info
if 'camera_info_ref' in cached_condition:
camera_info = cached_condition['camera_info_ref'] # B,N
if isinstance(camera_info, List):
camera_info = torch.tensor(camera_info)
camera_info = camera_info.to(device).to(torch.int64)
cached_condition['camera_info_ref'] = camera_info
cached_condition['ref_latents'] = ref_latents
if self.is_turbo:
if 'position_maps' in cached_condition:
cached_condition['position_attn_mask'] = (
compute_multi_resolution_mask(cached_condition['position_maps'])
)
cached_condition['position_voxel_indices'] = (
compute_multi_resolution_discrete_voxel_indice(cached_condition['position_maps'])
)
if (guidance_scale > 1) and (not self.is_turbo):
negative_ref_latents = torch.zeros_like(cached_condition['ref_latents'])
cached_condition['ref_latents'] = torch.cat([negative_ref_latents, cached_condition['ref_latents']])
cached_condition['ref_scale'] = torch.as_tensor([0.0, 1.0]).to(cached_condition['ref_latents'])
if "normal_imgs" in cached_condition:
cached_condition['normal_imgs'] = torch.cat(
(cached_condition['normal_imgs'], cached_condition['normal_imgs']))
if "position_imgs" in cached_condition:
cached_condition['position_imgs'] = torch.cat(
(cached_condition['position_imgs'], cached_condition['position_imgs']))
if 'position_maps' in cached_condition:
cached_condition['position_maps'] = torch.cat(
(cached_condition['position_maps'], cached_condition['position_maps']))
if 'camera_info_gen' in cached_condition:
cached_condition['camera_info_gen'] = torch.cat(
(cached_condition['camera_info_gen'], cached_condition['camera_info_gen']))
if 'camera_info_ref' in cached_condition:
cached_condition['camera_info_ref'] = torch.cat(
(cached_condition['camera_info_ref'], cached_condition['camera_info_ref']))
prompt_embeds = self.unet.learned_text_clip_gen.repeat(num_images_per_prompt, 1, 1)
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
latents: torch.Tensor = self.denoise(
None,
*args,
cross_attention_kwargs=None,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
num_inference_steps=num_inference_steps,
output_type='latent',
width=width,
height=height,
**cached_condition
).images
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
def denoise(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`]
(https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated,",
"consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated,",
"consider using `callback_on_step_end`",
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Encode input prompt
lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
self.do_classifier_free_guidance if self.is_turbo else False,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if (self.do_classifier_free_guidance) and (not self.is_turbo):
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance if self.is_turbo else False,
)
# 4. Prepare
if self.is_turbo:
bsz = 3
N_gen = 15
index = torch.range(29, 0, -bsz, device='cuda').long()
timesteps = self.solver.ddim_timesteps[index]
self.scheduler.set_timesteps(timesteps=timesteps.cpu(), device='cuda')
else:
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
assert num_images_per_prompt == 1
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * kwargs['num_in_batch'], # num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Add image embeds for IP-Adapter
added_cond_kwargs = (
{"image_embeds": image_embeds}
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
else None
)
# 6.2 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latents = rearrange(latents, '(b n) c h w -> b n c h w', n=kwargs['num_in_batch'])
latent_model_input = (
torch.cat([latents] * 2)
if ((self.do_classifier_free_guidance) and (not self.is_turbo))
else latents
)
latent_model_input = rearrange(latent_model_input, 'b n c h w -> (b n) c h w')
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = rearrange(latent_model_input, '(b n) c h w ->b n c h w', n=kwargs['num_in_batch'])
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False, **kwargs
)[0]
latents = rearrange(latents, 'b n c h w -> (b n) c h w')
# perform guidance
if (self.do_classifier_free_guidance) and (not self.is_turbo):
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if (self.do_classifier_free_guidance) and (self.guidance_rescale > 0.0) and (not self.is_turbo):
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = \
self.scheduler.step(noise_pred, t, latents[:, :num_channels_latents, :, :], **extra_step_kwargs,
return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)