File size: 2,199 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from einops import rearrange
from jaxtyping import Float
from lpips import LPIPS
from torch import Tensor
from src.dataset.types import BatchedExample
from src.misc.nn_module_tools import convert_to_buffer
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss
@dataclass
class LossLODCfg:
mse_weight: float
lpips_weight: float
@dataclass
class LossLODCfgWrapper:
lod: LossLODCfg
WEIGHT_LEVEL_MAPPING = {0: 0.1, 1: 0.1, 2: 0.2, 3: 0.6}
class LossLOD(Loss[LossLODCfg, LossLODCfgWrapper]):
lpips: LPIPS
def __init__(self, cfg: LossLODCfgWrapper) -> None:
super().__init__(cfg)
self.lpips = LPIPS(net="vgg")
convert_to_buffer(self.lpips, persistent=False)
def forward(
self,
prediction: DecoderOutput,
batch: BatchedExample,
gaussians: Gaussians,
global_step: int,
) -> Float[Tensor, ""]:
image = batch["target"]["image"]
# breakpoint()
def mse_loss(x, y):
delta = x - y
return torch.nan_to_num((delta**2).mean().mean(), nan=0.0, posinf=0.0, neginf=0.0)
# Before the specified step, don't apply the loss.
lod_rendering = prediction.lod_rendering
loss = 0.0
for level in lod_rendering.keys():
# level_weight
# breakpoint()
# if level != 3:
# continue
rendered_imgs = lod_rendering[level]['rendered_imgs'].flatten(0, 1)
_h, _w = rendered_imgs.shape[2:]
resized_image = F.interpolate(image.clone().flatten(0, 1), size=(_h, _w), mode='bilinear', align_corners=False)
level_mse_loss = mse_loss(rendered_imgs, resized_image)
level_lpips_loss = self.lpips.forward(rendered_imgs, resized_image, normalize=True).mean()
loss += (torch.nan_to_num(level_mse_loss, nan=0.0, posinf=0.0, neginf=0.0) * self.cfg.mse_weight + torch.nan_to_num(level_lpips_loss, nan=0.0, posinf=0.0, neginf=0.0) * self.cfg.lpips_weight)
return loss / len(lod_rendering.keys())
|