from abc import abstractmethod from functools import partial import numpy as np import torch from ...modules.diffusionmodules.util import make_beta_schedule from ...util import append_zero def generate_roughly_equally_spaced_steps(num_substeps: int, max_step: int) -> np.ndarray: return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] class Discretization: def __call__(self, n, do_append_zero=True, device="cpu", flip=False): sigmas = self.get_sigmas(n, device=device) sigmas = append_zero(sigmas) if do_append_zero else sigmas return sigmas if not flip else torch.flip(sigmas, (0,)) @abstractmethod def get_sigmas(self, n, device): pass class EDMDiscretization(Discretization): def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): self.sigma_min = sigma_min self.sigma_max = sigma_max self.rho = rho def get_sigmas(self, n, device="cpu"): ramp = torch.linspace(0, 1, n, device=device) min_inv_rho = self.sigma_min ** (1 / self.rho) max_inv_rho = self.sigma_max ** (1 / self.rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho print(sigmas) return sigmas class AYSDiscretization(Discretization): def __init__(self): self.sigma_min = 0.002 self.sigma_max = 700.0 self.base_sigmas = np.array([700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]) def loglinear_interp(self, t_steps, num_steps): """ Performs log-linear interpolation of a given array of decreasing numbers. """ xs = np.linspace(0, 1, len(t_steps)) ys = np.log(t_steps[::-1]) new_xs = np.linspace(0, 1, num_steps) new_ys = np.interp(new_xs, xs, ys) interped_ys = np.exp(new_ys)[::-1].copy() return interped_ys def get_sigmas(self, n, device="cpu"): assert n >= 10, "Number of timesteps must be greater than 10 for AYS discretization." if n > 10: sigmas = self.loglinear_interp(self.base_sigmas, n) else: sigmas = self.base_sigmas return torch.from_numpy(sigmas).to(device) class LegacyDDPMDiscretization(Discretization): def __init__( self, linear_start=0.00085, linear_end=0.0120, num_timesteps=1000, ): super().__init__() self.num_timesteps = num_timesteps betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end) alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.to_torch = partial(torch.tensor, dtype=torch.float32) def get_sigmas(self, n, device="cpu"): if n < self.num_timesteps: timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) alphas_cumprod = self.alphas_cumprod[timesteps] elif n == self.num_timesteps: alphas_cumprod = self.alphas_cumprod else: raise ValueError to_torch = partial(torch.tensor, dtype=torch.float32, device=device) sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 return torch.flip(sigmas, (0,))