Spaces:
Running
Running
import torch | |
from diffusion_model.sampler.base_sampler import BaseSampler | |
class DDIM(BaseSampler): | |
def __init__(self, config_path): | |
super().__init__(config_path) | |
self.sampling_T = self.config['sampling_T'] | |
step = self.T // self.sampling_T | |
self.timesteps = torch.arange(0, self.T, step, dtype=torch.int) | |
self.ddim_alpha = self.alpha_bar[self.timesteps] | |
self.sqrt_one_minus_alpha_bar = (1. - self.ddim_alpha).sqrt() | |
self.alpha_bar_prev = torch.cat([self.ddim_alpha[0:1], self.ddim_alpha[:-1]]) | |
self.sigma = (self.config['eta'] * | |
torch.sqrt((1-self.alpha_bar_prev) / (1-self.ddim_alpha) * | |
(1 - self.ddim_alpha / self.alpha_bar_prev))) | |
def get_x_prev(self, x, tau, eps_hat) : | |
alpha_prev = self.alpha_bar_prev[tau] | |
sigma = self.sigma[tau] | |
x0_hat = (x - self.sqrt_one_minus_alpha_bar[tau] * eps_hat) \ | |
/ (self.ddim_alpha[tau] ** 0.5) | |
dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * eps_hat | |
if sigma == 0. : noise = 0. | |
else : noise = torch.randn_like(x, device = x.device) | |
x = alpha_prev.sqrt() * x0_hat + dir_xt + sigma * noise | |
return x |