Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from tqdm import tqdm | |
import yaml | |
from helper.util import extract | |
from helper.beta_generator import BetaGenerator | |
from abc import ABC, abstractmethod | |
class BaseSampler(nn.Module, ABC): | |
def __init__(self, config_path : str): | |
super().__init__() | |
with open(config_path, "r") as file: | |
self.config = yaml.safe_load(file)['sampler'] | |
self.T = self.config['T'] | |
beta_generator = BetaGenerator(T=self.T) | |
self.timesteps = None | |
self.register_buffer('beta', getattr(beta_generator, | |
f"{self.config['beta']}_beta_schedule", | |
beta_generator.linear_beta_schedule)()) | |
self.register_buffer('alpha', 1 - self.beta) | |
self.register_buffer('alpha_sqrt', self.alpha.sqrt()) | |
self.register_buffer('alpha_bar', torch.cumprod(self.alpha, dim = 0)) | |
def get_x_prev(self, x, t, idx, eps_hat): | |
pass | |
def set_network(self, network : nn.Module): | |
self.network = network | |
def q_sample(self, x0, t, eps = None): | |
alpha_t_bar = extract(self.alpha_bar, t, x0.shape) | |
if eps is None: | |
eps = torch.randn_like(x0) | |
q_xt_x0 = alpha_t_bar.sqrt() * x0 + (1 - alpha_t_bar).sqrt() * eps | |
return q_xt_x0 | |
def reverse_process(self, x_T, only_last=True, **kwargs): | |
x = x_T | |
if only_last: | |
for i, t in tqdm(enumerate(reversed(self.timesteps))): | |
idx = len(self.timesteps) - i - 1 | |
x = self.p_sample(x, t, idx, **kwargs) | |
return x | |
else: | |
x_seq = [] | |
x_seq.append(x) | |
for i, t in tqdm(enumerate(reversed(self.timesteps))): | |
idx = len(self.timesteps) - i - 1 | |
x_seq.append(self.p_sample(x_seq[-1], t, idx, **kwargs)) | |
return x_seq | |
def p_sample(self, x, t, idx, gamma = None, **kwargs): | |
eps_hat = self.network(x = x, t = t, **kwargs) | |
if gamma is not None: | |
eps_null = self.network(x = x, t = t, cond_drop_all=True, **kwargs) | |
eps_hat = gamma * eps_hat + (1 - gamma) * eps_null | |
x = self.get_x_prev(x, idx, eps_hat) | |
return x | |
def forward(self, x_T, **kwargs): | |
return self.reverse_process(x_T, **kwargs) |