AnySplat / src /loss /loss_lpips.py
alexnasa's picture
Upload 243 files
2568013 verified
from dataclasses import dataclass
import torch
from einops import rearrange
from jaxtyping import Float
from lpips import LPIPS
from torch import Tensor
from src.dataset.types import BatchedExample
from src.misc.nn_module_tools import convert_to_buffer
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss
@dataclass
class LossLpipsCfg:
weight: float
apply_after_step: int
conf: bool = False
alpha: bool = False
mask: bool = False
@dataclass
class LossLpipsCfgWrapper:
lpips: LossLpipsCfg
class LossLpips(Loss[LossLpipsCfg, LossLpipsCfgWrapper]):
lpips: LPIPS
def __init__(self, cfg: LossLpipsCfgWrapper) -> None:
super().__init__(cfg)
self.lpips = LPIPS(net="vgg")
convert_to_buffer(self.lpips, persistent=False)
def forward(
self,
prediction: DecoderOutput,
batch: BatchedExample,
gaussians: Gaussians,
depth_dict: dict | None,
global_step: int,
) -> Float[Tensor, ""]:
image = (batch["context"]["image"] + 1) / 2
# Before the specified step, don't apply the loss.
if global_step < self.cfg.apply_after_step:
return torch.tensor(0, dtype=torch.float32, device=image.device)
if self.cfg.mask or self.cfg.alpha or self.cfg.conf:
if self.cfg.mask:
mask = batch["context"]["valid_mask"]
elif self.cfg.alpha:
mask = prediction.alpha
elif self.cfg.conf:
mask = depth_dict['conf_valid_mask']
b, v, c, h, w = prediction.color.shape
expanded_mask = mask.unsqueeze(2).expand(-1, -1, c, -1, -1)
masked_pred = prediction.color * expanded_mask
masked_img = image * expanded_mask
loss = self.lpips.forward(
rearrange(masked_pred, "b v c h w -> (b v) c h w"),
rearrange(masked_img, "b v c h w -> (b v) c h w"),
normalize=True,
)
else:
loss = self.lpips.forward(
rearrange(prediction.color, "b v c h w -> (b v) c h w"),
rearrange(image, "b v c h w -> (b v) c h w"),
normalize=True,
)
return self.cfg.weight * torch.nan_to_num(loss.mean(), nan=0.0, posinf=0.0, neginf=0.0)