Spaces:
Runtime error
Runtime error
| # Source: https://github.com/lissomx/MSP/blob/master/M_ModelAE_Cnn.py | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import numpy as np | |
| class Encoder(nn.Module): | |
| # only for square pics with width or height is n^(2x) | |
| def __init__(self, image_size, nf, hidden_size=None, nc=3): | |
| super(Encoder, self).__init__() | |
| self.image_size = image_size | |
| self.hidden_size = hidden_size | |
| sequens = [ | |
| nn.Conv2d(nc, nf, 4, 2, 1, bias=False), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ] | |
| while(True): | |
| image_size = image_size/2 | |
| if image_size > 4: | |
| sequens.append(nn.Conv2d(nf, nf * 2, 4, 2, 1, bias=False)) | |
| sequens.append(nn.BatchNorm2d(nf * 2)) | |
| sequens.append(nn.LeakyReLU(0.2, inplace=True)) | |
| nf = nf * 2 | |
| else: | |
| if hidden_size is None: | |
| self.hidden_size = int(nf) | |
| sequens.append(nn.Conv2d(nf, self.hidden_size, int(image_size), 1, 0, bias=False)) | |
| break | |
| self.main = nn.Sequential(*sequens) | |
| def forward(self, input): | |
| return self.main(input).squeeze(3).squeeze(2) | |
| class Decoder(nn.Module): | |
| # only for square pics with width or height is n^(2x) | |
| def __init__(self, image_size, nf, hidden_size=None, nc=3): | |
| super(Decoder, self).__init__() | |
| self.image_size = image_size | |
| self.hidden_size = hidden_size | |
| sequens = [ | |
| nn.Tanh(), | |
| nn.ConvTranspose2d(nf, nc, 4, 2, 1, bias=False), | |
| ] | |
| while(True): | |
| image_size = image_size/2 | |
| sequens.append(nn.ReLU(True)) | |
| sequens.append(nn.BatchNorm2d(nf)) | |
| if image_size > 4: | |
| sequens.append(nn.ConvTranspose2d(nf * 2, nf, 4, 2, 1, bias=False)) | |
| else: | |
| if hidden_size is None: | |
| self.hidden_size = int(nf) | |
| sequens.append(nn.ConvTranspose2d(self.hidden_size, nf, int(image_size), 1, 0, bias=False)) | |
| break | |
| nf = nf*2 | |
| sequens.reverse() | |
| self.main = nn.Sequential(*sequens) | |
| def forward(self, z): | |
| z = z.unsqueeze(2).unsqueeze(2) | |
| output = self.main(z) | |
| return output | |
| def loss(self, predict, orig): | |
| batch_size = predict.shape[0] | |
| a = predict.view(batch_size, -1) | |
| b = orig.view(batch_size, -1) | |
| L = F.mse_loss(a, b, reduction='sum') | |
| return L | |
| class CnnVae(nn.Module): | |
| def __init__(self, learning_rate, image_size, label_size, nf, hidden_size=None, nc=3): | |
| super(CnnVae, self).__init__() | |
| self.encoder = Encoder(image_size, nf, hidden_size, nc) | |
| self.decoder = Decoder(image_size, nf, hidden_size, nc) | |
| self.image_size = image_size | |
| self.nc = nc | |
| self.label_size = label_size | |
| self.hidden_size = self.encoder.hidden_size | |
| self.learning_rate = learning_rate | |
| self.fc1 = nn.Linear(self.hidden_size, self.hidden_size) | |
| self.fc2 = nn.Linear(self.hidden_size, self.hidden_size) | |
| self.M = nn.Parameter(torch.empty(label_size, self.hidden_size)) | |
| nn.init.xavier_normal_(self.M) | |
| def encode(self, x): | |
| h = self.encoder(x) | |
| mu = self.fc1(h) | |
| logvar = self.fc2(h) | |
| return mu, logvar | |
| def reparameterize(self, mu, logvar): | |
| std = torch.exp(0.5*logvar) | |
| eps = torch.randn_like(std) | |
| return mu + eps*std | |
| def forward(self, x): | |
| # breakpoint() | |
| mu, logvar = self.encode(x) | |
| z = self.reparameterize(mu, logvar) | |
| prod = self.decoder(z) | |
| outputs = {'output': prod} # DISENTANGLER_DECODER_OUTPUT | |
| # outputs[DISENTANGLER_ATTRIBUTE_OUTPUT] = attr | |
| outputs['embedding'] = z | |
| outputs['vae_mu'] = mu | |
| outputs['vae_logvar'] = logvar | |
| # return prod, z, mu, logvar | |
| return outputs | |
| def _loss_vae(self, mu, logvar): | |
| # https://arxiv.org/abs/1312.6114 | |
| # KLD = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) | |
| KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) | |
| return KLD | |
| def _loss_msp(self, label, z): | |
| labels_one_hot = F.one_hot(label, num_classes=self.label_size) | |
| labels_one_hot = labels_one_hot.to(dtype=torch.float32) | |
| labels_one_hot[labels_one_hot == 0.0] = -1 | |
| L1 = F.mse_loss((z @ self.M.t()).view(-1), labels_one_hot.view(-1), reduction="none").sum() | |
| L2 = F.mse_loss((labels_one_hot @ self.M).view(-1), z.view(-1), reduction="none").sum() | |
| return L1 + L2, L1, L2 | |
| def loss(self, prod, orgi, label, z, mu, logvar): | |
| L_rec = self.decoder.loss(prod, orgi) | |
| L_vae = self._loss_vae(mu, logvar) | |
| L_msp, L1_msp, L2_msp = self._loss_msp(label, z) | |
| _msp_weight = orgi.numel()/(label.numel()+z.numel()) | |
| Loss = L_rec + L_vae + L_msp * _msp_weight | |
| loss_dict = {'L1': L1_msp, 'L2': L2_msp, 'L_msp': L_msp, | |
| 'L_rec': L_rec, 'L_vae': L_vae} | |
| return Loss, loss_dict #L_rec.item(), L_vae.item(), L_msp.item() | |
| def acc(self, z, l): | |
| zl = z @ self.M.t() | |
| a = zl.clamp(-1, 1)*l*0.5+0.5 | |
| return a.round().mean().item() | |
| def predict(self, x, new_ls=None, weight=1.0): | |
| z, _ = self.encode(x) | |
| if new_ls is not None: | |
| zl = z @ self.M.t() | |
| d = torch.zeros_like(zl) | |
| for i, v in new_ls: | |
| d[:,i] = v*weight - zl[:,i] | |
| z += d @ self.M | |
| prod = self.decoder(z) | |
| return prod | |
| def predict_ex(self, x, label, new_ls=None, weight=1.0): | |
| return self.predict(x,new_ls,weight) | |
| def get_U(self, eps=1e-5): | |
| from scipy import linalg, compress | |
| # get the null matrix N of M | |
| # such that U=[M;N] is orthogonal | |
| M = self.M.detach().cpu() | |
| A = torch.zeros(M.shape[1]-M.shape[0], M.shape[1]) | |
| A = torch.cat([M, A]) | |
| u, s, vh = linalg.svd(A.numpy()) | |
| null_mask = (s <= eps) | |
| null_space = compress(null_mask, vh, axis=0) | |
| N = torch.tensor(null_space) | |
| return torch.cat([self.M, N.to(self.M.device)]) | |