File size: 2,670 Bytes
2568013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from dataclasses import dataclass

import torch
from einops import reduce
from jaxtyping import Float
from torch import Tensor

from src.dataset.types import BatchedExample
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss
from typing import Generic, TypeVar
from dataclasses import fields
import torch.nn.functional as F
import sys
from pytorch3d.loss import chamfer_distance
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# from src.loss.depth_anything.dpt import DepthAnything
from src.misc.utils import vis_depth_map

T_cfg = TypeVar("T_cfg")
T_wrapper = TypeVar("T_wrapper")


@dataclass
class LossChamferDistanceCfg:
    weight: float
    down_sample_ratio: float
    sigma_image: float | None


@dataclass
class LossChamferDistanceCfgWrapper:
    chamfer_distance: LossChamferDistanceCfg

class LossChamferDistance(Loss[LossChamferDistanceCfg, LossChamferDistanceCfgWrapper]):
    def __init__(self, cfg: T_wrapper) -> None:
        super().__init__(cfg)
        
        # Extract the configuration from the wrapper.
        (field,) = fields(type(cfg))
        self.cfg = getattr(cfg, field.name)
        self.name = field.name

    def forward(
        self,
        prediction: DecoderOutput,
        batch: BatchedExample,
        gaussians: Gaussians,
        depth_dict: dict,
        global_step: int,
    ) -> Float[Tensor, ""]:
        # Scale the depth between the near and far planes.
        b, v, h, w, _ = depth_dict['distill_infos']['pts_all'].shape
        pred_pts = depth_dict['distill_infos']['pts_all'].flatten(0, 1)

        conf_mask = depth_dict['distill_infos']['conf_mask']
        gaussian_meas = gaussians.means

        pred_pts = pred_pts.view(b, v, h, w, -1)
        conf_mask = conf_mask.view(b, v, h, w)

        pts_mask = torch.abs(gaussian_meas[..., -1]) < 1e2 # 
        # conf_mask = conf_mask & pts_mask
        
        cd_losses = 0.0
        for b_idx in range(b):
            batch_pts, batch_conf, batch_gaussian = pred_pts[b_idx], conf_mask[b_idx], gaussian_meas[b_idx][pts_mask[b_idx]]
            batch_pts = batch_pts[batch_conf]
            batch_pts = batch_pts[torch.randperm(batch_pts.shape[0])[:int(batch_pts.shape[0] * self.cfg.down_sample_ratio)]]
            batch_gaussian = batch_gaussian[torch.randperm(batch_gaussian.shape[0])[:int(batch_gaussian.shape[0] * self.cfg.down_sample_ratio)]]
            cd_loss = chamfer_distance(batch_pts.unsqueeze(0), batch_gaussian.unsqueeze(0))[0]
            cd_losses = cd_losses + cd_loss
        return self.cfg.weight * torch.nan_to_num(cd_losses / b, nan=0.0)