Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| from torch import Tensor | |
| class ContentLoss(nn.Module): | |
| """Constructs a content loss function based on the VGG19 network. | |
| Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image. | |
| Paper reference list: | |
| -`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper. | |
| -`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper. | |
| -`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper. | |
| """ | |
| def __init__(self) -> None: | |
| super(ContentLoss, self).__init__() | |
| # Load the VGG19 model trained on the ImageNet dataset. | |
| vgg19 = models.vgg19(pretrained=True).eval() | |
| # Extract the thirty-sixth layer output in the VGG19 model as the content loss. | |
| self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36]) | |
| # Freeze model parameters. | |
| for parameters in self.feature_extractor.parameters(): | |
| parameters.requires_grad = False | |
| # The preprocessing method of the input data. This is the VGG model preprocessing method of the ImageNet dataset. | |
| self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) | |
| self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) | |
| def forward(self, sr: Tensor, hr: Tensor) -> Tensor: | |
| # Standardized operations | |
| sr = sr.sub(self.mean).div(self.std) | |
| hr = hr.sub(self.mean).div(self.std) | |
| # Find the feature map difference between the two images | |
| loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr)) | |
| return loss | |
| class GenGaussLoss(nn.Module): | |
| def __init__( | |
| self, reduction='mean', | |
| alpha_eps = 1e-4, beta_eps=1e-4, | |
| resi_min = 1e-4, resi_max=1e3 | |
| ) -> None: | |
| super(GenGaussLoss, self).__init__() | |
| self.reduction = reduction | |
| self.alpha_eps = alpha_eps | |
| self.beta_eps = beta_eps | |
| self.resi_min = resi_min | |
| self.resi_max = resi_max | |
| def forward( | |
| self, | |
| mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor | |
| ): | |
| one_over_alpha1 = one_over_alpha + self.alpha_eps | |
| beta1 = beta + self.beta_eps | |
| resi = torch.abs(mean - target) | |
| # resi = torch.pow(resi*one_over_alpha1, beta1).clamp(min=self.resi_min, max=self.resi_max) | |
| resi = (resi*one_over_alpha1*beta1).clamp(min=self.resi_min, max=self.resi_max) | |
| ## check if resi has nans | |
| if torch.sum(resi != resi) > 0: | |
| print('resi has nans!!') | |
| return None | |
| log_one_over_alpha = torch.log(one_over_alpha1) | |
| log_beta = torch.log(beta1) | |
| lgamma_beta = torch.lgamma(torch.pow(beta1, -1)) | |
| if torch.sum(log_one_over_alpha != log_one_over_alpha) > 0: | |
| print('log_one_over_alpha has nan') | |
| if torch.sum(lgamma_beta != lgamma_beta) > 0: | |
| print('lgamma_beta has nan') | |
| if torch.sum(log_beta != log_beta) > 0: | |
| print('log_beta has nan') | |
| l = resi - log_one_over_alpha + lgamma_beta - log_beta | |
| if self.reduction == 'mean': | |
| return l.mean() | |
| elif self.reduction == 'sum': | |
| return l.sum() | |
| else: | |
| print('Reduction not supported') | |
| return None | |
| class TempCombLoss(nn.Module): | |
| def __init__( | |
| self, reduction='mean', | |
| alpha_eps = 1e-4, beta_eps=1e-4, | |
| resi_min = 1e-4, resi_max=1e3 | |
| ) -> None: | |
| super(TempCombLoss, self).__init__() | |
| self.reduction = reduction | |
| self.alpha_eps = alpha_eps | |
| self.beta_eps = beta_eps | |
| self.resi_min = resi_min | |
| self.resi_max = resi_max | |
| self.L_GenGauss = GenGaussLoss( | |
| reduction=self.reduction, | |
| alpha_eps=self.alpha_eps, beta_eps=self.beta_eps, | |
| resi_min=self.resi_min, resi_max=self.resi_max | |
| ) | |
| self.L_l1 = nn.L1Loss(reduction=self.reduction) | |
| def forward( | |
| self, | |
| mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor, | |
| T1: float, T2: float | |
| ): | |
| l1 = self.L_l1(mean, target) | |
| l2 = self.L_GenGauss(mean, one_over_alpha, beta, target) | |
| l = T1*l1 + T2*l2 | |
| return l | |
| # x1 = torch.randn(4,3,32,32) | |
| # x2 = torch.rand(4,3,32,32) | |
| # x3 = torch.rand(4,3,32,32) | |
| # x4 = torch.randn(4,3,32,32) | |
| # L = GenGaussLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3) | |
| # L2 = TempCombLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3) | |
| # print(L(x1, x2, x3, x4), L2(x1, x2, x3, x4, 1e0, 1e-2)) |