File size: 2,416 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
from dataclasses import dataclass
import torch
from einops import rearrange
from jaxtyping import Float
from lpips import LPIPS
from torch import Tensor
from src.dataset.types import BatchedExample
from src.misc.nn_module_tools import convert_to_buffer
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss
@dataclass
class LossLpipsCfg:
weight: float
apply_after_step: int
conf: bool = False
alpha: bool = False
mask: bool = False
@dataclass
class LossLpipsCfgWrapper:
lpips: LossLpipsCfg
class LossLpips(Loss[LossLpipsCfg, LossLpipsCfgWrapper]):
lpips: LPIPS
def __init__(self, cfg: LossLpipsCfgWrapper) -> None:
super().__init__(cfg)
self.lpips = LPIPS(net="vgg")
convert_to_buffer(self.lpips, persistent=False)
def forward(
self,
prediction: DecoderOutput,
batch: BatchedExample,
gaussians: Gaussians,
depth_dict: dict | None,
global_step: int,
) -> Float[Tensor, ""]:
image = (batch["context"]["image"] + 1) / 2
# Before the specified step, don't apply the loss.
if global_step < self.cfg.apply_after_step:
return torch.tensor(0, dtype=torch.float32, device=image.device)
if self.cfg.mask or self.cfg.alpha or self.cfg.conf:
if self.cfg.mask:
mask = batch["context"]["valid_mask"]
elif self.cfg.alpha:
mask = prediction.alpha
elif self.cfg.conf:
mask = depth_dict['conf_valid_mask']
b, v, c, h, w = prediction.color.shape
expanded_mask = mask.unsqueeze(2).expand(-1, -1, c, -1, -1)
masked_pred = prediction.color * expanded_mask
masked_img = image * expanded_mask
loss = self.lpips.forward(
rearrange(masked_pred, "b v c h w -> (b v) c h w"),
rearrange(masked_img, "b v c h w -> (b v) c h w"),
normalize=True,
)
else:
loss = self.lpips.forward(
rearrange(prediction.color, "b v c h w -> (b v) c h w"),
rearrange(image, "b v c h w -> (b v) c h w"),
normalize=True,
)
return self.cfg.weight * torch.nan_to_num(loss.mean(), nan=0.0, posinf=0.0, neginf=0.0)
|