JuyeopDang's picture
Upload 35 files
5ab5cab verified
import torch
from diffusion_model.sampler.base_sampler import BaseSampler
class DDIM(BaseSampler):
def __init__(self, config_path):
super().__init__(config_path)
self.sampling_T = self.config['sampling_T']
step = self.T // self.sampling_T
self.timesteps = torch.arange(0, self.T, step, dtype=torch.int)
self.ddim_alpha = self.alpha_bar[self.timesteps]
self.sqrt_one_minus_alpha_bar = (1. - self.ddim_alpha).sqrt()
self.alpha_bar_prev = torch.cat([self.ddim_alpha[0:1], self.ddim_alpha[:-1]])
self.sigma = (self.config['eta'] *
torch.sqrt((1-self.alpha_bar_prev) / (1-self.ddim_alpha) *
(1 - self.ddim_alpha / self.alpha_bar_prev)))
def get_x_prev(self, x, tau, eps_hat) :
alpha_prev = self.alpha_bar_prev[tau]
sigma = self.sigma[tau]
x0_hat = (x - self.sqrt_one_minus_alpha_bar[tau] * eps_hat) \
/ (self.ddim_alpha[tau] ** 0.5)
dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * eps_hat
if sigma == 0. : noise = 0.
else : noise = torch.randn_like(x, device = x.device)
x = alpha_prev.sqrt() * x0_hat + dir_xt + sigma * noise
return x