|
import torch |
|
from torch import nn |
|
|
|
|
|
|
|
|
|
|
|
|
|
class EDMPrecond(torch.nn.Module): |
|
def __init__( |
|
self, |
|
network, |
|
label_dim=0, |
|
sigma_min=0, |
|
sigma_max=float("inf"), |
|
sigma_data=0.5, |
|
): |
|
super().__init__() |
|
self.label_dim = label_dim |
|
self.sigma_min = sigma_min |
|
self.sigma_max = sigma_max |
|
self.sigma_data = sigma_data |
|
self.network = network |
|
|
|
def forward(self, x, sigma, conditioning=None, **network_kwargs): |
|
x = x.to(torch.float32) |
|
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) |
|
conditioning = ( |
|
None |
|
if self.label_dim == 0 |
|
else torch.zeros([1, self.label_dim], device=x.device) |
|
if conditioning is None |
|
else conditioning.to(torch.float32) |
|
) |
|
|
|
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) |
|
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() |
|
c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() |
|
c_noise = sigma.log() / 4 |
|
|
|
F_x = self.network( |
|
(c_in * x), |
|
c_noise.flatten(), |
|
conditioning=conditioning, |
|
**network_kwargs, |
|
) |
|
D_x = c_skip * x + c_out * F_x.to(torch.float32) |
|
return D_x |
|
|
|
def round_sigma(self, sigma): |
|
return torch.as_tensor(sigma) |
|
|
|
|
|
class DDPMPrecond(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, network, batch): |
|
F_x = network(batch) |
|
return F_x |
|
|