File size: 4,799 Bytes
9e15541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import logging
from pathlib import Path
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset
from scenedino.common.io.configs import load_model_config
from scenedino.models import make_model


from scenedino.datasets import make_test_dataset
from scenedino.common.geometry import distance_to_z
from scenedino.renderer import NeRFRenderer
from scenedino.common.ray_sampler import ImageRaySampler, get_ray_sampler
from scenedino.evaluation.base_evaluator import base_evaluation

from scenedino.training.trainer_downstream import BTSDownstreamWrapper

IDX = 0

logger = logging.getLogger("evaluation")


class BTSWrapper(nn.Module):
    def __init__(
        self,
        renderer,
        config,
        # evaluation_fns
    ) -> None:
        super().__init__()

        self.renderer = renderer

        # TODO: have a consitent sampling range
        self.z_near = config.get("z_near", 3.0)
        self.z_far = config.get("z_far", 80.0)
        self.sampler = ImageRaySampler(self.z_near, self.z_far)

        # self.evaluation_fns = evaluation_fns

    @staticmethod
    def get_loss_metric_names():
        return ["loss", "loss_l2", "loss_mask", "loss_temporal"]

    def forward(self, data):
        data = dict(data)
        images = torch.stack(data["imgs"], dim=1)  # n, v, c, h, w
        poses = torch.stack(data["poses"], dim=1)  # n, v, 4, 4 w2c
        projs = torch.stack(data["projs"], dim=1)  # n, v, 4, 4 (-1, 1)

        B, n_frames, c, h, w = images.shape
        device = images.device

        # Use first frame as keyframe
        to_base_pose = torch.inverse(poses[:, :1, :, :])
        poses = to_base_pose.expand(-1, n_frames, -1, -1) @ poses

        # TODO: make configurable
        ids_encoder = [0]

        self.renderer.net.compute_grid_transforms(
            projs[:, ids_encoder], poses[:, ids_encoder]
        )
        self.renderer.net.encode(
            images,
            projs,
            poses,
            ids_encoder=ids_encoder,
            ids_render=ids_encoder,
            images_alt=images * 0.5 + 0.5,
        )

        all_rays, all_rgb_gt = self.sampler.sample(images * 0.5 + 0.5, poses, projs)

        data["fine"] = []
        data["coarse"] = []

        self.renderer.net.set_scale(0)
        render_dict = self.renderer(all_rays, want_weights=True, want_alphas=True)

        if "fine" not in render_dict:
            render_dict["fine"] = dict(render_dict["coarse"])

        render_dict["rgb_gt"] = all_rgb_gt
        render_dict["rays"] = all_rays

        # TODO: check if distance to z is needed
        render_dict = self.sampler.reconstruct(render_dict)
        render_dict["coarse"]["depth"] = distance_to_z(
            render_dict["coarse"]["depth"], projs
        )
        render_dict["fine"]["depth"] = distance_to_z(
            render_dict["fine"]["depth"], projs
        )

        data["fine"].append(render_dict["fine"])
        data["coarse"].append(render_dict["coarse"])
        data["rgb_gt"] = render_dict["rgb_gt"]
        data["rays"] = render_dict["rays"]

        data["z_near"] = torch.tensor(self.z_near, device=images.device)
        data["z_far"] = torch.tensor(self.z_far, device=images.device)

        # for eval_fn in self.evaluation_fns:
        #     data["metrics"].update(eval_fn(data, model=self.renderer.net))

        return data


def evaluation(local_rank, config):
    return base_evaluation(local_rank, config, get_dataflow, initialize)


def get_dataflow(config):
    test_dataset = make_test_dataset(config["dataset"])
    test_loader = DataLoader(
        test_dataset,  # Subset(test_dataset, torch.randperm(test_dataset.length)[:1000]),
        batch_size=config.get("batch_size", 1),
        num_workers=config["num_workers"],
        shuffle=False,
        drop_last=False,
    )

    return test_loader


def initialize(config: dict):
    checkpoint = Path(config["checkpoint"])
    logger.info(f"Loading model config from {checkpoint.parent}")
    load_model_config(checkpoint.parent, config)

    net = make_model(config["model"], config["downstream"])
    # net = make_model(config["model"])

    renderer = NeRFRenderer.from_conf(config["renderer"])
    renderer = renderer.bind_parallel(net, gpus=None).eval()

    # TODO: attach evaluation functions rather that add them to the wrapper
    # eval_fns = []
    # for eval_conf in config["evaluations"]:
    #     eval_fn = make_eval_fn(eval_conf)
    #     if eval_fn is not None:
    #         eval_fns.append(eval_fn)

    ray_sampler = get_ray_sampler(config["training"]["ray_sampler"])
    model = BTSDownstreamWrapper(renderer, ray_sampler, config["model"])
    # model = BTSWrapper(renderer, config["model"])
    # model = BTSWrapper(renderer, config["model"], eval_fns)

    return model