|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
|
|
class ModelSamplingDiscreteFlow(nn.Module): |
|
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models""" |
|
|
|
def __init__(self, num_train_timesteps=1000, shift=1.0, **kwargs): |
|
super().__init__() |
|
self.num_train_timesteps = num_train_timesteps |
|
self.shift = shift |
|
ts = self.to_sigma(torch.arange(1, num_train_timesteps + 1, 1)) |
|
self.register_buffer("sigmas", ts) |
|
|
|
@property |
|
def sigma_min(self): |
|
return self.sigmas[0] |
|
|
|
@property |
|
def sigma_max(self): |
|
return self.sigmas[-1] |
|
|
|
def to_timestep(self, sigma): |
|
return sigma * self.num_train_timesteps |
|
|
|
def to_sigma(self, timestep: torch.Tensor): |
|
timestep = timestep / self.num_train_timesteps |
|
if self.shift == 1.0: |
|
return timestep |
|
return self.shift * timestep / (1 + (self.shift - 1) * timestep) |
|
|
|
def uniform_sample_t(self, batch_size, device): |
|
ts = (self.sigma_max - self.sigma_min) * torch.rand(batch_size, device=device) + self.sigma_min |
|
return ts |
|
|
|
def calculate_denoised(self, sigma, model_output, model_input): |
|
|
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) |
|
return model_input - model_output * sigma |
|
|
|
def noise_scaling(self, sigma, noise, latent_image): |
|
return sigma * noise + (1.0 - sigma) * latent_image |
|
|
|
def add_noise(self, sample, noise=None, timesteps=None): |
|
|
|
if timesteps is None: |
|
|
|
batch_size = sample.shape[0] |
|
sigmas = self.uniform_sample_t(batch_size, device=sample.device).to(dtype=sample.dtype) |
|
timesteps = self.to_timestep(sigmas) |
|
else: |
|
timesteps = timesteps.to(device=sample.device, dtype=sample.dtype) |
|
sigmas = self.to_sigma(timesteps) |
|
|
|
sigmas = sigmas.view(-1, 1, 1) |
|
noise = torch.randn_like(sample) |
|
noisy_samples = sigmas * noise + (1.0 - sigmas) * sample |
|
return noisy_samples, noise, noise - sample, timesteps |
|
|
|
def set_timesteps(self, num_inference_steps, device=None): |
|
if num_inference_steps > self.num_train_timesteps: |
|
raise ValueError( |
|
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" |
|
f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" |
|
f" maximal {self.num_train_timesteps} timesteps." |
|
) |
|
|
|
self.num_inference_steps = num_inference_steps |
|
|
|
start = self.to_timestep(self.sigma_max) |
|
end = self.to_timestep(self.sigma_min) |
|
timesteps = torch.linspace(start, end, num_inference_steps) |
|
|
|
self.timesteps = torch.from_numpy(np.array(timesteps)).to(device) |
|
|
|
def append_dims(self, x, target_dims): |
|
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" |
|
dims_to_append = target_dims - x.ndim |
|
return x[(...,) + (None,) * dims_to_append] |
|
|
|
def to_d(self, x, sigma, denoised): |
|
"""Converts a denoiser output to a Karras ODE derivative.""" |
|
return (x - denoised) / self.append_dims(sigma, x.ndim) |
|
|
|
@torch.no_grad() |
|
def step(self, model_output, timestep, sample, method="euler", **kwargs): |
|
""" |
|
Args: |
|
model_output (`torch.Tensor`): |
|
The direct output from learned diffusion model, direction (noise - x_0). |
|
timestep (`float`): |
|
The current discrete timestep in the diffusion chain. |
|
sample (`torch.Tensor`): |
|
A current instance of a sample created by the diffusion process, x_t. |
|
method (`str`): |
|
ODE solver, `euler` or `dpmpp_2m` |
|
|
|
Returns: |
|
`tuple`: |
|
the sample tensor. |
|
""" |
|
|
|
if self.num_inference_steps is None: |
|
raise ValueError( |
|
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" |
|
) |
|
sigma = self.to_sigma(timestep) |
|
prev_sigma = sigma - (self.sigma_max - self.sigma_min) / (self.num_inference_steps - 1) |
|
prev_sigma = 0.0 if prev_sigma < 0.0 else prev_sigma |
|
|
|
if method == "euler": |
|
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" |
|
dt = prev_sigma - sigma |
|
prev_sample = sample + model_output * dt |
|
elif method == "dpmpp_2m": |
|
"""DPM-Solver++(2M).""" |
|
raise NotImplementedError |
|
else: |
|
raise ValueError(f"Unsupported ode solver: {method}, only supports `euler` or `dpmpp_2m`") |
|
|
|
pred_original_sample = sample - model_output * sigma |
|
|
|
return ( |
|
prev_sample, |
|
pred_original_sample |
|
) |
|
|
|
def get_pred_original_sample(self, model_output, timestep, sample): |
|
sigma = self.to_sigma(timestep).view(-1, 1, 1) |
|
pred_original_sample = sample - model_output * sigma |
|
|
|
return pred_original_sample |