Spaces:
Runtime error
Runtime error
File size: 3,216 Bytes
bb3e610 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch
from torch import nn, Tensor
from typing import Any, List, Tuple, Dict
from .dm_loss import DMLoss
from .utils import _reshape_density
class DACELoss(nn.Module):
def __init__(
self,
bins: List[Tuple[float, float]],
reduction: int,
weight_count_loss: float = 1.0,
count_loss: str = "mae",
**kwargs: Any
) -> None:
super().__init__()
assert len(bins) > 0, f"Expected at least one bin, got {bins}"
assert all([len(b) == 2 for b in bins]), f"Expected all bins to be of length 2, got {bins}"
assert all([b[0] <= b[1] for b in bins]), f"Expected all bins to be in increasing order, got {bins}"
self.bins = bins
self.reduction = reduction
self.cross_entropy_fn = nn.CrossEntropyLoss(reduction="none")
count_loss = count_loss.lower()
assert count_loss in ["mae", "mse", "dmcount"], f"Expected count_loss to be one of ['mae', 'mse', 'dmcount'], got {count_loss}"
self.count_loss = count_loss
if self.count_loss == "mae":
self.use_dm_loss = False
self.count_loss_fn = nn.L1Loss(reduction="none")
elif self.count_loss == "mse":
self.use_dm_loss = False
self.count_loss_fn = nn.MSELoss(reduction="none")
else:
self.use_dm_loss = True
assert "input_size" in kwargs, f"Expected input_size to be in kwargs when count_loss='dmcount', got {kwargs}"
self.count_loss_fn = DMLoss(reduction=reduction, **kwargs)
self.weight_count_loss = weight_count_loss
def _bin_count(self, density_map: Tensor) -> Tensor:
class_map = torch.zeros_like(density_map, dtype=torch.long)
for idx, (low, high) in enumerate(self.bins):
mask = (density_map >= low) & (density_map <= high)
class_map[mask] = idx
return class_map.squeeze(1) # remove channel dimension
def forward(self, pred_class: Tensor, pred_density: Tensor, target_density: Tensor, target_points: List[Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
target_density = _reshape_density(target_density, reduction=self.reduction) if target_density.shape[-2:] != pred_density.shape[-2:] else target_density
assert pred_density.shape == target_density.shape, f"Expected pred_density and target_density to have the same shape, got {pred_density.shape} and {target_density.shape}"
target_class = self._bin_count(target_density)
cross_entropy_loss = self.cross_entropy_fn(pred_class, target_class).sum(dim=(-1, -2)).mean()
if self.use_dm_loss:
count_loss, loss_info = self.count_loss_fn(pred_density, target_density, target_points)
loss_info["ce_loss"] = cross_entropy_loss.detach()
else:
count_loss = self.count_loss_fn(pred_density, target_density).sum(dim=(-1, -2, -3)).mean()
loss_info = {
"ce_loss": cross_entropy_loss.detach(),
f"{self.count_loss}_loss": count_loss.detach(),
}
loss = cross_entropy_loss + self.weight_count_loss * count_loss
loss_info["loss"] = loss.detach()
return loss, loss_info
|