Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| class PerceptualLoss(nn.Module): | |
| r""" | |
| Perceptual loss, VGG-based | |
| https://arxiv.org/abs/1603.08155 | |
| https://github.com/dxyang/StyleTransfer/blob/master/utils.py | |
| """ | |
| def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): | |
| super(PerceptualLoss, self).__init__() | |
| self.add_module('vgg', VGG19()) | |
| self.criterion = torch.nn.L1Loss() | |
| self.weights = weights | |
| def __call__(self, x, y): | |
| # Compute features | |
| x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
| content_loss = 0.0 | |
| content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1']) | |
| content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1']) | |
| content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1']) | |
| content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1']) | |
| content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1']) | |
| return content_loss | |
| class VGG19(torch.nn.Module): | |
| def __init__(self): | |
| super(VGG19, self).__init__() | |
| features = models.vgg19(pretrained=True).features | |
| self.relu1_1 = torch.nn.Sequential() | |
| self.relu1_2 = torch.nn.Sequential() | |
| self.relu2_1 = torch.nn.Sequential() | |
| self.relu2_2 = torch.nn.Sequential() | |
| self.relu3_1 = torch.nn.Sequential() | |
| self.relu3_2 = torch.nn.Sequential() | |
| self.relu3_3 = torch.nn.Sequential() | |
| self.relu3_4 = torch.nn.Sequential() | |
| self.relu4_1 = torch.nn.Sequential() | |
| self.relu4_2 = torch.nn.Sequential() | |
| self.relu4_3 = torch.nn.Sequential() | |
| self.relu4_4 = torch.nn.Sequential() | |
| self.relu5_1 = torch.nn.Sequential() | |
| self.relu5_2 = torch.nn.Sequential() | |
| self.relu5_3 = torch.nn.Sequential() | |
| self.relu5_4 = torch.nn.Sequential() | |
| for x in range(2): | |
| self.relu1_1.add_module(str(x), features[x]) | |
| for x in range(2, 4): | |
| self.relu1_2.add_module(str(x), features[x]) | |
| for x in range(4, 7): | |
| self.relu2_1.add_module(str(x), features[x]) | |
| for x in range(7, 9): | |
| self.relu2_2.add_module(str(x), features[x]) | |
| for x in range(9, 12): | |
| self.relu3_1.add_module(str(x), features[x]) | |
| for x in range(12, 14): | |
| self.relu3_2.add_module(str(x), features[x]) | |
| for x in range(14, 16): | |
| self.relu3_2.add_module(str(x), features[x]) | |
| for x in range(16, 18): | |
| self.relu3_4.add_module(str(x), features[x]) | |
| for x in range(18, 21): | |
| self.relu4_1.add_module(str(x), features[x]) | |
| for x in range(21, 23): | |
| self.relu4_2.add_module(str(x), features[x]) | |
| for x in range(23, 25): | |
| self.relu4_3.add_module(str(x), features[x]) | |
| for x in range(25, 27): | |
| self.relu4_4.add_module(str(x), features[x]) | |
| for x in range(27, 30): | |
| self.relu5_1.add_module(str(x), features[x]) | |
| for x in range(30, 32): | |
| self.relu5_2.add_module(str(x), features[x]) | |
| for x in range(32, 34): | |
| self.relu5_3.add_module(str(x), features[x]) | |
| for x in range(34, 36): | |
| self.relu5_4.add_module(str(x), features[x]) | |
| # don't need the gradients, just want the features | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, x): | |
| relu1_1 = self.relu1_1(x) | |
| relu1_2 = self.relu1_2(relu1_1) | |
| relu2_1 = self.relu2_1(relu1_2) | |
| relu2_2 = self.relu2_2(relu2_1) | |
| relu3_1 = self.relu3_1(relu2_2) | |
| relu3_2 = self.relu3_2(relu3_1) | |
| relu3_3 = self.relu3_3(relu3_2) | |
| relu3_4 = self.relu3_4(relu3_3) | |
| relu4_1 = self.relu4_1(relu3_4) | |
| relu4_2 = self.relu4_2(relu4_1) | |
| relu4_3 = self.relu4_3(relu4_2) | |
| relu4_4 = self.relu4_4(relu4_3) | |
| relu5_1 = self.relu5_1(relu4_4) | |
| relu5_2 = self.relu5_2(relu5_1) | |
| relu5_3 = self.relu5_3(relu5_2) | |
| relu5_4 = self.relu5_4(relu5_3) | |
| out = { | |
| 'relu1_1': relu1_1, | |
| 'relu1_2': relu1_2, | |
| 'relu2_1': relu2_1, | |
| 'relu2_2': relu2_2, | |
| 'relu3_1': relu3_1, | |
| 'relu3_2': relu3_2, | |
| 'relu3_3': relu3_3, | |
| 'relu3_4': relu3_4, | |
| 'relu4_1': relu4_1, | |
| 'relu4_2': relu4_2, | |
| 'relu4_3': relu4_3, | |
| 'relu4_4': relu4_4, | |
| 'relu5_1': relu5_1, | |
| 'relu5_2': relu5_2, | |
| 'relu5_3': relu5_3, | |
| 'relu5_4': relu5_4, | |
| } | |
| return out | |