Spaces:
Running
Running
File size: 773 Bytes
5ab5cab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
from diffusion_model.sampler.base_sampler import BaseSampler
class DDPM(BaseSampler):
def __init__(self, config_path):
super().__init__(config_path)
self.timesteps = torch.arange(0, self.T, dtype=torch.int)
self.sqrt_one_minus_alpha_bar = (1. - self.alpha_bar).sqrt()
self.alpha_bar_prev = torch.cat([self.alpha_bar[0:1], self.alpha_bar[:-1]])
self.sigma = (((1 - self.alpha_bar_prev) / (1 - self.alpha_bar)) * self.beta).sqrt()
@torch.no_grad()
def get_x_prev(self, x, t, eps_hat):
x = (1 / self.alpha_sqrt[t]) \
* (x - (self.beta[t] / self.sqrt_one_minus_alpha_bar[t] * eps_hat))
z = torch.randn_like(x) if t > 0 else 0.
x = x + self.sigma[t] * z
return x
|