File size: 1,240 Bytes
5ab5cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch

from diffusion_model.sampler.base_sampler import BaseSampler

class DDIM(BaseSampler):
    def __init__(self, config_path):
        super().__init__(config_path)
        self.sampling_T = self.config['sampling_T']
        step = self.T // self.sampling_T
        self.timesteps = torch.arange(0, self.T, step, dtype=torch.int)

        self.ddim_alpha = self.alpha_bar[self.timesteps]
        self.sqrt_one_minus_alpha_bar = (1. - self.ddim_alpha).sqrt()
        self.alpha_bar_prev = torch.cat([self.ddim_alpha[0:1], self.ddim_alpha[:-1]])
        self.sigma = (self.config['eta'] * 
                            torch.sqrt((1-self.alpha_bar_prev) / (1-self.ddim_alpha) *
                            (1 - self.ddim_alpha / self.alpha_bar_prev)))

    def get_x_prev(self, x, tau, eps_hat) :
        alpha_prev = self.alpha_bar_prev[tau]
        sigma = self.sigma[tau]

        x0_hat = (x - self.sqrt_one_minus_alpha_bar[tau] * eps_hat) \
           / (self.ddim_alpha[tau] ** 0.5)
        dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * eps_hat
        if sigma == 0. : noise = 0.
        else : noise = torch.randn_like(x, device = x.device)
        x = alpha_prev.sqrt() * x0_hat + dir_xt + sigma * noise
        return x