|
from dataclasses import dataclass |
|
|
|
import torch |
|
from einops import reduce |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
|
|
from src.dataset.types import BatchedExample |
|
from src.model.decoder.decoder import DecoderOutput |
|
from src.model.types import Gaussians |
|
from .loss import Loss |
|
from typing import Generic, TypeVar |
|
from dataclasses import fields |
|
import torch.nn.functional as F |
|
import sys |
|
import os |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
from src.misc.utils import vis_depth_map |
|
import open3d as o3d |
|
T_cfg = TypeVar("T_cfg") |
|
T_wrapper = TypeVar("T_wrapper") |
|
|
|
@dataclass |
|
class LossNormalConsisCfg: |
|
normal_weight: float |
|
smooth_weight: float |
|
sigma_image: float | None |
|
use_second_derivative: bool |
|
detach: bool = False |
|
conf: bool = False |
|
not_use_valid_mask: bool = False |
|
|
|
@dataclass |
|
class LossNormalConsisCfgWrapper: |
|
normal_consis: LossNormalConsisCfg |
|
|
|
class TVLoss(torch.nn.Module): |
|
"""TV loss""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, pred): |
|
""" |
|
Args: |
|
pred: [batch, H, W, 3] |
|
|
|
Returns: |
|
tv_loss: [batch] |
|
""" |
|
h_diff = pred[..., :, :-1, :] - pred[..., :, 1:, :] |
|
w_diff = pred[..., :-1, :, :] - pred[..., 1:, :, :] |
|
return torch.mean(torch.abs(h_diff)) + torch.mean(torch.abs(w_diff)) |
|
|
|
|
|
class LossNormalConsis(Loss[LossNormalConsisCfg, LossNormalConsisCfgWrapper]): |
|
def __init__(self, cfg: T_wrapper) -> None: |
|
super().__init__(cfg) |
|
|
|
|
|
(field,) = fields(type(cfg)) |
|
self.cfg = getattr(cfg, field.name) |
|
self.name = field.name |
|
|
|
def forward( |
|
self, |
|
prediction: DecoderOutput, |
|
batch: BatchedExample, |
|
gaussians: Gaussians, |
|
depth_dict: dict, |
|
global_step: int, |
|
) -> Float[Tensor, ""]: |
|
|
|
conf_valid_mask = depth_dict['conf_valid_mask'].flatten(0, 1) |
|
valid_mask = batch["context"]["valid_mask"][:, batch["using_index"]].flatten(0, 1) |
|
if self.cfg.conf: |
|
valid_mask = valid_mask & conf_valid_mask |
|
if self.cfg.not_use_valid_mask: |
|
valid_mask = torch.ones_like(valid_mask, device=valid_mask.device) |
|
render_normal = self.get_normal_map(prediction.depth.flatten(0, 1), batch["context"]["intrinsics"].flatten(0, 1)) |
|
pred_normal = self.get_normal_map(depth_dict['depth'].flatten(0, 1).squeeze(-1), batch["context"]["intrinsics"].flatten(0, 1)) |
|
if self.cfg.detach: |
|
pred_normal = pred_normal.detach() |
|
alpha1_loss = (1 - (render_normal * pred_normal).sum(-1)).mean() |
|
alpha2_loss = F.l1_loss(render_normal, pred_normal, reduction='mean') |
|
normal_smooth_loss = TVLoss()(render_normal) |
|
normal_loss = (alpha1_loss + alpha2_loss) / 2 |
|
return self.cfg.normal_weight * torch.nan_to_num(normal_loss, nan=0.0) + self.cfg.smooth_weight * torch.nan_to_num(normal_smooth_loss, nan=0.0) |
|
|
|
def get_normal_map(self, depth_map: torch.Tensor, intrinsic: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Convert a depth map to camera coordinates. |
|
|
|
Args: |
|
depth_map (torch.Tensor): Depth map of shape (H, W). |
|
intrinsic (torch.Tensor): Camera intrinsic matrix of shape (3, 3). |
|
|
|
Returns: |
|
tuple[torch.Tensor, torch.Tensor]: Camera coordinates (H, W, 3) |
|
""" |
|
B, H, W = depth_map.shape |
|
assert intrinsic.shape == (B, 3, 3), "Intrinsic matrix must be Bx3x3" |
|
assert (intrinsic[:, 0, 1] == 0).all() and (intrinsic[:, 1, 0] == 0).all(), "Intrinsic matrix must have zero skew" |
|
|
|
|
|
fu = intrinsic[:, 0, 0] * W |
|
fv = intrinsic[:, 1, 1] * H |
|
cu = intrinsic[:, 0, 2] * W |
|
cv = intrinsic[:, 1, 2] * H |
|
|
|
|
|
u = torch.arange(W, device=depth_map.device)[None, None, :].expand(B, H, W) |
|
v = torch.arange(H, device=depth_map.device)[None, :, None].expand(B, H, W) |
|
|
|
|
|
x_cam = (u - cu[:, None, None]) * depth_map / fu[:, None, None] |
|
y_cam = (v - cv[:, None, None]) * depth_map / fv[:, None, None] |
|
z_cam = depth_map |
|
|
|
|
|
cam_coords = torch.stack((x_cam, y_cam, z_cam), dim=-1).to(dtype=torch.float32) |
|
|
|
output = torch.zeros_like(cam_coords) |
|
|
|
dx = cam_coords[:, 2:, 1:-1] - cam_coords[:, :-2, 1:-1] |
|
|
|
dy = cam_coords[:, 1:-1, 2:] - cam_coords[:, 1:-1, :-2] |
|
|
|
normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) |
|
|
|
output[:, 1:-1, 1:-1, :] = normal_map |
|
|
|
return output |