Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| from models.ade20k import ModelBuilder | |
| from saicinpainting.utils import check_and_warn_input_range | |
| IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None] | |
| IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None] | |
| class PerceptualLoss(nn.Module): | |
| def __init__(self, normalize_inputs=True): | |
| super(PerceptualLoss, self).__init__() | |
| self.normalize_inputs = normalize_inputs | |
| self.mean_ = IMAGENET_MEAN | |
| self.std_ = IMAGENET_STD | |
| vgg = torchvision.models.vgg19(pretrained=True).features | |
| vgg_avg_pooling = [] | |
| for weights in vgg.parameters(): | |
| weights.requires_grad = False | |
| for module in vgg.modules(): | |
| if module.__class__.__name__ == 'Sequential': | |
| continue | |
| elif module.__class__.__name__ == 'MaxPool2d': | |
| vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0)) | |
| else: | |
| vgg_avg_pooling.append(module) | |
| self.vgg = nn.Sequential(*vgg_avg_pooling) | |
| def do_normalize_inputs(self, x): | |
| return (x - self.mean_.to(x.device)) / self.std_.to(x.device) | |
| def partial_losses(self, input, target, mask=None): | |
| check_and_warn_input_range(target, 0, 1, 'PerceptualLoss target in partial_losses') | |
| # we expect input and target to be in [0, 1] range | |
| losses = [] | |
| if self.normalize_inputs: | |
| features_input = self.do_normalize_inputs(input) | |
| features_target = self.do_normalize_inputs(target) | |
| else: | |
| features_input = input | |
| features_target = target | |
| for layer in self.vgg[:30]: | |
| features_input = layer(features_input) | |
| features_target = layer(features_target) | |
| if layer.__class__.__name__ == 'ReLU': | |
| loss = F.mse_loss(features_input, features_target, reduction='none') | |
| if mask is not None: | |
| cur_mask = F.interpolate(mask, size=features_input.shape[-2:], | |
| mode='bilinear', align_corners=False) | |
| loss = loss * (1 - cur_mask) | |
| loss = loss.mean(dim=tuple(range(1, len(loss.shape)))) | |
| losses.append(loss) | |
| return losses | |
| def forward(self, input, target, mask=None): | |
| losses = self.partial_losses(input, target, mask=mask) | |
| return torch.stack(losses).sum(dim=0) | |
| def get_global_features(self, input): | |
| check_and_warn_input_range(input, 0, 1, 'PerceptualLoss input in get_global_features') | |
| if self.normalize_inputs: | |
| features_input = self.do_normalize_inputs(input) | |
| else: | |
| features_input = input | |
| features_input = self.vgg(features_input) | |
| return features_input | |
| class ResNetPL(nn.Module): | |
| def __init__(self, weight=1, | |
| weights_path=None, arch_encoder='resnet50dilated', segmentation=True): | |
| super().__init__() | |
| self.impl = ModelBuilder.get_encoder(weights_path=weights_path, | |
| arch_encoder=arch_encoder, | |
| arch_decoder='ppm_deepsup', | |
| fc_dim=2048, | |
| segmentation=segmentation) | |
| self.impl.eval() | |
| for w in self.impl.parameters(): | |
| w.requires_grad_(False) | |
| self.weight = weight | |
| def forward(self, pred, target): | |
| pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred) | |
| target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target) | |
| pred_feats = self.impl(pred, return_feature_maps=True) | |
| target_feats = self.impl(target, return_feature_maps=True) | |
| result = torch.stack([F.mse_loss(cur_pred, cur_target) | |
| for cur_pred, cur_target | |
| in zip(pred_feats, target_feats)]).sum() * self.weight | |
| return result | |