Spaces:
Running
on
Zero
Running
on
Zero
from scenedino.losses.base_loss import BaseLoss | |
import torch | |
import torch.nn.functional as F | |
class StegoLoss(BaseLoss): | |
def __init__(self, config) -> None: | |
super().__init__(config) | |
self.random_weight = config.get("random_weight", 1.0) | |
self.knn_weight = config.get("knn_weight", 1.0) | |
self.self_weight = config.get("self_weight", 1.0) | |
self.random_shift = config.get("random_shift", 0.0) | |
self.knn_shift = config.get("knn_shift", 0.0) | |
self.self_shift = config.get("self_shift", 0.0) | |
self.pointwise = config.get("pointwise", True) | |
def get_loss_metric_names(self) -> list[str]: | |
return [ | |
"total_loss", | |
"self_loss", "knn_loss", "random_loss", | |
"direct_cluster_loss", "direct_linear_loss", "stego_cluster_loss", "stego_linear_loss" | |
] | |
def __call__(self, data) -> dict[str, torch.Tensor]: | |
if "stego_corr" not in data["segmentation"]: | |
self_loss, knn_loss, random_loss, total_loss = 0, 0, 0, 0 | |
else: | |
dino_self_corr = data["segmentation"]["stego_corr"]["dino_self_corr"] | |
stego_self_corr = data["segmentation"]["stego_corr"]["stego_self_corr"] | |
dino_nn_corr = data["segmentation"]["stego_corr"]["dino_nn_corr"] | |
stego_nn_corr = data["segmentation"]["stego_corr"]["stego_nn_corr"] | |
dino_random_corr = data["segmentation"]["stego_corr"]["dino_random_corr"] | |
stego_random_corr = data["segmentation"]["stego_corr"]["stego_random_corr"] | |
self_loss = self._compute_stego_loss(dino_self_corr, stego_self_corr, | |
self.self_weight, self.self_shift) | |
knn_loss = self._compute_stego_loss(dino_nn_corr, stego_nn_corr, | |
self.knn_weight, self.knn_shift) | |
random_loss = self._compute_stego_loss(dino_random_corr, stego_random_corr, | |
self.random_weight, self.random_shift) | |
total_loss = self_loss + knn_loss + random_loss | |
direct_cluster_loss = data["segmentation"]["results"]["direct_cluster"].get("loss", 0.0) | |
stego_cluster_loss = data["segmentation"]["results"]["stego_cluster"].get("loss", 0.0) | |
# If linear heads present | |
direct_linear_loss = data["segmentation"]["results"].get("direct_linear", {}).get("loss", 0.0) | |
stego_linear_loss = data["segmentation"]["results"].get("stego_linear", {}).get("loss", 0.0) | |
total_loss += direct_cluster_loss + direct_linear_loss + stego_cluster_loss + stego_linear_loss | |
losses = { | |
"total_loss": total_loss, | |
"self_loss": self_loss, | |
"knn_loss": knn_loss, | |
"random_loss": random_loss, | |
"direct_cluster_loss": direct_cluster_loss, | |
"direct_linear_loss": direct_linear_loss, | |
"stego_cluster_loss": stego_cluster_loss, | |
"stego_linear_loss": stego_linear_loss, | |
} | |
return losses | |
def _compute_stego_loss(self, dino_corr, stego_corr, weight, shift): | |
if self.pointwise: | |
old_mean = dino_corr.mean() | |
dino_corr -= dino_corr.mean(dim=-1, keepdim=True) | |
dino_corr = dino_corr - dino_corr.mean() + old_mean | |
loss = -weight * stego_corr.clamp(0) * (dino_corr - shift) | |
return loss.mean() | |