Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| import math | |
| import numpy as np | |
| def clip_noise_schedule(alphas2, clip_value=0.001): | |
| """ | |
| For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. This may help improve stability during | |
| sampling. | |
| """ | |
| alphas2 = np.concatenate([np.ones(1), alphas2], axis=0) | |
| alphas_step = (alphas2[1:] / alphas2[:-1]) | |
| alphas_step = np.clip(alphas_step, a_min=clip_value, a_max=1.) | |
| alphas2 = np.cumprod(alphas_step, axis=0) | |
| return alphas2 | |
| def polynomial_schedule(timesteps: int, s=1e-4, power=3.): | |
| """ | |
| A noise schedule based on a simple polynomial equation: 1 - x^power. | |
| """ | |
| steps = timesteps + 1 | |
| x = np.linspace(0, steps, steps) | |
| alphas2 = (1 - np.power(x / steps, power)) ** 2 | |
| alphas2 = clip_noise_schedule(alphas2, clip_value=0.001) | |
| precision = 1 - 2 * s | |
| alphas2 = precision * alphas2 + s | |
| return alphas2 | |
| def cosine_beta_schedule(timesteps, s=0.008, raise_to_power: float = 1): | |
| """ | |
| cosine schedule | |
| as proposed in https://openreview.net/forum?id=-NEXDKk8gZ | |
| """ | |
| steps = timesteps + 2 | |
| x = np.linspace(0, steps, steps) | |
| alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 | |
| alphas_cumprod = alphas_cumprod / alphas_cumprod[0] | |
| betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) | |
| betas = np.clip(betas, a_min=0, a_max=0.999) | |
| alphas = 1. - betas | |
| alphas_cumprod = np.cumprod(alphas, axis=0) | |
| if raise_to_power != 1: | |
| alphas_cumprod = np.power(alphas_cumprod, raise_to_power) | |
| return alphas_cumprod | |
| class PositiveLinear(torch.nn.Module): | |
| """Linear layer with weights forced to be positive.""" | |
| def __init__(self, in_features: int, out_features: int, bias: bool = True, | |
| weight_init_offset: int = -2): | |
| super(PositiveLinear, self).__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.weight = torch.nn.Parameter( | |
| torch.empty((out_features, in_features))) | |
| if bias: | |
| self.bias = torch.nn.Parameter(torch.empty(out_features)) | |
| else: | |
| self.register_parameter('bias', None) | |
| self.weight_init_offset = weight_init_offset | |
| self.reset_parameters() | |
| def reset_parameters(self) -> None: | |
| torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | |
| with torch.no_grad(): | |
| self.weight.add_(self.weight_init_offset) | |
| if self.bias is not None: | |
| fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) | |
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 | |
| torch.nn.init.uniform_(self.bias, -bound, bound) | |
| def forward(self, x): | |
| positive_weight = F.softplus(self.weight) | |
| return F.linear(x, positive_weight, self.bias) | |
| class PredefinedNoiseSchedule(torch.nn.Module): | |
| """ | |
| Predefined noise schedule. Essentially creates a lookup array for predefined (non-learned) noise schedules. | |
| """ | |
| def __init__(self, noise_schedule, timesteps, precision): | |
| super(PredefinedNoiseSchedule, self).__init__() | |
| self.timesteps = timesteps | |
| if noise_schedule == 'cosine': | |
| alphas2 = cosine_beta_schedule(timesteps) | |
| elif 'polynomial' in noise_schedule: | |
| splits = noise_schedule.split('_') | |
| assert len(splits) == 2 | |
| power = float(splits[1]) | |
| alphas2 = polynomial_schedule(timesteps, s=precision, power=power) | |
| else: | |
| raise ValueError(noise_schedule) | |
| # print('alphas2', alphas2) | |
| sigmas2 = 1 - alphas2 | |
| log_alphas2 = np.log(alphas2) | |
| log_sigmas2 = np.log(sigmas2) | |
| log_alphas2_to_sigmas2 = log_alphas2 - log_sigmas2 | |
| # print('gamma', -log_alphas2_to_sigmas2) | |
| self.gamma = torch.nn.Parameter( | |
| torch.from_numpy(-log_alphas2_to_sigmas2).float(), | |
| requires_grad=False) | |
| def forward(self, t): | |
| t_int = torch.round(t * self.timesteps).long() | |
| return self.gamma[t_int] | |
| class GammaNetwork(torch.nn.Module): | |
| """The gamma network models a monotonic increasing function. Construction as in the VDM paper.""" | |
| def __init__(self): | |
| super().__init__() | |
| self.l1 = PositiveLinear(1, 1) | |
| self.l2 = PositiveLinear(1, 1024) | |
| self.l3 = PositiveLinear(1024, 1) | |
| self.gamma_0 = torch.nn.Parameter(torch.tensor([-5.])) | |
| self.gamma_1 = torch.nn.Parameter(torch.tensor([10.])) | |
| self.show_schedule() | |
| def show_schedule(self, num_steps=50): | |
| t = torch.linspace(0, 1, num_steps).view(num_steps, 1) | |
| gamma = self.forward(t) | |
| print('Gamma schedule:') | |
| print(gamma.detach().cpu().numpy().reshape(num_steps)) | |
| def gamma_tilde(self, t): | |
| l1_t = self.l1(t) | |
| return l1_t + self.l3(torch.sigmoid(self.l2(l1_t))) | |
| def forward(self, t): | |
| zeros, ones = torch.zeros_like(t), torch.ones_like(t) | |
| # Not super efficient. | |
| gamma_tilde_0 = self.gamma_tilde(zeros) | |
| gamma_tilde_1 = self.gamma_tilde(ones) | |
| gamma_tilde_t = self.gamma_tilde(t) | |
| # Normalize to [0, 1] | |
| normalized_gamma = (gamma_tilde_t - gamma_tilde_0) / ( | |
| gamma_tilde_1 - gamma_tilde_0) | |
| # Rescale to [gamma_0, gamma_1] | |
| gamma = self.gamma_0 + (self.gamma_1 - self.gamma_0) * normalized_gamma | |
| return gamma | |