Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.distributions as td | |
| import numpy as np | |
| class Swish(nn.Module): | |
| def __init__(self): | |
| super(Swish, self).__init__() | |
| def forward(self, x): | |
| return x * torch.sigmoid(x) | |
| def cycle_interval(starting_value, num_frames, min_val, max_val): | |
| """Cycles through the state space in a single cycle.""" | |
| starting_in_01 = ((starting_value - min_val) / (max_val - min_val)).cpu() | |
| grid = torch.linspace(starting_in_01.item(), starting_in_01.item() + 2., steps=num_frames + 1)[:-1] | |
| grid -= np.maximum(0, 2 * grid - 2) | |
| grid += np.maximum(0, -2 * grid) | |
| return grid * (max_val - min_val) + min_val | |
| class BetaVAE_Linear(nn.Module): | |
| def __init__(self, in_dim=1024, n_hidden=64, latent=8): | |
| super(BetaVAE_Linear, self).__init__() | |
| self.n_hidden = n_hidden | |
| self.latent = latent | |
| # Encoder | |
| self.encoder = nn.Sequential( | |
| nn.Linear(in_dim, n_hidden), Swish(), | |
| ) | |
| # Latent | |
| self.mu = nn.Linear(n_hidden, latent) | |
| self.lv = nn.Linear(n_hidden, latent) | |
| # Decoder | |
| self.decoder = nn.Sequential( | |
| nn.Linear(latent, n_hidden), Swish(), | |
| nn.Linear(n_hidden, in_dim), Swish() | |
| ) | |
| def BottomUp(self, x): | |
| out = self.encoder(x) | |
| mu, lv = self.mu(out), self.lv(out) | |
| return mu, lv | |
| def reparameterize(self, mu, lv): | |
| std = torch.exp(0.5 * lv) | |
| eps = torch.randn_like(std) | |
| return mu + std * eps | |
| def TopDown(self, z): | |
| out = self.decoder(z) | |
| return out | |
| def forward(self, x): | |
| # x = x.view(x.shape[0], -1) | |
| mu, lv = self.BottomUp(x) | |
| z = self.reparameterize(mu, lv) | |
| out = self.TopDown(z) | |
| return out, mu, lv | |
| def calc_loss(self, x, beta): | |
| mu, lv = self.BottomUp(x) | |
| z = self.reparameterize(mu, lv) | |
| out = torch.sigmoid(self.TopDown(z)) | |
| nll = -nn.functional.binary_cross_entropy(out, x, reduction='sum') / x.shape[0] | |
| kl = (-0.5 * torch.sum(1 + lv - mu.pow(2) - lv.exp()) + 1e-5) / x.shape[0] | |
| # print(kl, nll) | |
| return -nll + kl * beta, kl, nll | |
| def LT_fitted_gauss_2std(self, x,num_var=6, num_traversal=5): | |
| # Cycle linearly through +-2 std dev of a fitted Gaussian. | |
| x = x.view(x.shape[0], -1) | |
| mu, lv = self.BottomUp(x) | |
| images = [] | |
| for i, batch_mu in enumerate(mu[:num_var]): | |
| images.append(torch.sigmoid(self.TopDown(batch_mu)).unsqueeze(0)) | |
| for latent_var in range(batch_mu.shape[0]): | |
| new_mu = batch_mu.unsqueeze(0).repeat([num_traversal, 1]) | |
| loc = mu[:, latent_var].mean() | |
| total_var = lv[:, latent_var].exp().mean() + mu[:, latent_var].var() | |
| scale = total_var.sqrt() | |
| new_mu[:, latent_var] = cycle_interval(batch_mu[latent_var], num_traversal, | |
| loc - 2 * scale, loc + 2 * scale) | |
| images.append(torch.sigmoid(self.TopDown(new_mu))) | |
| return images | |
| if __name__ == "__main__": | |
| model = BetaVAE_Linear() | |
| x = torch.rand(10, 784) | |
| out = model(x) | |
| print(out.shape) | |
| loss, kl, nll = model.calc_loss(x, 0.05) | |
| print(loss, kl, nll) | |
| images = model.LT_fitted_gauss_2std(x) | |
| print(len(images), images[0].shape) | |
| print(images[0].shape) |