SceneDINO / scenedino /evaluation /unified_evaluator.py
jev-aleks's picture
scenedino init
9e15541
import logging
from pathlib import Path
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset
from scenedino.common.io.configs import load_model_config
from scenedino.models import make_model
from scenedino.datasets import make_test_dataset
from scenedino.common.geometry import distance_to_z
from scenedino.renderer import NeRFRenderer
from scenedino.common.ray_sampler import ImageRaySampler, get_ray_sampler
from scenedino.evaluation.base_evaluator import base_evaluation
from scenedino.training.trainer_downstream import BTSDownstreamWrapper
IDX = 0
logger = logging.getLogger("evaluation")
class BTSWrapper(nn.Module):
def __init__(
self,
renderer,
config,
# evaluation_fns
) -> None:
super().__init__()
self.renderer = renderer
# TODO: have a consitent sampling range
self.z_near = config.get("z_near", 3.0)
self.z_far = config.get("z_far", 80.0)
self.sampler = ImageRaySampler(self.z_near, self.z_far)
# self.evaluation_fns = evaluation_fns
@staticmethod
def get_loss_metric_names():
return ["loss", "loss_l2", "loss_mask", "loss_temporal"]
def forward(self, data):
data = dict(data)
images = torch.stack(data["imgs"], dim=1) # n, v, c, h, w
poses = torch.stack(data["poses"], dim=1) # n, v, 4, 4 w2c
projs = torch.stack(data["projs"], dim=1) # n, v, 4, 4 (-1, 1)
B, n_frames, c, h, w = images.shape
device = images.device
# Use first frame as keyframe
to_base_pose = torch.inverse(poses[:, :1, :, :])
poses = to_base_pose.expand(-1, n_frames, -1, -1) @ poses
# TODO: make configurable
ids_encoder = [0]
self.renderer.net.compute_grid_transforms(
projs[:, ids_encoder], poses[:, ids_encoder]
)
self.renderer.net.encode(
images,
projs,
poses,
ids_encoder=ids_encoder,
ids_render=ids_encoder,
images_alt=images * 0.5 + 0.5,
)
all_rays, all_rgb_gt = self.sampler.sample(images * 0.5 + 0.5, poses, projs)
data["fine"] = []
data["coarse"] = []
self.renderer.net.set_scale(0)
render_dict = self.renderer(all_rays, want_weights=True, want_alphas=True)
if "fine" not in render_dict:
render_dict["fine"] = dict(render_dict["coarse"])
render_dict["rgb_gt"] = all_rgb_gt
render_dict["rays"] = all_rays
# TODO: check if distance to z is needed
render_dict = self.sampler.reconstruct(render_dict)
render_dict["coarse"]["depth"] = distance_to_z(
render_dict["coarse"]["depth"], projs
)
render_dict["fine"]["depth"] = distance_to_z(
render_dict["fine"]["depth"], projs
)
data["fine"].append(render_dict["fine"])
data["coarse"].append(render_dict["coarse"])
data["rgb_gt"] = render_dict["rgb_gt"]
data["rays"] = render_dict["rays"]
data["z_near"] = torch.tensor(self.z_near, device=images.device)
data["z_far"] = torch.tensor(self.z_far, device=images.device)
# for eval_fn in self.evaluation_fns:
# data["metrics"].update(eval_fn(data, model=self.renderer.net))
return data
def evaluation(local_rank, config):
return base_evaluation(local_rank, config, get_dataflow, initialize)
def get_dataflow(config):
test_dataset = make_test_dataset(config["dataset"])
test_loader = DataLoader(
test_dataset, # Subset(test_dataset, torch.randperm(test_dataset.length)[:1000]),
batch_size=config.get("batch_size", 1),
num_workers=config["num_workers"],
shuffle=False,
drop_last=False,
)
return test_loader
def initialize(config: dict):
checkpoint = Path(config["checkpoint"])
logger.info(f"Loading model config from {checkpoint.parent}")
load_model_config(checkpoint.parent, config)
net = make_model(config["model"], config["downstream"])
# net = make_model(config["model"])
renderer = NeRFRenderer.from_conf(config["renderer"])
renderer = renderer.bind_parallel(net, gpus=None).eval()
# TODO: attach evaluation functions rather that add them to the wrapper
# eval_fns = []
# for eval_conf in config["evaluations"]:
# eval_fn = make_eval_fn(eval_conf)
# if eval_fn is not None:
# eval_fns.append(eval_fn)
ray_sampler = get_ray_sampler(config["training"]["ray_sampler"])
model = BTSDownstreamWrapper(renderer, ray_sampler, config["model"])
# model = BTSWrapper(renderer, config["model"])
# model = BTSWrapper(renderer, config["model"], eval_fns)
return model