File size: 2,670 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from dataclasses import dataclass
import torch
from einops import reduce
from jaxtyping import Float
from torch import Tensor
from src.dataset.types import BatchedExample
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss
from typing import Generic, TypeVar
from dataclasses import fields
import torch.nn.functional as F
import sys
from pytorch3d.loss import chamfer_distance
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# from src.loss.depth_anything.dpt import DepthAnything
from src.misc.utils import vis_depth_map
T_cfg = TypeVar("T_cfg")
T_wrapper = TypeVar("T_wrapper")
@dataclass
class LossChamferDistanceCfg:
weight: float
down_sample_ratio: float
sigma_image: float | None
@dataclass
class LossChamferDistanceCfgWrapper:
chamfer_distance: LossChamferDistanceCfg
class LossChamferDistance(Loss[LossChamferDistanceCfg, LossChamferDistanceCfgWrapper]):
def __init__(self, cfg: T_wrapper) -> None:
super().__init__(cfg)
# Extract the configuration from the wrapper.
(field,) = fields(type(cfg))
self.cfg = getattr(cfg, field.name)
self.name = field.name
def forward(
self,
prediction: DecoderOutput,
batch: BatchedExample,
gaussians: Gaussians,
depth_dict: dict,
global_step: int,
) -> Float[Tensor, ""]:
# Scale the depth between the near and far planes.
b, v, h, w, _ = depth_dict['distill_infos']['pts_all'].shape
pred_pts = depth_dict['distill_infos']['pts_all'].flatten(0, 1)
conf_mask = depth_dict['distill_infos']['conf_mask']
gaussian_meas = gaussians.means
pred_pts = pred_pts.view(b, v, h, w, -1)
conf_mask = conf_mask.view(b, v, h, w)
pts_mask = torch.abs(gaussian_meas[..., -1]) < 1e2 #
# conf_mask = conf_mask & pts_mask
cd_losses = 0.0
for b_idx in range(b):
batch_pts, batch_conf, batch_gaussian = pred_pts[b_idx], conf_mask[b_idx], gaussian_meas[b_idx][pts_mask[b_idx]]
batch_pts = batch_pts[batch_conf]
batch_pts = batch_pts[torch.randperm(batch_pts.shape[0])[:int(batch_pts.shape[0] * self.cfg.down_sample_ratio)]]
batch_gaussian = batch_gaussian[torch.randperm(batch_gaussian.shape[0])[:int(batch_gaussian.shape[0] * self.cfg.down_sample_ratio)]]
cd_loss = chamfer_distance(batch_pts.unsqueeze(0), batch_gaussian.unsqueeze(0))[0]
cd_losses = cd_losses + cd_loss
return self.cfg.weight * torch.nan_to_num(cd_losses / b, nan=0.0) |