File size: 5,174 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from dataclasses import dataclass
import torch
from einops import reduce
from jaxtyping import Float
from torch import Tensor
from src.dataset.types import BatchedExample
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss
from typing import Generic, TypeVar
from dataclasses import fields
import torch.nn.functional as F
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# from src.loss.depth_anything.dpt import DepthAnything
from src.misc.utils import vis_depth_map
import open3d as o3d
T_cfg = TypeVar("T_cfg")
T_wrapper = TypeVar("T_wrapper")
@dataclass
class LossNormalConsisCfg:
normal_weight: float
smooth_weight: float
sigma_image: float | None
use_second_derivative: bool
detach: bool = False
conf: bool = False
not_use_valid_mask: bool = False
@dataclass
class LossNormalConsisCfgWrapper:
normal_consis: LossNormalConsisCfg
class TVLoss(torch.nn.Module):
"""TV loss"""
def __init__(self):
super().__init__()
def forward(self, pred):
"""
Args:
pred: [batch, H, W, 3]
Returns:
tv_loss: [batch]
"""
h_diff = pred[..., :, :-1, :] - pred[..., :, 1:, :]
w_diff = pred[..., :-1, :, :] - pred[..., 1:, :, :]
return torch.mean(torch.abs(h_diff)) + torch.mean(torch.abs(w_diff))
class LossNormalConsis(Loss[LossNormalConsisCfg, LossNormalConsisCfgWrapper]):
def __init__(self, cfg: T_wrapper) -> None:
super().__init__(cfg)
# Extract the configuration from the wrapper.
(field,) = fields(type(cfg))
self.cfg = getattr(cfg, field.name)
self.name = field.name
def forward(
self,
prediction: DecoderOutput,
batch: BatchedExample,
gaussians: Gaussians,
depth_dict: dict,
global_step: int,
) -> Float[Tensor, ""]:
# Scale the depth between the near and far planes.
conf_valid_mask = depth_dict['conf_valid_mask'].flatten(0, 1)
valid_mask = batch["context"]["valid_mask"][:, batch["using_index"]].flatten(0, 1)
if self.cfg.conf:
valid_mask = valid_mask & conf_valid_mask
if self.cfg.not_use_valid_mask:
valid_mask = torch.ones_like(valid_mask, device=valid_mask.device)
render_normal = self.get_normal_map(prediction.depth.flatten(0, 1), batch["context"]["intrinsics"].flatten(0, 1))
pred_normal = self.get_normal_map(depth_dict['depth'].flatten(0, 1).squeeze(-1), batch["context"]["intrinsics"].flatten(0, 1))
if self.cfg.detach:
pred_normal = pred_normal.detach()
alpha1_loss = (1 - (render_normal * pred_normal).sum(-1)).mean()
alpha2_loss = F.l1_loss(render_normal, pred_normal, reduction='mean')
normal_smooth_loss = TVLoss()(render_normal)
normal_loss = (alpha1_loss + alpha2_loss) / 2
return self.cfg.normal_weight * torch.nan_to_num(normal_loss, nan=0.0) + self.cfg.smooth_weight * torch.nan_to_num(normal_smooth_loss, nan=0.0)
def get_normal_map(self, depth_map: torch.Tensor, intrinsic: torch.Tensor) -> torch.Tensor:
"""
Convert a depth map to camera coordinates.
Args:
depth_map (torch.Tensor): Depth map of shape (H, W).
intrinsic (torch.Tensor): Camera intrinsic matrix of shape (3, 3).
Returns:
tuple[torch.Tensor, torch.Tensor]: Camera coordinates (H, W, 3)
"""
B, H, W = depth_map.shape
assert intrinsic.shape == (B, 3, 3), "Intrinsic matrix must be Bx3x3"
assert (intrinsic[:, 0, 1] == 0).all() and (intrinsic[:, 1, 0] == 0).all(), "Intrinsic matrix must have zero skew"
# Intrinsic parameters
fu = intrinsic[:, 0, 0] * W # (B,)
fv = intrinsic[:, 1, 1] * H # (B,)
cu = intrinsic[:, 0, 2] * W # (B,)
cv = intrinsic[:, 1, 2] * H # (B,)
# Generate grid of pixel coordinates
u = torch.arange(W, device=depth_map.device)[None, None, :].expand(B, H, W)
v = torch.arange(H, device=depth_map.device)[None, :, None].expand(B, H, W)
# Unproject to camera coordinates (B, H, W)
x_cam = (u - cu[:, None, None]) * depth_map / fu[:, None, None]
y_cam = (v - cv[:, None, None]) * depth_map / fv[:, None, None]
z_cam = depth_map
# Stack to form camera coordinates (B, H, W, 3)
cam_coords = torch.stack((x_cam, y_cam, z_cam), dim=-1).to(dtype=torch.float32)
output = torch.zeros_like(cam_coords)
# Calculate dx using batch dimension (B, H-2, W-2, 3)
dx = cam_coords[:, 2:, 1:-1] - cam_coords[:, :-2, 1:-1]
# Calculate dy using batch dimension (B, H-2, W-2, 3)
dy = cam_coords[:, 1:-1, 2:] - cam_coords[:, 1:-1, :-2]
# Cross product and normalization (B, H-2, W-2, 3)
normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1)
# Assign the computed normal map to the output tensor
output[:, 1:-1, 1:-1, :] = normal_map
return output |