AnySplat / src /model /model_wrapper.py
alexnasa's picture
Upload 243 files
2568013 verified
from dataclasses import dataclass
from pathlib import Path
import gc
import random
from typing import Literal, Optional, Protocol, runtime_checkable, Any
import moviepy.editor as mpy
import torch
import torchvision
import wandb
from einops import pack, rearrange, repeat
from jaxtyping import Float
from lightning.pytorch import LightningModule
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.utilities import rank_zero_only
from tabulate import tabulate
from torch import Tensor, nn, optim
import torch.nn.functional as F
from loss.loss_lpips import LossLpips
from loss.loss_mse import LossMse
from model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri
from ..loss.loss_distill import DistillLoss
from src.utils.render import generate_path
from src.utils.point import get_normal_map
from ..loss.loss_huber import HuberLoss, extri_intri_to_pose_encoding
# from model.types import Gaussians
from ..dataset.data_module import get_data_shim
from ..dataset.types import BatchedExample
from ..evaluation.metrics import compute_lpips, compute_psnr, compute_ssim, abs_relative_difference, delta1_acc
from ..global_cfg import get_cfg
from ..loss import Loss
from ..loss.loss_point import Regr3D
from ..loss.loss_ssim import ssim
from ..misc.benchmarker import Benchmarker
from ..misc.cam_utils import update_pose, get_pnp_pose, rotation_6d_to_matrix
from ..misc.image_io import prep_image, save_image, save_video
from ..misc.LocalLogger import LOG_PATH, LocalLogger
from ..misc.nn_module_tools import convert_to_buffer
from ..misc.step_tracker import StepTracker
from ..misc.utils import inverse_normalize, vis_depth_map, confidence_map, get_overlap_tag
from ..visualization.annotation import add_label
from ..visualization.camera_trajectory.interpolation import (
interpolate_extrinsics,
interpolate_intrinsics,
)
from ..visualization.camera_trajectory.wobble import (
generate_wobble,
generate_wobble_transformation,
)
from ..visualization.color_map import apply_color_map_to_image
from ..visualization.layout import add_border, hcat, vcat
# from ..visualization.validation_in_3d import render_cameras, render_projections
from .decoder.decoder import Decoder, DepthRenderingMode
from .encoder import Encoder
from .encoder.visualization.encoder_visualizer import EncoderVisualizer
from .ply_export import export_ply
@dataclass
class OptimizerCfg:
lr: float
warm_up_steps: int
backbone_lr_multiplier: float
@dataclass
class TestCfg:
output_path: Path
align_pose: bool
pose_align_steps: int
rot_opt_lr: float
trans_opt_lr: float
compute_scores: bool
save_image: bool
save_video: bool
save_compare: bool
generate_video: bool
mode: Literal["inference", "evaluation"]
image_folder: str
@dataclass
class TrainCfg:
output_path: Path
depth_mode: DepthRenderingMode | None
extended_visualization: bool
print_log_every_n_steps: int
distiller: str
distill_max_steps: int
pose_loss_alpha: float = 1.0
pose_loss_delta: float = 1.0
cxt_depth_weight: float = 0.01
weight_pose: float = 1.0
weight_depth: float = 1.0
weight_normal: float = 1.0
render_ba: bool = False
render_ba_after_step: int = 0
@runtime_checkable
class TrajectoryFn(Protocol):
def __call__(
self,
t: Float[Tensor, " t"],
) -> tuple[
Float[Tensor, "batch view 4 4"], # extrinsics
Float[Tensor, "batch view 3 3"], # intrinsics
]:
pass
class ModelWrapper(LightningModule):
logger: Optional[WandbLogger]
model: nn.Module
losses: nn.ModuleList
optimizer_cfg: OptimizerCfg
test_cfg: TestCfg
train_cfg: TrainCfg
step_tracker: StepTracker | None
def __init__(
self,
optimizer_cfg: OptimizerCfg,
test_cfg: TestCfg,
train_cfg: TrainCfg,
model: nn.Module,
losses: list[Loss],
step_tracker: StepTracker | None
) -> None:
super().__init__()
self.optimizer_cfg = optimizer_cfg
self.test_cfg = test_cfg
self.train_cfg = train_cfg
self.step_tracker = step_tracker
# Set up the model.
self.encoder_visualizer = None
self.model = model
self.data_shim = get_data_shim(self.model.encoder)
self.losses = nn.ModuleList(losses)
if self.model.encoder.pred_pose:
self.loss_pose = HuberLoss(alpha=self.train_cfg.pose_loss_alpha, delta=self.train_cfg.pose_loss_delta)
if self.model.encoder.distill:
self.loss_distill = DistillLoss(
delta=self.train_cfg.pose_loss_delta,
weight_pose=self.train_cfg.weight_pose,
weight_depth=self.train_cfg.weight_depth,
weight_normal=self.train_cfg.weight_normal
)
# This is used for testing.
self.benchmarker = Benchmarker()
def on_train_epoch_start(self) -> None:
# our custom dataset and sampler has to have epoch set by calling set_epoch
if hasattr(self.trainer.datamodule.train_loader.dataset, "set_epoch"):
self.trainer.datamodule.train_loader.dataset.set_epoch(self.current_epoch)
if hasattr(self.trainer.datamodule.train_loader.sampler, "set_epoch"):
self.trainer.datamodule.train_loader.sampler.set_epoch(self.current_epoch)
def on_validation_epoch_start(self) -> None:
print(f"Validation epoch start on rank {self.trainer.global_rank}")
# our custom dataset and sampler has to have epoch set by calling set_epoch
if hasattr(self.trainer.datamodule.val_loader.dataset, "set_epoch"):
self.trainer.datamodule.val_loader.dataset.set_epoch(self.current_epoch)
if hasattr(self.trainer.datamodule.val_loader.sampler, "set_epoch"):
self.trainer.datamodule.val_loader.sampler.set_epoch(self.current_epoch)
def training_step(self, batch, batch_idx):
# combine batch from different dataloaders
# torch.cuda.empty_cache()
if isinstance(batch, list):
batch_combined = None
for batch_per_dl in batch:
if batch_combined is None:
batch_combined = batch_per_dl
else:
for k in batch_combined.keys():
if isinstance(batch_combined[k], list):
batch_combined[k] += batch_per_dl[k]
elif isinstance(batch_combined[k], dict):
for kk in batch_combined[k].keys():
batch_combined[k][kk] = torch.cat([batch_combined[k][kk], batch_per_dl[k][kk]], dim=0)
else:
raise NotImplementedError
batch = batch_combined
batch: BatchedExample = self.data_shim(batch)
b, v, c, h, w = batch["context"]["image"].shape
context_image = (batch["context"]["image"] + 1) / 2
# Run the model.
visualization_dump = None
encoder_output, output = self.model(context_image, self.global_step, visualization_dump=visualization_dump)
gaussians, pred_pose_enc_list, depth_dict = encoder_output.gaussians, encoder_output.pred_pose_enc_list, encoder_output.depth_dict
pred_context_pose = encoder_output.pred_context_pose
infos = encoder_output.infos
distill_infos = encoder_output.distill_infos
num_context_views = pred_context_pose['extrinsic'].shape[1]
using_index = torch.arange(num_context_views, device=gaussians.means.device)
batch["using_index"] = using_index
target_gt = (batch["context"]["image"] + 1) / 2
scene_scale = infos["scene_scale"]
self.log("train/scene_scale", infos["scene_scale"])
self.log("train/voxelize_ratio", infos["voxelize_ratio"])
# Compute metrics.
psnr_probabilistic = compute_psnr(
rearrange(target_gt, "b v c h w -> (b v) c h w"),
rearrange(output.color, "b v c h w -> (b v) c h w"),
)
self.log("train/psnr_probabilistic", psnr_probabilistic.mean())
consis_absrel = abs_relative_difference(
rearrange(output.depth, "b v h w -> (b v) h w"),
rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"),
rearrange(distill_infos['conf_mask'], "b v h w -> (b v) h w"),
)
self.log("train/consis_absrel", consis_absrel.mean())
consis_delta1 = delta1_acc(
rearrange(output.depth, "b v h w -> (b v) h w"),
rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"),
rearrange(distill_infos['conf_mask'], "b v h w -> (b v) h w"),
)
self.log("train/consis_delta1", consis_delta1.mean())
# Compute and log loss.
total_loss = 0
depth_dict['distill_infos'] = distill_infos
with torch.amp.autocast('cuda', enabled=False):
for loss_fn in self.losses:
loss = loss_fn.forward(output, batch, gaussians, depth_dict, self.global_step)
self.log(f"loss/{loss_fn.name}", loss)
total_loss = total_loss + loss
if depth_dict is not None and "depth" in get_cfg()["loss"].keys() and self.train_cfg.cxt_depth_weight > 0:
depth_loss_idx = list(get_cfg()["loss"].keys()).index("depth")
depth_loss_fn = self.losses[depth_loss_idx].ctx_depth_loss
loss_depth = depth_loss_fn(depth_dict["depth_map"], depth_dict["depth_conf"], batch, cxt_depth_weight=self.train_cfg.cxt_depth_weight)
self.log("loss/ctx_depth", loss_depth)
total_loss = total_loss + loss_depth
if distill_infos is not None:
# distill ctx pred_pose & depth & normal
loss_distill_list = self.loss_distill(distill_infos, pred_pose_enc_list, output, batch)
self.log("loss/distill", loss_distill_list['loss_distill'])
self.log("loss/distill_pose", loss_distill_list['loss_pose'])
self.log("loss/distill_depth", loss_distill_list['loss_depth'])
self.log("loss/distill_normal", loss_distill_list['loss_normal'])
total_loss = total_loss + loss_distill_list['loss_distill']
self.log("loss/total", total_loss)
print(f"total_loss: {total_loss}")
# Skip batch if loss is too high after certain step
SKIP_AFTER_STEP = 1000
LOSS_THRESHOLD = 0.2
if self.global_step > SKIP_AFTER_STEP and total_loss > LOSS_THRESHOLD:
print(f"Skipping batch with high loss ({total_loss:.6f}) at step {self.global_step} on Rank {self.global_rank}")
# set to a really small number
return total_loss * 1e-10
if (
self.global_rank == 0
and self.global_step % self.train_cfg.print_log_every_n_steps == 0
):
print(
f"train step {self.global_step}; "
f"scene = {[x[:20] for x in batch['scene']]}; "
f"context = {batch['context']['index'].tolist()}; "
f"loss = {total_loss:.6f}; "
)
self.log("info/global_step", self.global_step) # hack for ckpt monitor
# Tell the data loader processes about the current step.
if self.step_tracker is not None:
self.step_tracker.set_step(self.global_step)
del batch
if self.global_step % 50 == 0:
gc.collect()
torch.cuda.empty_cache()
return total_loss
def on_after_backward(self):
total_norm = 0.0
counter = 0
for p in self.parameters():
if p.grad is not None:
param_norm = p.grad.detach().data.norm(2)
total_norm += param_norm.item() ** 2
counter += 1
total_norm = (total_norm / counter) ** 0.5
self.log("loss/grad_norm", total_norm)
def test_step(self, batch, batch_idx):
batch: BatchedExample = self.data_shim(batch)
b, v, _, h, w = batch["target"]["image"].shape
assert b == 1
if batch_idx % 100 == 0:
print(f"Test step {batch_idx:0>6}.")
# Render Gaussians.
with self.benchmarker.time("encoder"):
gaussians = self.model.encoder(
(batch["context"]["image"]+1)/2,
self.global_step,
)[0]
# export_ply(gaussians.means[0], gaussians.scales[0], gaussians.rotations[0], gaussians.harmonics[0], gaussians.opacities[0], Path("gaussians.ply"))
# align the target pose
if self.test_cfg.align_pose:
output = self.test_step_align(batch, gaussians)
else:
with self.benchmarker.time("decoder", num_calls=v):
output = self.model.decoder.forward(
gaussians,
batch["target"]["extrinsics"],
batch["target"]["intrinsics"],
batch["target"]["near"],
batch["target"]["far"],
(h, w),
)
# compute scores
if self.test_cfg.compute_scores:
overlap = batch["context"]["overlap"][0]
overlap_tag = get_overlap_tag(overlap)
rgb_pred = output.color[0]
rgb_gt = batch["target"]["image"][0]
all_metrics = {
f"lpips_ours": compute_lpips(rgb_gt, rgb_pred).mean(),
f"ssim_ours": compute_ssim(rgb_gt, rgb_pred).mean(),
f"psnr_ours": compute_psnr(rgb_gt, rgb_pred).mean(),
}
methods = ['ours']
self.log_dict(all_metrics)
self.print_preview_metrics(all_metrics, methods, overlap_tag=overlap_tag)
# Save images.
(scene,) = batch["scene"]
name = get_cfg()["wandb"]["name"]
path = self.test_cfg.output_path / name
if self.test_cfg.save_image:
for index, color in zip(batch["target"]["index"][0], output.color[0]):
save_image(color, path / scene / f"color/{index:0>6}.png")
if self.test_cfg.save_video:
frame_str = "_".join([str(x.item()) for x in batch["context"]["index"][0]])
save_video(
[a for a in output.color[0]],
path / "video" / f"{scene}_frame_{frame_str}.mp4",
)
if self.test_cfg.save_compare:
# Construct comparison image.
context_img = inverse_normalize(batch["context"]["image"][0])
comparison = hcat(
add_label(vcat(*context_img), "Context"),
add_label(vcat(*rgb_gt), "Target (Ground Truth)"),
add_label(vcat(*rgb_pred), "Target (Prediction)"),
)
save_image(comparison, path / f"{scene}.png")
def test_step_align(self, batch, gaussians):
self.model.encoder.eval()
# freeze all parameters
for param in self.model.encoder.parameters():
param.requires_grad = False
b, v, _, h, w = batch["target"]["image"].shape
output_c2ws = batch["target"]["extrinsics"]
with torch.set_grad_enabled(True):
cam_rot_delta = nn.Parameter(torch.zeros([b, v, 6], requires_grad=True, device=output_c2ws.device))
cam_trans_delta = nn.Parameter(torch.zeros([b, v, 3], requires_grad=True, device=output_c2ws.device))
opt_params = []
self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]).to(output_c2ws))
opt_params.append(
{
"params": [cam_rot_delta],
"lr": 0.005,
}
)
opt_params.append(
{
"params": [cam_trans_delta],
"lr": 0.005,
}
)
pose_optimizer = torch.optim.Adam(opt_params)
extrinsics = output_c2ws.clone()
with self.benchmarker.time("optimize"):
for i in range(self.test_cfg.pose_align_steps):
pose_optimizer.zero_grad()
dx, drot = cam_trans_delta, cam_rot_delta
rot = rotation_6d_to_matrix(
drot + self.identity.expand(b, v, -1)
) # (..., 3, 3)
transform = torch.eye(4, device=extrinsics.device).repeat((b, v, 1, 1))
transform[..., :3, :3] = rot
transform[..., :3, 3] = dx
new_extrinsics = torch.matmul(extrinsics, transform)
output = self.model.decoder.forward(
gaussians,
new_extrinsics,
batch["target"]["intrinsics"],
batch["target"]["near"],
batch["target"]["far"],
(h, w),
# cam_rot_delta=cam_rot_delta,
# cam_trans_delta=cam_trans_delta,
)
# Compute and log loss.
total_loss = 0
for loss_fn in self.losses:
loss = loss_fn.forward(output, batch, gaussians, self.global_step)
total_loss = total_loss + loss
total_loss.backward()
pose_optimizer.step()
# Render Gaussians.
output = self.model.decoder.forward(
gaussians,
new_extrinsics,
batch["target"]["intrinsics"],
batch["target"]["near"],
batch["target"]["far"],
(h, w),
)
return output
def on_test_end(self) -> None:
name = get_cfg()["wandb"]["name"]
self.benchmarker.dump(self.test_cfg.output_path / name / "benchmark.json")
self.benchmarker.dump_memory(
self.test_cfg.output_path / name / "peak_memory.json"
)
self.benchmarker.summarize()
@rank_zero_only
def validation_step(self, batch, batch_idx, dataloader_idx=0):
batch: BatchedExample = self.data_shim(batch)
if self.global_rank == 0:
print(
f"validation step {self.global_step}; "
f"scene = {batch['scene']}; "
f"context = {batch['context']['index'].tolist()}"
)
# Render Gaussians.
b, v, _, h, w = batch["context"]["image"].shape
assert b == 1
visualization_dump = {}
encoder_output, output = self.model(batch["context"]["image"], self.global_step, visualization_dump=visualization_dump)
gaussians, pred_pose_enc_list, depth_dict = encoder_output.gaussians, encoder_output.pred_pose_enc_list, encoder_output.depth_dict
pred_context_pose, distill_infos = encoder_output.pred_context_pose, encoder_output.distill_infos
infos = encoder_output.infos
GS_num = infos['voxelize_ratio'] * (h*w*v)
self.log("val/GS_num", GS_num)
num_context_views = pred_context_pose['extrinsic'].shape[1]
num_target_views = batch["target"]["extrinsics"].shape[1]
rgb_pred = output.color[0].float()
depth_pred = vis_depth_map(output.depth[0])
# direct depth from gaussian means (used for visualization only)
gaussian_means = visualization_dump["depth"][0].squeeze()
if gaussian_means.shape[-1] == 3:
gaussian_means = gaussian_means.mean(dim=-1)
# Compute validation metrics.
rgb_gt = (batch["context"]["image"][0].float() + 1) / 2
psnr = compute_psnr(rgb_gt, rgb_pred).mean()
self.log(f"val/psnr", psnr)
lpips = compute_lpips(rgb_gt, rgb_pred).mean()
self.log(f"val/lpips", lpips)
ssim = compute_ssim(rgb_gt, rgb_pred).mean()
self.log(f"val/ssim", ssim)
# depth metrics
consis_absrel = abs_relative_difference(
rearrange(output.depth, "b v h w -> (b v) h w"),
rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"),
)
self.log("val/consis_absrel", consis_absrel.mean())
consis_delta1 = delta1_acc(
rearrange(output.depth, "b v h w -> (b v) h w"),
rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"),
valid_mask=rearrange(torch.ones_like(output.depth, device=output.depth.device, dtype=torch.bool), "b v h w -> (b v) h w"),
)
self.log("val/consis_delta1", consis_delta1.mean())
diff_map = torch.abs(output.depth - depth_dict['depth'].squeeze(-1))
self.log("val/consis_mse", diff_map[distill_infos['conf_mask']].mean())
# Construct comparison image.
context_img = inverse_normalize(batch["context"]["image"][0])
# context_img_depth = vis_depth_map(gaussian_means)
context = []
for i in range(context_img.shape[0]):
context.append(context_img[i])
# context.append(context_img_depth[i])
colored_diff_map = vis_depth_map(diff_map[0], near=torch.tensor(1e-4, device=diff_map.device), far=torch.tensor(1.0, device=diff_map.device))
model_depth_pred = depth_dict["depth"].squeeze(-1)[0]
model_depth_pred = vis_depth_map(model_depth_pred)
render_normal = (get_normal_map(output.depth.flatten(0, 1), batch["context"]["intrinsics"].flatten(0, 1)).permute(0, 3, 1, 2) + 1) / 2.
pred_normal = (get_normal_map(depth_dict['depth'].flatten(0, 1).squeeze(-1), batch["context"]["intrinsics"].flatten(0, 1)).permute(0, 3, 1, 2) + 1) / 2.
comparison = hcat(
add_label(vcat(*context), "Context"),
add_label(vcat(*rgb_gt), "Target (Ground Truth)"),
add_label(vcat(*rgb_pred), "Target (Prediction)"),
add_label(vcat(*depth_pred), "Depth (Prediction)"),
add_label(vcat(*model_depth_pred), "Depth (VGGT Prediction)"),
add_label(vcat(*render_normal), "Normal (Prediction)"),
add_label(vcat(*pred_normal), "Normal (VGGT Prediction)"),
add_label(vcat(*colored_diff_map), "Diff Map"),
)
comparison = torch.nn.functional.interpolate(
comparison.unsqueeze(0),
scale_factor=0.5,
mode='bicubic',
align_corners=False
).squeeze(0)
self.logger.log_image(
"comparison",
[prep_image(add_border(comparison))],
step=self.global_step,
caption=batch["scene"],
)
# self.logger.log_image(
# key="comparison",
# images=[wandb.Image(prep_image(add_border(comparison)), caption=batch["scene"], file_type="jpg")],
# step=self.global_step
# )
# Render projections and construct projection image.
# These are disabled for now, since RE10k scenes are effectively unbounded.
# if isinstance(gaussians, Gaussians):
# projections = hcat(
# *render_projections(
# gaussians,
# 256,
# extra_label="",
# )[0]
# )
# self.logger.log_image(
# "projection",
# [prep_image(add_border(projections))],
# step=self.global_step,
# )
# Draw cameras.
# cameras = hcat(*render_cameras(batch, 256))
# self.logger.log_image(
# "cameras", [prep_image(add_border(cameras))], step=self.global_step
# )
if self.encoder_visualizer is not None:
for k, image in self.encoder_visualizer.visualize(
batch["context"], self.global_step
).items():
self.logger.log_image(k, [prep_image(image)], step=self.global_step)
# Run video validation step.
self.render_video_interpolation(batch)
self.render_video_wobble(batch)
if self.train_cfg.extended_visualization:
self.render_video_interpolation_exaggerated(batch)
@rank_zero_only
def render_video_wobble(self, batch: BatchedExample) -> None:
# Two views are needed to get the wobble radius.
_, v, _, _ = batch["context"]["extrinsics"].shape
if v != 2:
return
def trajectory_fn(t):
origin_a = batch["context"]["extrinsics"][:, 0, :3, 3]
origin_b = batch["context"]["extrinsics"][:, 1, :3, 3]
delta = (origin_a - origin_b).norm(dim=-1)
extrinsics = generate_wobble(
batch["context"]["extrinsics"][:, 0],
delta * 0.25,
t,
)
intrinsics = repeat(
batch["context"]["intrinsics"][:, 0],
"b i j -> b v i j",
v=t.shape[0],
)
return extrinsics, intrinsics
return self.render_video_generic(batch, trajectory_fn, "wobble", num_frames=60)
@rank_zero_only
def render_video_interpolation(self, batch: BatchedExample) -> None:
_, v, _, _ = batch["context"]["extrinsics"].shape
def trajectory_fn(t):
extrinsics = interpolate_extrinsics(
batch["context"]["extrinsics"][0, 0],
(
batch["context"]["extrinsics"][0, 1]
if v == 2
else batch["target"]["extrinsics"][0, 0]
),
t,
)
intrinsics = interpolate_intrinsics(
batch["context"]["intrinsics"][0, 0],
(
batch["context"]["intrinsics"][0, 1]
if v == 2
else batch["target"]["intrinsics"][0, 0]
),
t,
)
return extrinsics[None], intrinsics[None]
return self.render_video_generic(batch, trajectory_fn, "rgb")
@rank_zero_only
def render_video_interpolation_exaggerated(self, batch: BatchedExample) -> None:
# Two views are needed to get the wobble radius.
_, v, _, _ = batch["context"]["extrinsics"].shape
if v != 2:
return
def trajectory_fn(t):
origin_a = batch["context"]["extrinsics"][:, 0, :3, 3]
origin_b = batch["context"]["extrinsics"][:, 1, :3, 3]
delta = (origin_a - origin_b).norm(dim=-1)
tf = generate_wobble_transformation(
delta * 0.5,
t,
5,
scale_radius_with_t=False,
)
extrinsics = interpolate_extrinsics(
batch["context"]["extrinsics"][0, 0],
(
batch["context"]["extrinsics"][0, 1]
if v == 2
else batch["target"]["extrinsics"][0, 0]
),
t * 5 - 2,
)
intrinsics = interpolate_intrinsics(
batch["context"]["intrinsics"][0, 0],
(
batch["context"]["intrinsics"][0, 1]
if v == 2
else batch["target"]["intrinsics"][0, 0]
),
t * 5 - 2,
)
return extrinsics @ tf, intrinsics[None]
return self.render_video_generic(
batch,
trajectory_fn,
"interpolation_exagerrated",
num_frames=300,
smooth=False,
loop_reverse=False,
)
@rank_zero_only
def render_video_generic(
self,
batch: BatchedExample,
trajectory_fn: TrajectoryFn,
name: str,
num_frames: int = 30,
smooth: bool = True,
loop_reverse: bool = True,
) -> None:
# Render probabilistic estimate of scene.
encoder_output = self.model.encoder((batch["context"]["image"]+1)/2, self.global_step)
gaussians, pred_pose_enc_list = encoder_output.gaussians, encoder_output.pred_pose_enc_list
t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=self.device)
if smooth:
t = (torch.cos(torch.pi * (t + 1)) + 1) / 2
extrinsics, intrinsics = trajectory_fn(t)
_, _, _, h, w = batch["context"]["image"].shape
# TODO: Interpolate near and far planes?
near = repeat(batch["context"]["near"][:, 0], "b -> b v", v=num_frames)
far = repeat(batch["context"]["far"][:, 0], "b -> b v", v=num_frames)
output = self.model.decoder.forward(
gaussians, extrinsics, intrinsics, near, far, (h, w), "depth"
)
images = [
vcat(rgb, depth)
for rgb, depth in zip(output.color[0], vis_depth_map(output.depth[0]))
]
video = torch.stack(images)
video = (video.clip(min=0, max=1) * 255).type(torch.uint8).cpu().numpy()
if loop_reverse:
video = pack([video, video[::-1][1:-1]], "* c h w")[0]
visualizations = {
f"video/{name}": wandb.Video(video[None], fps=30, format="mp4")
}
# Since the PyTorch Lightning doesn't support video logging, log to wandb directly.
try:
wandb.log(visualizations)
except Exception:
assert isinstance(self.logger, LocalLogger)
for key, value in visualizations.items():
tensor = value._prepare_video(value.data)
clip = mpy.ImageSequenceClip(list(tensor), fps=30)
dir = LOG_PATH / key
dir.mkdir(exist_ok=True, parents=True)
clip.write_videofile(
str(dir / f"{self.global_step:0>6}.mp4"), logger=None
)
def print_preview_metrics(self, metrics: dict[str, float | Tensor], methods: list[str] | None = None, overlap_tag: str | None = None) -> None:
if getattr(self, "running_metrics", None) is None:
self.running_metrics = metrics
self.running_metric_steps = 1
else:
s = self.running_metric_steps
self.running_metrics = {
k: ((s * v) + metrics[k]) / (s + 1)
for k, v in self.running_metrics.items()
}
self.running_metric_steps += 1
if overlap_tag is not None:
if getattr(self, "running_metrics_sub", None) is None:
self.running_metrics_sub = {overlap_tag: metrics}
self.running_metric_steps_sub = {overlap_tag: 1}
elif overlap_tag not in self.running_metrics_sub:
self.running_metrics_sub[overlap_tag] = metrics
self.running_metric_steps_sub[overlap_tag] = 1
else:
s = self.running_metric_steps_sub[overlap_tag]
self.running_metrics_sub[overlap_tag] = {k: ((s * v) + metrics[k]) / (s + 1)
for k, v in self.running_metrics_sub[overlap_tag].items()}
self.running_metric_steps_sub[overlap_tag] += 1
metric_list = ["psnr", "lpips", "ssim"]
def print_metrics(runing_metric, methods=None):
table = []
if methods is None:
methods = ['ours']
for method in methods:
row = [
f"{runing_metric[f'{metric}_{method}']:.3f}"
for metric in metric_list
]
table.append((method, *row))
headers = ["Method"] + metric_list
table = tabulate(table, headers)
print(table)
print("All Pairs:")
print_metrics(self.running_metrics, methods)
if overlap_tag is not None:
for k, v in self.running_metrics_sub.items():
print(f"Overlap: {k}")
print_metrics(v, methods)
def configure_optimizers(self):
new_params, new_param_names = [], []
pretrained_params, pretrained_param_names = [], []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
if "gaussian_param_head" in name or "interm" in name:
new_params.append(param)
new_param_names.append(name)
else:
pretrained_params.append(param)
pretrained_param_names.append(name)
param_dicts = [
{
"params": new_params,
"lr": self.optimizer_cfg.lr,
},
{
"params": pretrained_params,
"lr": self.optimizer_cfg.lr * self.optimizer_cfg.backbone_lr_multiplier,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=self.optimizer_cfg.lr, weight_decay=0.05, betas=(0.9, 0.95))
warm_up_steps = self.optimizer_cfg.warm_up_steps
warm_up = torch.optim.lr_scheduler.LinearLR(
optimizer,
1 / warm_up_steps,
1,
total_iters=warm_up_steps,
)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=get_cfg()["trainer"]["max_steps"], eta_min=self.optimizer_cfg.lr * 0.1)
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warm_up, lr_scheduler], milestones=[warm_up_steps])
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
"frequency": 1,
},
}