|
from dataclasses import dataclass |
|
|
|
import torch |
|
from einops import rearrange |
|
from jaxtyping import Float |
|
from lpips import LPIPS |
|
from torch import Tensor |
|
|
|
from src.dataset.types import BatchedExample |
|
from src.misc.nn_module_tools import convert_to_buffer |
|
from src.model.decoder.decoder import DecoderOutput |
|
from src.model.types import Gaussians |
|
from .loss import Loss |
|
|
|
|
|
@dataclass |
|
class LossLpipsCfg: |
|
weight: float |
|
apply_after_step: int |
|
conf: bool = False |
|
alpha: bool = False |
|
mask: bool = False |
|
|
|
|
|
@dataclass |
|
class LossLpipsCfgWrapper: |
|
lpips: LossLpipsCfg |
|
|
|
|
|
class LossLpips(Loss[LossLpipsCfg, LossLpipsCfgWrapper]): |
|
lpips: LPIPS |
|
|
|
def __init__(self, cfg: LossLpipsCfgWrapper) -> None: |
|
super().__init__(cfg) |
|
|
|
self.lpips = LPIPS(net="vgg") |
|
convert_to_buffer(self.lpips, persistent=False) |
|
|
|
def forward( |
|
self, |
|
prediction: DecoderOutput, |
|
batch: BatchedExample, |
|
gaussians: Gaussians, |
|
depth_dict: dict | None, |
|
global_step: int, |
|
) -> Float[Tensor, ""]: |
|
image = (batch["context"]["image"] + 1) / 2 |
|
|
|
|
|
if global_step < self.cfg.apply_after_step: |
|
return torch.tensor(0, dtype=torch.float32, device=image.device) |
|
|
|
if self.cfg.mask or self.cfg.alpha or self.cfg.conf: |
|
if self.cfg.mask: |
|
mask = batch["context"]["valid_mask"] |
|
elif self.cfg.alpha: |
|
mask = prediction.alpha |
|
elif self.cfg.conf: |
|
mask = depth_dict['conf_valid_mask'] |
|
b, v, c, h, w = prediction.color.shape |
|
expanded_mask = mask.unsqueeze(2).expand(-1, -1, c, -1, -1) |
|
masked_pred = prediction.color * expanded_mask |
|
masked_img = image * expanded_mask |
|
|
|
loss = self.lpips.forward( |
|
rearrange(masked_pred, "b v c h w -> (b v) c h w"), |
|
rearrange(masked_img, "b v c h w -> (b v) c h w"), |
|
normalize=True, |
|
) |
|
else: |
|
loss = self.lpips.forward( |
|
rearrange(prediction.color, "b v c h w -> (b v) c h w"), |
|
rearrange(image, "b v c h w -> (b v) c h w"), |
|
normalize=True, |
|
) |
|
return self.cfg.weight * torch.nan_to_num(loss.mean(), nan=0.0, posinf=0.0, neginf=0.0) |
|
|