File size: 2,481 Bytes
5ab5cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn
from tqdm import tqdm
import yaml

from helper.util import extract
from helper.beta_generator import BetaGenerator
from abc import ABC, abstractmethod

class BaseSampler(nn.Module, ABC):
    def __init__(self, config_path : str):
        super().__init__()
        with open(config_path, "r") as file:
            self.config = yaml.safe_load(file)['sampler']
        self.T = self.config['T']
        beta_generator = BetaGenerator(T=self.T)
        self.timesteps = None

        self.register_buffer('beta', getattr(beta_generator,
                                              f"{self.config['beta']}_beta_schedule",
                                              beta_generator.linear_beta_schedule)())

        self.register_buffer('alpha', 1 - self.beta)
        self.register_buffer('alpha_sqrt', self.alpha.sqrt())
        self.register_buffer('alpha_bar', torch.cumprod(self.alpha, dim = 0))

    @abstractmethod
    @torch.no_grad()
    def get_x_prev(self, x, t, idx, eps_hat):
        pass
        
    def set_network(self, network : nn.Module):
        self.network = network
        
    def q_sample(self, x0, t, eps = None):
        alpha_t_bar = extract(self.alpha_bar, t, x0.shape)
        if eps is None:
            eps = torch.randn_like(x0)
        q_xt_x0 = alpha_t_bar.sqrt() * x0 + (1 - alpha_t_bar).sqrt() * eps
        return q_xt_x0

    @torch.no_grad()
    def reverse_process(self, x_T, only_last=True, **kwargs):
        x = x_T
        if only_last:
            for i, t in tqdm(enumerate(reversed(self.timesteps))):
                idx = len(self.timesteps) - i - 1
                x = self.p_sample(x, t, idx, **kwargs)
            return x
        else:
            x_seq = []
            x_seq.append(x)
            for i, t in tqdm(enumerate(reversed(self.timesteps))):
                idx = len(self.timesteps) - i - 1
                x_seq.append(self.p_sample(x_seq[-1], t, idx, **kwargs))
            return x_seq
        
    @torch.no_grad()
    def p_sample(self, x, t, idx, gamma = None, **kwargs):
        eps_hat = self.network(x = x, t = t, **kwargs)
        if gamma is not None:
            eps_null = self.network(x = x, t = t, cond_drop_all=True, **kwargs)
            eps_hat = gamma * eps_hat + (1 - gamma) * eps_null
        x = self.get_x_prev(x, idx, eps_hat)
        return x

    @torch.no_grad()
    def forward(self, x_T, **kwargs):
        return self.reverse_process(x_T, **kwargs)