Spaces:
Paused
Paused
Create inversion.py
Browse files- inversion.py +125 -0
inversion.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Google LLC
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
from typing import Callable
|
| 18 |
+
from diffusers import StableDiffusionXLPipeline
|
| 19 |
+
import torch
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
T = torch.Tensor
|
| 25 |
+
TN = T | None
|
| 26 |
+
InversionCallback = Callable[[StableDiffusionXLPipeline, int, T, dict[str, T]], dict[str, T]]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _get_text_embeddings(prompt: str, tokenizer, text_encoder, device):
|
| 30 |
+
# Tokenize text and get embeddings
|
| 31 |
+
text_inputs = tokenizer(prompt, padding='max_length', max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt')
|
| 32 |
+
text_input_ids = text_inputs.input_ids
|
| 33 |
+
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
prompt_embeds = text_encoder(
|
| 36 |
+
text_input_ids.to(device),
|
| 37 |
+
output_hidden_states=True,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
| 41 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
| 42 |
+
if prompt == '':
|
| 43 |
+
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
| 44 |
+
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
| 45 |
+
return negative_prompt_embeds, negative_pooled_prompt_embeds
|
| 46 |
+
return prompt_embeds, pooled_prompt_embeds
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _encode_text_sdxl(model: StableDiffusionXLPipeline, prompt: str) -> tuple[dict[str, T], T]:
|
| 50 |
+
device = model._execution_device
|
| 51 |
+
prompt_embeds, pooled_prompt_embeds, = _get_text_embeddings(prompt, model.tokenizer, model.text_encoder, device)
|
| 52 |
+
prompt_embeds_2, pooled_prompt_embeds2, = _get_text_embeddings( prompt, model.tokenizer_2, model.text_encoder_2, device)
|
| 53 |
+
prompt_embeds = torch.cat((prompt_embeds, prompt_embeds_2), dim=-1)
|
| 54 |
+
text_encoder_projection_dim = model.text_encoder_2.config.projection_dim
|
| 55 |
+
add_time_ids = model._get_add_time_ids((1024, 1024), (0, 0), (1024, 1024), torch.float16,
|
| 56 |
+
text_encoder_projection_dim).to(device)
|
| 57 |
+
added_cond_kwargs = {"text_embeds": pooled_prompt_embeds2, "time_ids": add_time_ids}
|
| 58 |
+
return added_cond_kwargs, prompt_embeds
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _encode_text_sdxl_with_negative(model: StableDiffusionXLPipeline, prompt: str) -> tuple[dict[str, T], T]:
|
| 62 |
+
added_cond_kwargs, prompt_embeds = _encode_text_sdxl(model, prompt)
|
| 63 |
+
added_cond_kwargs_uncond, prompt_embeds_uncond = _encode_text_sdxl(model, "")
|
| 64 |
+
prompt_embeds = torch.cat((prompt_embeds_uncond, prompt_embeds, ))
|
| 65 |
+
added_cond_kwargs = {"text_embeds": torch.cat((added_cond_kwargs_uncond["text_embeds"], added_cond_kwargs["text_embeds"])),
|
| 66 |
+
"time_ids": torch.cat((added_cond_kwargs_uncond["time_ids"], added_cond_kwargs["time_ids"])),}
|
| 67 |
+
return added_cond_kwargs, prompt_embeds
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _encode_image(model: StableDiffusionXLPipeline, image: np.ndarray) -> T:
|
| 71 |
+
model.vae.to(dtype=torch.float32)
|
| 72 |
+
image = torch.from_numpy(image).float() / 255.
|
| 73 |
+
image = (image * 2 - 1).permute(2, 0, 1).unsqueeze(0)
|
| 74 |
+
latent = model.vae.encode(image.to(model.vae.device))['latent_dist'].mean * model.vae.config.scaling_factor
|
| 75 |
+
model.vae.to(dtype=torch.float16)
|
| 76 |
+
return latent
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _next_step(model: StableDiffusionXLPipeline, model_output: T, timestep: int, sample: T) -> T:
|
| 80 |
+
timestep, next_timestep = min(timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps, 999), timestep
|
| 81 |
+
alpha_prod_t = model.scheduler.alphas_cumprod[int(timestep)] if timestep >= 0 else model.scheduler.final_alpha_cumprod
|
| 82 |
+
alpha_prod_t_next = model.scheduler.alphas_cumprod[int(next_timestep)]
|
| 83 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 84 |
+
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
| 85 |
+
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
| 86 |
+
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
|
| 87 |
+
return next_sample
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _get_noise_pred(model: StableDiffusionXLPipeline, latent: T, t: T, context: T, guidance_scale: float, added_cond_kwargs: dict[str, T]):
|
| 91 |
+
latents_input = torch.cat([latent] * 2)
|
| 92 |
+
noise_pred = model.unet(latents_input, t, encoder_hidden_states=context, added_cond_kwargs=added_cond_kwargs)["sample"]
|
| 93 |
+
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
|
| 94 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
| 95 |
+
# latents = next_step(model, noise_pred, t, latent)
|
| 96 |
+
return noise_pred
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _ddim_loop(model: StableDiffusionXLPipeline, z0, prompt, guidance_scale) -> T:
|
| 100 |
+
all_latent = [z0]
|
| 101 |
+
added_cond_kwargs, text_embedding = _encode_text_sdxl_with_negative(model, prompt)
|
| 102 |
+
latent = z0.clone().detach().half()
|
| 103 |
+
for i in tqdm(range(model.scheduler.num_inference_steps)):
|
| 104 |
+
t = model.scheduler.timesteps[len(model.scheduler.timesteps) - i - 1]
|
| 105 |
+
noise_pred = _get_noise_pred(model, latent, t, text_embedding, guidance_scale, added_cond_kwargs)
|
| 106 |
+
latent = _next_step(model, noise_pred, t, latent)
|
| 107 |
+
all_latent.append(latent)
|
| 108 |
+
return torch.cat(all_latent).flip(0)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def make_inversion_callback(zts, offset: int = 0) -> [T, InversionCallback]:
|
| 112 |
+
|
| 113 |
+
def callback_on_step_end(pipeline: StableDiffusionXLPipeline, i: int, t: T, callback_kwargs: dict[str, T]) -> dict[str, T]:
|
| 114 |
+
latents = callback_kwargs['latents']
|
| 115 |
+
latents[0] = zts[max(offset + 1, i + 1)].to(latents.device, latents.dtype)
|
| 116 |
+
return {'latents': latents}
|
| 117 |
+
return zts[offset], callback_on_step_end
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@torch.no_grad()
|
| 121 |
+
def ddim_inversion(model: StableDiffusionXLPipeline, x0: np.ndarray, prompt: str, num_inference_steps: int, guidance_scale,) -> T:
|
| 122 |
+
z0 = _encode_image(model, x0)
|
| 123 |
+
model.scheduler.set_timesteps(num_inference_steps, device=z0.device)
|
| 124 |
+
zs = _ddim_loop(model, z0, prompt, guidance_scale)
|
| 125 |
+
return zs
|