File size: 2,416 Bytes
2568013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from dataclasses import dataclass

import torch
from einops import rearrange
from jaxtyping import Float
from lpips import LPIPS
from torch import Tensor

from src.dataset.types import BatchedExample
from src.misc.nn_module_tools import convert_to_buffer
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss


@dataclass
class LossLpipsCfg:
    weight: float
    apply_after_step: int
    conf: bool = False
    alpha: bool = False
    mask: bool = False


@dataclass
class LossLpipsCfgWrapper:
    lpips: LossLpipsCfg


class LossLpips(Loss[LossLpipsCfg, LossLpipsCfgWrapper]):
    lpips: LPIPS

    def __init__(self, cfg: LossLpipsCfgWrapper) -> None:
        super().__init__(cfg)

        self.lpips = LPIPS(net="vgg")
        convert_to_buffer(self.lpips, persistent=False)
        
    def forward(
        self,
        prediction: DecoderOutput,
        batch: BatchedExample,
        gaussians: Gaussians,
        depth_dict: dict | None,
        global_step: int,
    ) -> Float[Tensor, ""]:
        image = (batch["context"]["image"] + 1) / 2
        
        # Before the specified step, don't apply the loss.
        if global_step < self.cfg.apply_after_step:
            return torch.tensor(0, dtype=torch.float32, device=image.device)
        
        if self.cfg.mask or self.cfg.alpha or self.cfg.conf:
            if self.cfg.mask:
                mask = batch["context"]["valid_mask"]
            elif self.cfg.alpha:
                mask = prediction.alpha
            elif self.cfg.conf:
                mask = depth_dict['conf_valid_mask']
            b, v, c, h, w = prediction.color.shape
            expanded_mask = mask.unsqueeze(2).expand(-1, -1, c, -1, -1)
            masked_pred = prediction.color * expanded_mask
            masked_img = image * expanded_mask
            
            loss = self.lpips.forward(
                rearrange(masked_pred, "b v c h w -> (b v) c h w"),
                rearrange(masked_img, "b v c h w -> (b v) c h w"),
                normalize=True,
            )
        else:
            loss = self.lpips.forward(
                rearrange(prediction.color, "b v c h w -> (b v) c h w"),
                rearrange(image, "b v c h w -> (b v) c h w"),
                normalize=True,
            )
        return self.cfg.weight * torch.nan_to_num(loss.mean(), nan=0.0, posinf=0.0, neginf=0.0)