File size: 3,522 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist

from imaginaire.utils.distributed import get_world_size, get_rank, \
    dist_all_reduce_tensor


class GatherLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
        dist.all_gather(output, input)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        input, = ctx.saved_tensors
        grad_out = torch.zeros_like(input)
        all_grads = torch.stack(grads)
        all_grads = dist_all_reduce_tensor(all_grads, reduce='sum')
        grad_out[:] = all_grads[get_rank()]
        return grad_out


class InfoNCELoss(nn.Module):
    def __init__(self,
                 temperature=0.07,
                 gather_distributed=True,
                 learn_temperature=True,
                 single_direction=False,
                 flatten=True):
        super(InfoNCELoss, self).__init__()
        self.logit_scale = nn.Parameter(torch.tensor([math.log(1/temperature)]))
        self.logit_scale.requires_grad = learn_temperature
        self.gather_distributed = gather_distributed
        self.single_direction = single_direction
        self.flatten = flatten

    def forward(self, features_a, features_b, gather_distributed=None, eps=1e-8):
        if gather_distributed is None:
            gather_distributed = self.gather_distributed

        if features_a is None or features_b is None:
            return torch.tensor(0, device='cuda'), torch.tensor(0, device='cuda')

        bs_a, bs_b = features_a.size(0), features_b.size(0)
        if self.flatten:
            features_a, features_b = features_a.reshape(bs_a, -1), features_b.reshape(bs_b, -1)
        else:
            features_a = features_a.reshape(bs_a, features_a.size(1), -1).mean(-1)
            features_b = features_b.reshape(bs_b, features_b.size(1), -1).mean(-1)

        # Temperature clipping.
        self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, 4.6052)

        # normalized features
        features_a = features_a / (features_a.norm(dim=1, keepdim=True) + eps)
        features_b = features_b / (features_b.norm(dim=1, keepdim=True) + eps)

        loss_a = self._forward_single_direction(features_a, features_b, gather_distributed)
        if self.single_direction:
            return loss_a
        else:
            loss_b = self._forward_single_direction(features_b, features_a, gather_distributed)
            return loss_a + loss_b

    def _forward_single_direction(
            self, features_a, features_b, gather_distributed):
        bs_a = features_a.shape[0]
        logit_scale = self.logit_scale.exp()
        if get_world_size() > 1 and gather_distributed:
            gather_features_b = torch.cat(GatherLayer.apply(features_b))
            gather_labels_a = torch.arange(bs_a, device='cuda') + get_rank() * bs_a
            logits_a = logit_scale * features_a @ gather_features_b.t()
        else:
            gather_labels_a = torch.arange(bs_a, device='cuda')
            logits_a = logit_scale * features_a @ features_b.t()
        loss_a = F.cross_entropy(logits_a, gather_labels_a)
        return loss_a