Spaces:
Running
Running
File size: 2,481 Bytes
5ab5cab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
import torch.nn as nn
from tqdm import tqdm
import yaml
from helper.util import extract
from helper.beta_generator import BetaGenerator
from abc import ABC, abstractmethod
class BaseSampler(nn.Module, ABC):
def __init__(self, config_path : str):
super().__init__()
with open(config_path, "r") as file:
self.config = yaml.safe_load(file)['sampler']
self.T = self.config['T']
beta_generator = BetaGenerator(T=self.T)
self.timesteps = None
self.register_buffer('beta', getattr(beta_generator,
f"{self.config['beta']}_beta_schedule",
beta_generator.linear_beta_schedule)())
self.register_buffer('alpha', 1 - self.beta)
self.register_buffer('alpha_sqrt', self.alpha.sqrt())
self.register_buffer('alpha_bar', torch.cumprod(self.alpha, dim = 0))
@abstractmethod
@torch.no_grad()
def get_x_prev(self, x, t, idx, eps_hat):
pass
def set_network(self, network : nn.Module):
self.network = network
def q_sample(self, x0, t, eps = None):
alpha_t_bar = extract(self.alpha_bar, t, x0.shape)
if eps is None:
eps = torch.randn_like(x0)
q_xt_x0 = alpha_t_bar.sqrt() * x0 + (1 - alpha_t_bar).sqrt() * eps
return q_xt_x0
@torch.no_grad()
def reverse_process(self, x_T, only_last=True, **kwargs):
x = x_T
if only_last:
for i, t in tqdm(enumerate(reversed(self.timesteps))):
idx = len(self.timesteps) - i - 1
x = self.p_sample(x, t, idx, **kwargs)
return x
else:
x_seq = []
x_seq.append(x)
for i, t in tqdm(enumerate(reversed(self.timesteps))):
idx = len(self.timesteps) - i - 1
x_seq.append(self.p_sample(x_seq[-1], t, idx, **kwargs))
return x_seq
@torch.no_grad()
def p_sample(self, x, t, idx, gamma = None, **kwargs):
eps_hat = self.network(x = x, t = t, **kwargs)
if gamma is not None:
eps_null = self.network(x = x, t = t, cond_drop_all=True, **kwargs)
eps_hat = gamma * eps_hat + (1 - gamma) * eps_null
x = self.get_x_prev(x, idx, eps_hat)
return x
@torch.no_grad()
def forward(self, x_T, **kwargs):
return self.reverse_process(x_T, **kwargs) |