File size: 3,216 Bytes
bb3e610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from torch import nn, Tensor
from typing import Any, List, Tuple, Dict

from .dm_loss import DMLoss
from .utils import _reshape_density


class DACELoss(nn.Module):
    def __init__(
        self,
        bins: List[Tuple[float, float]],
        reduction: int,
        weight_count_loss: float = 1.0,
        count_loss: str = "mae",
        **kwargs: Any
    ) -> None:
        super().__init__()
        assert len(bins) > 0, f"Expected at least one bin, got {bins}"
        assert all([len(b) == 2 for b in bins]), f"Expected all bins to be of length 2, got {bins}"
        assert all([b[0] <= b[1] for b in bins]), f"Expected all bins to be in increasing order, got {bins}"
        self.bins = bins
        self.reduction = reduction
        self.cross_entropy_fn = nn.CrossEntropyLoss(reduction="none")

        count_loss = count_loss.lower()
        assert count_loss in ["mae", "mse", "dmcount"], f"Expected count_loss to be one of ['mae', 'mse', 'dmcount'], got {count_loss}"
        self.count_loss = count_loss
        if self.count_loss == "mae":
            self.use_dm_loss = False
            self.count_loss_fn = nn.L1Loss(reduction="none")
        elif self.count_loss == "mse":
            self.use_dm_loss = False
            self.count_loss_fn = nn.MSELoss(reduction="none")
        else:
            self.use_dm_loss = True
            assert "input_size" in kwargs, f"Expected input_size to be in kwargs when count_loss='dmcount', got {kwargs}"
            self.count_loss_fn = DMLoss(reduction=reduction, **kwargs)

        self.weight_count_loss = weight_count_loss

    def _bin_count(self, density_map: Tensor) -> Tensor:
        class_map = torch.zeros_like(density_map, dtype=torch.long)
        for idx, (low, high) in enumerate(self.bins):
            mask = (density_map >= low) & (density_map <= high)
            class_map[mask] = idx
        return class_map.squeeze(1)  # remove channel dimension

    def forward(self, pred_class: Tensor, pred_density: Tensor, target_density: Tensor, target_points: List[Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
        target_density = _reshape_density(target_density, reduction=self.reduction) if target_density.shape[-2:] != pred_density.shape[-2:] else target_density
        assert pred_density.shape == target_density.shape, f"Expected pred_density and target_density to have the same shape, got {pred_density.shape} and {target_density.shape}"

        target_class = self._bin_count(target_density)

        cross_entropy_loss = self.cross_entropy_fn(pred_class, target_class).sum(dim=(-1, -2)).mean()

        if self.use_dm_loss:
            count_loss, loss_info = self.count_loss_fn(pred_density, target_density, target_points)
            loss_info["ce_loss"] = cross_entropy_loss.detach()
        else:
            count_loss = self.count_loss_fn(pred_density, target_density).sum(dim=(-1, -2, -3)).mean()
            loss_info = {
                "ce_loss": cross_entropy_loss.detach(),
                f"{self.count_loss}_loss": count_loss.detach(),
            }

        loss = cross_entropy_loss + self.weight_count_loss * count_loss
        loss_info["loss"] = loss.detach()

        return loss, loss_info