AnySplat / src /loss /loss_depth.py
alexnasa's picture
Upload 243 files
2568013 verified
from dataclasses import dataclass
import torch
from einops import reduce
from jaxtyping import Float
from torch import Tensor
from src.dataset.types import BatchedExample
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss
from typing import Generic, TypeVar
from dataclasses import fields
import torch.nn.functional as F
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# from src.loss.depth_anything.dpt import DepthAnything
from src.misc.utils import vis_depth_map
T_cfg = TypeVar("T_cfg")
T_wrapper = TypeVar("T_wrapper")
@dataclass
class LossDepthCfg:
weight: float
sigma_image: float | None
use_second_derivative: bool
@dataclass
class LossDepthCfgWrapper:
depth: LossDepthCfg
class LossDepth(Loss[LossDepthCfg, LossDepthCfgWrapper]):
def __init__(self, cfg: T_wrapper) -> None:
super().__init__(cfg)
# Extract the configuration from the wrapper.
(field,) = fields(type(cfg))
self.cfg = getattr(cfg, field.name)
self.name = field.name
model_configs = {
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}
}
encoder = 'vits' # or 'vitb', 'vits'
depth_anything = DepthAnything(model_configs[encoder])
depth_anything.load_state_dict(torch.load(f'src/loss/depth_anything/depth_anything_{encoder}14.pth'))
self.depth_anything = depth_anything
for param in self.depth_anything.parameters():
param.requires_grad = False
def disp_rescale(self, disp: Float[Tensor, "B H W"]):
disp = disp.flatten(1, 2)
disp_median = torch.median(disp, dim=-1, keepdim=True)[0] # (B, V, 1)
disp_var = (disp - disp_median).abs().mean(dim=-1, keepdim=True) # (B, V, 1)
disp = (disp - disp_median) / (disp_var + 1e-6)
return disp
def smooth_l1_loss(self, pred, target, beta=1.0, reduction='none'):
diff = pred - target
abs_diff = torch.abs(diff)
loss = torch.where(abs_diff < beta, 0.5 * diff ** 2 / beta, abs_diff - 0.5 * beta)
if reduction == 'mean':
return loss.mean()
elif reduction == 'sum':
return loss.sum()
elif reduction == 'none':
return loss
else:
raise ValueError("Invalid reduction type. Choose from 'mean', 'sum', or 'none'.")
def ctx_depth_loss(self,
depth_map: Float[Tensor, "B V H W C"],
depth_conf: Float[Tensor, "B V H W"],
batch: BatchedExample,
cxt_depth_weight: float = 0.01,
alpha: float = 0.2):
B, V, _, H, W = batch["context"]["image"].shape
ctx_imgs = batch["context"]["image"].view(B * V, 3, H, W).float()
da_output = self.depth_anything(ctx_imgs)
da_output = self.disp_rescale(da_output)
disp_context = 1.0 / depth_map.flatten(0, 1).squeeze(-1).clamp(1e-3) # (B * V, H, W)
context_output = self.disp_rescale(disp_context)
depth_conf = depth_conf.flatten(0, 1).flatten(1, 2) # (B * V)
return cxt_depth_weight * (self.smooth_l1_loss(context_output*depth_conf, da_output*depth_conf, reduction='none') - alpha * torch.log(depth_conf)).mean()
def forward(
self,
prediction: DecoderOutput,
batch: BatchedExample,
gaussians: Gaussians,
global_step: int,
) -> Float[Tensor, ""]:
# Scale the depth between the near and far planes.
target_imgs = batch["target"]["image"]
B, V, _, H, W = target_imgs.shape
target_imgs = target_imgs.view(B * V, 3, H, W)
da_output = self.depth_anything(target_imgs.float())
da_output = self.disp_rescale(da_output)
disp_gs = 1.0 / prediction.depth.flatten(0, 1).clamp(1e-3).float()
gs_output = self.disp_rescale(disp_gs)
return self.cfg.weight * torch.nan_to_num(F.smooth_l1_loss(da_output, gs_output), nan=0.0)