Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import Tensor | |
import torch.nn.functional as F | |
from typing import Optional, Union, Tuple | |
from utils import normalize | |
from model import freq_exp, gen_nn_map | |
from src.ddpm_step import deterministic_ddpm_step | |
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput | |
# Kernel sizes for the DIFT correction at successive time-ranges | |
DIFT_KERNELS: Tuple[int, int, int, int] = (12, 7, 5, 3) | |
def _get_kernel_for_timestep(timestep: int) -> Tuple[int, int]: | |
if timestep >= 799: | |
return DIFT_KERNELS[0], 1 | |
if timestep >= 599: | |
return DIFT_KERNELS[1], 1 | |
if timestep >= 299: | |
return DIFT_KERNELS[2], 1 | |
return DIFT_KERNELS[3], 1 | |
def step_save_latents( | |
self, | |
model_output: torch.FloatTensor, | |
timestep: int, | |
sample: torch.FloatTensor, | |
return_dict: bool = True, | |
noise_pred_uncond: Optional[torch.FloatTensor] = None, | |
**kwargs, | |
): | |
timestep_index = self._timesteps.index(timestep) | |
next_timestep_index = timestep_index + 1 | |
u_hat_t, beta_coef = deterministic_ddpm_step( | |
model_output=model_output, | |
timestep=timestep, | |
sample=sample, | |
scheduler=self, | |
) | |
x_t_minus_1 = self.x_ts[next_timestep_index] | |
self.x_ts_c_predicted.append(u_hat_t) | |
z_t = x_t_minus_1 - u_hat_t | |
self.latents.append(z_t) | |
z_t, _ = normalize(z_t, timestep_index, self._config.max_norm_zs) | |
x_t_minus_1_predicted = u_hat_t + z_t | |
if not return_dict: | |
return (x_t_minus_1_predicted,) | |
return DDIMSchedulerOutput(prev_sample=x_t_minus_1, pred_original_sample=None) | |
def step_use_latents( | |
self, | |
model_output: torch.FloatTensor, | |
timestep: int, | |
sample: torch.FloatTensor, | |
return_dict: bool = True, | |
noise_pred_uncond: Optional[torch.FloatTensor] = None, | |
**kwargs, | |
): | |
timestep_index = self._timesteps.index(timestep) | |
next_timestep_index = timestep_index + 1 | |
z_t = self.latents[next_timestep_index] | |
_, normalize_coefficient = normalize( | |
z_t, | |
timestep_index, | |
self._config.max_norm_zs, | |
) | |
x_t_hat_c_hat, beta_coef = deterministic_ddpm_step( | |
model_output=model_output, | |
timestep=timestep, | |
sample=sample, | |
scheduler=self, | |
) | |
x_t_minus_1_exact = self.x_ts[next_timestep_index] | |
x_t_minus_1_exact = x_t_minus_1_exact.expand_as(x_t_hat_c_hat) | |
x_t_c_predicted: torch.Tensor = self.x_ts_c_predicted[next_timestep_index] | |
x_t_c = x_t_c_predicted[0].expand_as(x_t_hat_c_hat) | |
mask: Optional[Tensor] = kwargs.get("mask", None) | |
if mask is not None and timestep > 300: | |
mask = mask.to(x_t_hat_c_hat.device) | |
movement_intensifier = kwargs.get("movement_intensifier", 0.0) | |
if timestep > 900 and movement_intensifier > 0.0: | |
latent_mask_h, *_ = freq_exp( | |
x_t_hat_c_hat[1:], | |
"auto_mask", | |
None, | |
mask.unsqueeze(0), | |
movement_intensifier | |
) | |
x_t_hat_c_hat[1:] = latent_mask_h | |
x_t_hat_c_hat[-1] = x_t_hat_c_hat[-1] * mask + (1-mask) * x_t_c[-1] | |
edit_prompts_num = model_output.size(0) // 2 | |
x_t_hat_c_indices = ( | |
0, | |
edit_prompts_num, | |
) | |
edit_images_indices = ( | |
edit_prompts_num, | |
(model_output.size(0)), | |
) | |
x_t_hat_c = torch.zeros_like(x_t_hat_c_hat) | |
x_t_hat_c[edit_images_indices[0] : edit_images_indices[1]] = x_t_hat_c_hat[ | |
x_t_hat_c_indices[0] : x_t_hat_c_indices[1] | |
] | |
w1 = kwargs.get("w1", 1.9) | |
cross_prompt_term = x_t_hat_c_hat - x_t_hat_c | |
cross_trajectory_term = x_t_hat_c - normalize_coefficient * x_t_c | |
x_t_minus_1_hat_ = ( | |
normalize_coefficient * x_t_minus_1_exact | |
+ cross_trajectory_term | |
+ w1 * cross_prompt_term | |
) | |
x_t_minus_1_hat_[x_t_hat_c_indices[0] : x_t_hat_c_indices[1]] = x_t_minus_1_hat_[ | |
edit_images_indices[0] : edit_images_indices[1] | |
] | |
dift_timestep = kwargs.get("dift_timestep", 700) | |
if timestep < dift_timestep and kwargs.get("apply_dift_correction", False): | |
z_t = torch.cat([z_t]*x_t_hat_c_hat.shape[0], dim=0) | |
dift_features: Optional[Tensor] = kwargs.get("dift_features", None) | |
dift_s, _, dift_t = dift_features.chunk(3) | |
resized_src_features = F.interpolate(dift_s[0].unsqueeze(0), size=z_t.shape[-1], mode='bilinear', align_corners=False).squeeze(0) | |
resized_tgt_features = F.interpolate(dift_t[0].unsqueeze(0), size=z_t.shape[-1], mode='bilinear', align_corners=False).squeeze(0) | |
kernel_size, stride = _get_kernel_for_timestep(timestep) | |
torch.cuda.empty_cache() | |
updated_z_t = gen_nn_map(z_t[1], resized_src_features, resized_tgt_features, | |
kernel_size=kernel_size, stride=stride, | |
device=z_t.device, timestep=timestep) | |
alpha = 1.0 | |
z_t[1] = alpha * updated_z_t + (1 - alpha) * z_t[1] | |
x_t_minus_1_hat = x_t_hat_c_hat + z_t * normalize_coefficient | |
else: | |
x_t_minus_1_hat = x_t_minus_1_hat_ | |
if not return_dict: | |
return (x_t_minus_1_hat,) | |
return DDIMSchedulerOutput( | |
prev_sample=x_t_minus_1_hat, | |
pred_original_sample=None, | |
) | |
def get_ddpm_inversion_scheduler( | |
scheduler, | |
config, | |
timesteps, | |
latents, | |
x_ts, | |
**kwargs, | |
): | |
def step( | |
model_output: torch.FloatTensor, | |
timestep: int, | |
sample: torch.FloatTensor, | |
eta: float = 0.0, | |
use_clipped_model_output: bool = False, | |
generator=None, | |
variance_noise: Optional[torch.FloatTensor] = None, | |
noise_pred_uncond: Optional[torch.FloatTensor] = None, | |
dift_features: Optional[torch.FloatTensor] = None, | |
return_dict: bool = True, | |
): | |
# predict and save x_t_c | |
res_inv = step_save_latents( | |
scheduler, | |
model_output[:1, :, :, :], | |
timestep, | |
sample[:1, :, :, :], | |
return_dict, | |
noise_pred_uncond[:1, :, :, :], | |
**kwargs, | |
) | |
res_inf = step_use_latents( | |
scheduler, | |
model_output[1:, :, :, :], | |
timestep, | |
sample[1:, :, :, :], | |
return_dict, | |
noise_pred_uncond[1:, :, :, :], | |
dift_features=dift_features, | |
**kwargs, | |
) | |
res = (torch.cat((res_inv[0], res_inf[0]), dim=0),) | |
return res | |
scheduler._timesteps = timesteps | |
scheduler._config = config | |
scheduler.latents = latents | |
scheduler.x_ts = x_ts | |
scheduler.x_ts_c_predicted = [None] | |
scheduler.step = step | |
return scheduler | |