|
from dataclasses import dataclass |
|
|
|
import torch |
|
from einops import reduce |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
|
|
from src.dataset.types import BatchedExample |
|
from src.model.decoder.decoder import DecoderOutput |
|
from src.model.types import Gaussians |
|
from .loss import Loss |
|
from typing import Generic, TypeVar |
|
from dataclasses import fields |
|
import torch.nn.functional as F |
|
import sys |
|
import os |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
from src.misc.utils import vis_depth_map |
|
|
|
T_cfg = TypeVar("T_cfg") |
|
T_wrapper = TypeVar("T_wrapper") |
|
|
|
|
|
@dataclass |
|
class LossDepthCfg: |
|
weight: float |
|
sigma_image: float | None |
|
use_second_derivative: bool |
|
|
|
|
|
@dataclass |
|
class LossDepthCfgWrapper: |
|
depth: LossDepthCfg |
|
|
|
|
|
class LossDepth(Loss[LossDepthCfg, LossDepthCfgWrapper]): |
|
def __init__(self, cfg: T_wrapper) -> None: |
|
super().__init__(cfg) |
|
|
|
|
|
(field,) = fields(type(cfg)) |
|
self.cfg = getattr(cfg, field.name) |
|
self.name = field.name |
|
|
|
model_configs = { |
|
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, |
|
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, |
|
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]} |
|
} |
|
encoder = 'vits' |
|
depth_anything = DepthAnything(model_configs[encoder]) |
|
depth_anything.load_state_dict(torch.load(f'src/loss/depth_anything/depth_anything_{encoder}14.pth')) |
|
|
|
self.depth_anything = depth_anything |
|
for param in self.depth_anything.parameters(): |
|
param.requires_grad = False |
|
|
|
def disp_rescale(self, disp: Float[Tensor, "B H W"]): |
|
disp = disp.flatten(1, 2) |
|
disp_median = torch.median(disp, dim=-1, keepdim=True)[0] |
|
disp_var = (disp - disp_median).abs().mean(dim=-1, keepdim=True) |
|
disp = (disp - disp_median) / (disp_var + 1e-6) |
|
return disp |
|
|
|
def smooth_l1_loss(self, pred, target, beta=1.0, reduction='none'): |
|
diff = pred - target |
|
abs_diff = torch.abs(diff) |
|
|
|
loss = torch.where(abs_diff < beta, 0.5 * diff ** 2 / beta, abs_diff - 0.5 * beta) |
|
|
|
if reduction == 'mean': |
|
return loss.mean() |
|
elif reduction == 'sum': |
|
return loss.sum() |
|
elif reduction == 'none': |
|
return loss |
|
else: |
|
raise ValueError("Invalid reduction type. Choose from 'mean', 'sum', or 'none'.") |
|
|
|
def ctx_depth_loss(self, |
|
depth_map: Float[Tensor, "B V H W C"], |
|
depth_conf: Float[Tensor, "B V H W"], |
|
batch: BatchedExample, |
|
cxt_depth_weight: float = 0.01, |
|
alpha: float = 0.2): |
|
B, V, _, H, W = batch["context"]["image"].shape |
|
ctx_imgs = batch["context"]["image"].view(B * V, 3, H, W).float() |
|
da_output = self.depth_anything(ctx_imgs) |
|
da_output = self.disp_rescale(da_output) |
|
|
|
disp_context = 1.0 / depth_map.flatten(0, 1).squeeze(-1).clamp(1e-3) |
|
context_output = self.disp_rescale(disp_context) |
|
|
|
depth_conf = depth_conf.flatten(0, 1).flatten(1, 2) |
|
|
|
return cxt_depth_weight * (self.smooth_l1_loss(context_output*depth_conf, da_output*depth_conf, reduction='none') - alpha * torch.log(depth_conf)).mean() |
|
|
|
|
|
def forward( |
|
self, |
|
prediction: DecoderOutput, |
|
batch: BatchedExample, |
|
gaussians: Gaussians, |
|
global_step: int, |
|
) -> Float[Tensor, ""]: |
|
|
|
target_imgs = batch["target"]["image"] |
|
B, V, _, H, W = target_imgs.shape |
|
target_imgs = target_imgs.view(B * V, 3, H, W) |
|
da_output = self.depth_anything(target_imgs.float()) |
|
da_output = self.disp_rescale(da_output) |
|
|
|
disp_gs = 1.0 / prediction.depth.flatten(0, 1).clamp(1e-3).float() |
|
gs_output = self.disp_rescale(disp_gs) |
|
|
|
|
|
return self.cfg.weight * torch.nan_to_num(F.smooth_l1_loss(da_output, gs_output), nan=0.0) |