Spaces:
Running
Running
import torch | |
from diffusion_model.sampler.base_sampler import BaseSampler | |
class DDPM(BaseSampler): | |
def __init__(self, config_path): | |
super().__init__(config_path) | |
self.timesteps = torch.arange(0, self.T, dtype=torch.int) | |
self.sqrt_one_minus_alpha_bar = (1. - self.alpha_bar).sqrt() | |
self.alpha_bar_prev = torch.cat([self.alpha_bar[0:1], self.alpha_bar[:-1]]) | |
self.sigma = (((1 - self.alpha_bar_prev) / (1 - self.alpha_bar)) * self.beta).sqrt() | |
def get_x_prev(self, x, t, eps_hat): | |
x = (1 / self.alpha_sqrt[t]) \ | |
* (x - (self.beta[t] / self.sqrt_one_minus_alpha_bar[t] * eps_hat)) | |
z = torch.randn_like(x) if t > 0 else 0. | |
x = x + self.sigma[t] * z | |
return x | |