Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 5,470 Bytes
			
			| 1b2a9b1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import swapae.util as util
from .stylegan2_layers import Downsample
def gan_loss(pred, should_be_classified_as_real):
    bs = pred.size(0)
    if should_be_classified_as_real:
        return F.softplus(-pred).view(bs, -1).mean(dim=1)
    else:
        return F.softplus(pred).view(bs, -1).mean(dim=1)
def feature_matching_loss(xs, ys, equal_weights=False, num_layers=6):
    loss = 0.0
    for i, (x, y) in enumerate(zip(xs[:num_layers], ys[:num_layers])):
        if equal_weights:
            weight = 1.0 / min(num_layers, len(xs))
        else:
            weight = 1 / (2 ** (min(num_layers, len(xs)) - i))
        loss = loss + (x - y).abs().flatten(1).mean(1) * weight
    return loss
class IntraImageNCELoss(nn.Module):
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='mean')
    def forward(self, query, target):
        num_locations = min(query.size(2) * query.size(3), self.opt.intraimage_num_locations)
        bs = query.size(0)
        patch_ids = torch.randperm(num_locations, device=query.device)
        query = query.flatten(2, 3)
        target = target.flatten(2, 3)
        # both query and target are of size B x C x N
        query = query[:, :, patch_ids]
        target = target[:, :, patch_ids]
        cosine_similarity = torch.bmm(query.transpose(1, 2), target)
        cosine_similarity = cosine_similarity.flatten(0, 1)
        target_label = torch.arange(num_locations, dtype=torch.long, device=query.device).repeat(bs)
        loss = self.cross_entropy_loss(cosine_similarity / 0.07, target_label)
        return loss
class VGG16Loss(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg_convs = torchvision.models.vgg16(pretrained=True).features
        self.register_buffer('mean',
                             torch.tensor([0.485, 0.456, 0.406])[None, :, None, None] - 0.5)
        self.register_buffer('stdev',
                             torch.tensor([0.229, 0.224, 0.225])[None, :, None, None] * 2)
        self.downsample = Downsample([1, 2, 1], factor=2)
    def copy_section(self, source, start, end):
        slice = torch.nn.Sequential()
        for i in range(start, end):
            slice.add_module(str(i), source[i])
        return slice
    def vgg_forward(self, x):
        x = (x - self.mean) / self.stdev
        features = []
        for name, layer in self.vgg_convs.named_children():
            if "MaxPool2d" == type(layer).__name__:
                features.append(x)
                if len(features) == 3:
                    break
                x = self.downsample(x)
            else:
                x = layer(x)
        return features
    def forward(self, x, y):
        y = y.detach()
        loss = 0
        weights = [1 / 32, 1 / 16, 1 / 8, 1 / 4, 1.0]
        #weights = [1] * 5
        total_weights = 0.0
        for i, (xf, yf) in enumerate(zip(self.vgg_forward(x), self.vgg_forward(y))):
            loss += F.l1_loss(xf, yf) * weights[i]
            total_weights += weights[i]
        return loss / total_weights
class NCELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='mean')
    def forward(self, query, target, negatives):
        query = util.normalize(query.flatten(1))
        target = util.normalize(target.flatten(1))
        negatives = util.normalize(negatives.flatten(1))
        bs = query.size(0)
        sim_pos = (query * target).sum(dim=1, keepdim=True)
        sim_neg = torch.mm(query, negatives.transpose(0, 1))
        all_similarity = torch.cat([sim_pos, sim_neg], axis=1) / 0.07
        #sim_target = util.compute_similarity_logit(query, target)
        #sim_target = torch.mm(query, target.transpose(0, 1)) / 0.07
        #sim_query = util.compute_similarity_logit(query, query)
        #util.set_diag_(sim_query, -20.0)
        #all_similarity = torch.cat([sim_target, sim_query], axis=1)
        #target_label = torch.arange(bs, dtype=torch.long,
        #                            device=query.device)
        target_label = torch.zeros(bs, dtype=torch.long, device=query.device)
        loss = self.cross_entropy_loss(all_similarity,
                                       target_label)
        return loss
class ScaleInvariantReconstructionLoss(nn.Module):
    def forward(self, query, target):
        query_flat = query.transpose(1, 3)
        target_flat = target.transpose(1, 3)
        dist = 1.0 - torch.bmm(
            query_flat[:, :, :, None, :].flatten(0, 2),
            target_flat[:, :, :, :, None].flatten(0, 2),
        )
        target_spatially_flat = target.flatten(1, 2)
        num_samples = min(target_spatially_flat.size(1), 64)
        random_indices = torch.randperm(num_samples, dtype=torch.long, device=target.device)
        randomly_sampled = target_spatially_flat[:, random_indices]
        random_indices = torch.randperm(num_samples, dtype=torch.long, device=target.device)
        another_random_sample = target_spatially_flat[:, random_indices]
        random_similarity = torch.bmm(
            randomly_sampled[:, :, None, :].flatten(0, 1),
            torch.flip(another_random_sample, [0])[:, :, :, None].flatten(0, 1)
        )
        return dist.mean() + random_similarity.clamp(min=0.0).mean()
 |