File size: 3,436 Bytes
9e15541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from scenedino.losses.base_loss import BaseLoss
import torch
import torch.nn.functional as F


class StegoLoss(BaseLoss):
    def __init__(self, config) -> None:
        super().__init__(config)
        self.random_weight = config.get("random_weight", 1.0)
        self.knn_weight = config.get("knn_weight", 1.0)
        self.self_weight = config.get("self_weight", 1.0)

        self.random_shift = config.get("random_shift", 0.0)
        self.knn_shift = config.get("knn_shift", 0.0)
        self.self_shift = config.get("self_shift", 0.0)

        self.pointwise = config.get("pointwise", True)

    def get_loss_metric_names(self) -> list[str]:
        return [
            "total_loss", 
            "self_loss", "knn_loss", "random_loss", 
            "direct_cluster_loss", "direct_linear_loss", "stego_cluster_loss", "stego_linear_loss"
        ]

    def __call__(self, data) -> dict[str, torch.Tensor]:

        if "stego_corr" not in data["segmentation"]:
            self_loss, knn_loss, random_loss, total_loss = 0, 0, 0, 0
        else:
            dino_self_corr = data["segmentation"]["stego_corr"]["dino_self_corr"]
            stego_self_corr = data["segmentation"]["stego_corr"]["stego_self_corr"]

            dino_nn_corr = data["segmentation"]["stego_corr"]["dino_nn_corr"]
            stego_nn_corr = data["segmentation"]["stego_corr"]["stego_nn_corr"]

            dino_random_corr = data["segmentation"]["stego_corr"]["dino_random_corr"]
            stego_random_corr = data["segmentation"]["stego_corr"]["stego_random_corr"]

            self_loss = self._compute_stego_loss(dino_self_corr, stego_self_corr,
                                                self.self_weight, self.self_shift)
            knn_loss = self._compute_stego_loss(dino_nn_corr, stego_nn_corr,
                                                self.knn_weight, self.knn_shift)
            random_loss = self._compute_stego_loss(dino_random_corr, stego_random_corr,
                                                self.random_weight, self.random_shift)
            total_loss = self_loss + knn_loss + random_loss
        
        direct_cluster_loss = data["segmentation"]["results"]["direct_cluster"].get("loss", 0.0)
        stego_cluster_loss = data["segmentation"]["results"]["stego_cluster"].get("loss", 0.0)

        # If linear heads present
        direct_linear_loss = data["segmentation"]["results"].get("direct_linear", {}).get("loss", 0.0)
        stego_linear_loss = data["segmentation"]["results"].get("stego_linear", {}).get("loss", 0.0)

        total_loss += direct_cluster_loss + direct_linear_loss + stego_cluster_loss + stego_linear_loss

        losses = {
            "total_loss": total_loss,

            "self_loss": self_loss,
            "knn_loss": knn_loss,
            "random_loss": random_loss,

            "direct_cluster_loss": direct_cluster_loss,
            "direct_linear_loss": direct_linear_loss,
            "stego_cluster_loss": stego_cluster_loss,
            "stego_linear_loss": stego_linear_loss,
        }
        return losses

    def _compute_stego_loss(self, dino_corr, stego_corr, weight, shift):
        if self.pointwise:
            old_mean = dino_corr.mean()
            dino_corr -= dino_corr.mean(dim=-1, keepdim=True)
            dino_corr = dino_corr - dino_corr.mean() + old_mean

        loss = -weight * stego_corr.clamp(0) * (dino_corr - shift)

        return loss.mean()