File size: 2,199 Bytes
2568013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from einops import rearrange
from jaxtyping import Float
from lpips import LPIPS
from torch import Tensor

from src.dataset.types import BatchedExample
from src.misc.nn_module_tools import convert_to_buffer
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss


@dataclass
class LossLODCfg:
    mse_weight: float
    lpips_weight: float

@dataclass
class LossLODCfgWrapper:
    lod: LossLODCfg

WEIGHT_LEVEL_MAPPING = {0: 0.1, 1: 0.1, 2: 0.2, 3: 0.6}

class LossLOD(Loss[LossLODCfg, LossLODCfgWrapper]):
    lpips: LPIPS

    def __init__(self, cfg: LossLODCfgWrapper) -> None:
        super().__init__(cfg)

        self.lpips = LPIPS(net="vgg")
        convert_to_buffer(self.lpips, persistent=False)
        
    def forward(
        self,
        prediction: DecoderOutput,
        batch: BatchedExample,
        gaussians: Gaussians,
        global_step: int,
    ) -> Float[Tensor, ""]:
        image = batch["target"]["image"]
        # breakpoint()
        def mse_loss(x, y):
            delta = x - y
            return torch.nan_to_num((delta**2).mean().mean(), nan=0.0, posinf=0.0, neginf=0.0)
        # Before the specified step, don't apply the loss.
        lod_rendering = prediction.lod_rendering
        loss = 0.0
        for level in lod_rendering.keys():
            # level_weight
            # breakpoint()
            # if level != 3:
            #     continue
            rendered_imgs = lod_rendering[level]['rendered_imgs'].flatten(0, 1)
            _h, _w = rendered_imgs.shape[2:]
            resized_image = F.interpolate(image.clone().flatten(0, 1), size=(_h, _w), mode='bilinear', align_corners=False)
            level_mse_loss = mse_loss(rendered_imgs, resized_image)
            level_lpips_loss = self.lpips.forward(rendered_imgs, resized_image, normalize=True).mean()

            loss += (torch.nan_to_num(level_mse_loss, nan=0.0, posinf=0.0, neginf=0.0) * self.cfg.mse_weight + torch.nan_to_num(level_lpips_loss, nan=0.0, posinf=0.0, neginf=0.0) * self.cfg.lpips_weight)
        return loss / len(lod_rendering.keys())