|
import json |
|
import math |
|
import os |
|
import time |
|
from collections import defaultdict |
|
from dataclasses import dataclass, field |
|
from pathlib import Path |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import imageio |
|
import matplotlib |
|
import torchvision |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import tqdm |
|
import tyro |
|
import viser |
|
import yaml |
|
import torchvision |
|
import sys |
|
from plyfile import PlyData, PlyElement |
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) |
|
|
|
from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri |
|
from src.model.types import Gaussians |
|
from src.post_opt.datasets.colmap import Dataset, Parser |
|
from src.post_opt.datasets.traj import ( |
|
generate_ellipse_path_z, |
|
generate_interpolated_path, |
|
generate_spiral_path, |
|
) |
|
from fused_ssim import fused_ssim |
|
|
|
from src.utils.image import process_image |
|
from src.post_opt.exporter import export_splats |
|
from src.post_opt.lib_bilagrid import BilateralGrid, color_correct, slice, total_variation_loss |
|
from torch import Tensor |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.utils.tensorboard import SummaryWriter |
|
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure |
|
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity |
|
from typing_extensions import Literal, assert_never |
|
from src.post_opt.utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed |
|
|
|
|
|
from gsplat.compression import PngCompression |
|
from gsplat.distributed import cli |
|
|
|
|
|
from gsplat import rasterization |
|
from gsplat.strategy import DefaultStrategy, MCMCStrategy |
|
from src.post_opt.gsplat_viewer import GsplatViewer, GsplatRenderTabState |
|
from nerfview import CameraState, RenderTabState, apply_float_colormap |
|
|
|
import torch |
|
from einops import rearrange |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
from scipy.spatial.transform import Rotation as R |
|
|
|
from src.model.model.anysplat import AnySplat |
|
|
|
|
|
|
|
def quaternion_to_matrix( |
|
quaternions: Float[Tensor, "*batch 4"], |
|
eps: float = 1e-8, |
|
) -> Float[Tensor, "*batch 3 3"]: |
|
|
|
i, j, k, r = torch.unbind(quaternions, dim=-1) |
|
two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) |
|
|
|
o = torch.stack( |
|
( |
|
1 - two_s * (j * j + k * k), |
|
two_s * (i * j - k * r), |
|
two_s * (i * k + j * r), |
|
two_s * (i * j + k * r), |
|
1 - two_s * (i * i + k * k), |
|
two_s * (j * k - i * r), |
|
two_s * (i * k - j * r), |
|
two_s * (j * k + i * r), |
|
1 - two_s * (i * i + j * j), |
|
), |
|
-1, |
|
) |
|
return rearrange(o, "... (i j) -> ... i j", i=3, j=3) |
|
|
|
def construct_list_of_attributes(num_rest: int) -> list[str]: |
|
attributes = ["x", "y", "z", "nx", "ny", "nz"] |
|
for i in range(3): |
|
attributes.append(f"f_dc_{i}") |
|
for i in range(num_rest): |
|
attributes.append(f"f_rest_{i}") |
|
attributes.append("opacity") |
|
for i in range(3): |
|
attributes.append(f"scale_{i}") |
|
for i in range(4): |
|
attributes.append(f"rot_{i}") |
|
return attributes |
|
|
|
def export_ply( |
|
means: Float[Tensor, "gaussian 3"], |
|
scales: Float[Tensor, "gaussian 3"], |
|
rotations: Float[Tensor, "gaussian 4"], |
|
harmonics: Float[Tensor, "gaussian 3 d_sh"], |
|
opacities: Float[Tensor, " gaussian"], |
|
path: Path, |
|
shift_and_scale: bool = False, |
|
save_sh_dc_only: bool = True, |
|
): |
|
if shift_and_scale: |
|
|
|
means = means - means.median(dim=0).values |
|
|
|
|
|
scale_factor = means.abs().quantile(0.95, dim=0).max() |
|
means = means / scale_factor |
|
scales = scales / scale_factor |
|
|
|
|
|
rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() |
|
rotations = R.from_matrix(rotations).as_quat() |
|
x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") |
|
rotations = np.stack((w, x, y, z), axis=-1) |
|
|
|
|
|
|
|
f_dc = harmonics[..., 0] |
|
f_rest = harmonics[..., 1:].flatten(start_dim=1) |
|
|
|
dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0 if save_sh_dc_only else f_rest.shape[1])] |
|
elements = np.empty(means.shape[0], dtype=dtype_full) |
|
attributes = [ |
|
means.detach().cpu().numpy(), |
|
torch.zeros_like(means).detach().cpu().numpy(), |
|
f_dc.detach().cpu().contiguous().numpy(), |
|
f_rest.detach().cpu().contiguous().numpy(), |
|
opacities[..., None].detach().cpu().numpy(), |
|
scales.detach().cpu().numpy(), |
|
rotations, |
|
] |
|
if save_sh_dc_only: |
|
|
|
attributes.pop(3) |
|
|
|
attributes = np.concatenate(attributes, axis=1) |
|
elements[:] = list(map(tuple, attributes)) |
|
path.parent.mkdir(exist_ok=True, parents=True) |
|
PlyData([PlyElement.describe(elements, "vertex")]).write(path) |
|
|
|
def colorize_depth_maps(depth_map, min_depth=0.0, max_depth=1.0, cmap="Spectral", valid_mask=None): |
|
""" |
|
Colorize depth maps. |
|
""" |
|
assert len(depth_map.shape) >= 2, "Invalid dimension" |
|
|
|
if isinstance(depth_map, torch.Tensor): |
|
depth = depth_map.detach().clone().squeeze().numpy() |
|
elif isinstance(depth_map, np.ndarray): |
|
depth = depth_map.copy().squeeze() |
|
|
|
if depth.ndim < 3: |
|
depth = depth[np.newaxis, :, :] |
|
|
|
|
|
cm = matplotlib.colormaps[cmap] |
|
|
|
depth = ((depth - depth.min()) / (depth.max() - depth.min())).clip(0, 1) |
|
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] |
|
img_colored_np = np.rollaxis(img_colored_np, 3, 1) |
|
|
|
if valid_mask is not None: |
|
if isinstance(depth_map, torch.Tensor): |
|
valid_mask = valid_mask.detach().numpy() |
|
valid_mask = valid_mask.squeeze() |
|
if valid_mask.ndim < 3: |
|
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] |
|
else: |
|
valid_mask = valid_mask[:, np.newaxis, :, :] |
|
valid_mask = np.repeat(valid_mask, 3, axis=1) |
|
img_colored_np[~valid_mask] = 0 |
|
|
|
if isinstance(depth_map, torch.Tensor): |
|
img_colored = torch.from_numpy(img_colored_np).float() |
|
elif isinstance(depth_map, np.ndarray): |
|
img_colored = img_colored_np |
|
|
|
return img_colored |
|
|
|
def build_covariance( |
|
scale: Float[Tensor, "*#batch 3"], |
|
rotation_xyzw: Float[Tensor, "*#batch 4"], |
|
) -> Float[Tensor, "*batch 3 3"]: |
|
scale = scale.diag_embed() |
|
rotation = quaternion_to_matrix(rotation_xyzw) |
|
return ( |
|
rotation |
|
@ scale |
|
@ rearrange(scale, "... i j -> ... j i") |
|
@ rearrange(rotation, "... i j -> ... j i") |
|
) |
|
|
|
|
|
@dataclass |
|
class Config: |
|
|
|
disable_viewer: bool = True |
|
|
|
ckpt: Optional[List[str]] = None |
|
|
|
compression: Optional[Literal["png"]] = None |
|
|
|
render_traj_path: str = "interp" |
|
|
|
data_dir: str = "data/360_v2/garden" |
|
|
|
data_factor: int = 4 |
|
|
|
result_dir: str = "results/garden" |
|
|
|
test_every: int = 8 |
|
|
|
patch_size: Optional[int] = None |
|
|
|
global_scale: float = 1.0 |
|
|
|
normalize_world_space: bool = True |
|
|
|
camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole" |
|
|
|
|
|
port: int = 8080 |
|
|
|
|
|
batch_size: int = 1 |
|
|
|
steps_scaler: float = 1.0 |
|
|
|
|
|
max_steps: int = 3_000 |
|
|
|
eval_steps: List[int] = field(default_factory=lambda: [1, 1_000, 3_000, 7_000, 10_000]) |
|
|
|
save_steps: List[int] = field(default_factory=lambda: [1, 1_000, 3_000, 7_000, 10_000]) |
|
|
|
save_ply: bool = False |
|
|
|
ply_steps: List[int] = field(default_factory=lambda: [1, 1_000, 3_000, 7_000, 10_000]) |
|
|
|
disable_video: bool = False |
|
|
|
|
|
init_type: str = "sfm" |
|
|
|
init_num_pts: int = 100_000 |
|
|
|
init_extent: float = 3.0 |
|
|
|
sh_degree: int = 4 |
|
|
|
sh_degree_interval: int = 1000 |
|
|
|
init_opa: float = 0.1 |
|
|
|
init_scale: float = 1.0 |
|
|
|
ssim_lambda: float = 0.2 |
|
|
|
|
|
near_plane: float = 1e-10 |
|
|
|
far_plane: float = 1e10 |
|
|
|
|
|
strategy: Union[DefaultStrategy, MCMCStrategy] = field( |
|
default_factory=DefaultStrategy |
|
) |
|
|
|
packed: bool = False |
|
|
|
sparse_grad: bool = False |
|
|
|
visible_adam: bool = False |
|
|
|
antialiased: bool = False |
|
|
|
|
|
random_bkgd: bool = False |
|
|
|
|
|
opacity_reg: float = 0.0 |
|
|
|
scale_reg: float = 0.0 |
|
|
|
|
|
pose_opt: bool = True |
|
|
|
pose_opt_lr: float = 1e-5 |
|
|
|
pose_opt_reg: float = 1e-6 |
|
|
|
pose_noise: float = 0.0 |
|
|
|
|
|
app_opt: bool = False |
|
|
|
app_embed_dim: int = 16 |
|
|
|
app_opt_lr: float = 1e-3 |
|
|
|
app_opt_reg: float = 1e-6 |
|
|
|
|
|
use_bilateral_grid: bool = False |
|
|
|
bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8) |
|
|
|
|
|
depth_loss: bool = False |
|
|
|
depth_lambda: float = 1e-2 |
|
|
|
|
|
tb_every: int = 100 |
|
|
|
tb_save_image: bool = False |
|
|
|
lpips_net: Literal["vgg", "alex"] = "vgg" |
|
|
|
lr_means: float = 1.6e-4 |
|
lr_scales: float = 5e-3 |
|
lr_quats: float = 1e-3 |
|
lr_opacities: float = 5e-2 |
|
lr_sh: float = 2.5e-3 |
|
|
|
def adjust_steps(self, factor: float): |
|
self.eval_steps = [int(i * factor) for i in self.eval_steps] |
|
self.save_steps = [int(i * factor) for i in self.save_steps] |
|
self.ply_steps = [int(i * factor) for i in self.ply_steps] |
|
self.max_steps = int(self.max_steps * factor) |
|
self.sh_degree_interval = int(self.sh_degree_interval * factor) |
|
|
|
strategy = self.strategy |
|
if isinstance(strategy, DefaultStrategy): |
|
|
|
|
|
|
|
|
|
|
|
strategy.refine_start_iter = 30000 |
|
strategy.refine_stop_iter = 0 |
|
strategy.reset_every = 30000 |
|
strategy.refine_every = 30000 |
|
|
|
elif isinstance(strategy, MCMCStrategy): |
|
strategy.refine_start_iter = int(strategy.refine_start_iter * factor) |
|
strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) |
|
strategy.refine_every = int(strategy.refine_every * factor) |
|
else: |
|
assert_never(strategy) |
|
|
|
|
|
def create_splats_with_optimizers( |
|
gaussians: Gaussians, |
|
init_num_pts: int = 100_000, |
|
init_extent: float = 3.0, |
|
init_opacity: float = 0.1, |
|
init_scale: float = 1.0, |
|
sh_degree: int = 3, |
|
sparse_grad: bool = False, |
|
visible_adam: bool = False, |
|
batch_size: int = 1, |
|
feature_dim: Optional[int] = None, |
|
device: str = "cuda", |
|
world_rank: int = 0, |
|
world_size: int = 1, |
|
cfg: Config = None, |
|
) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: |
|
|
|
points = gaussians.means[0].detach().float() |
|
scales = torch.log(gaussians.scales[0].detach().float()) |
|
quats = gaussians.rotations[0].detach().float() |
|
opacities = torch.logit(gaussians.opacities[0].detach().float()) |
|
harmonics = gaussians.harmonics[0].detach().float().permute(0, 2, 1).contiguous() |
|
|
|
N = points.shape[0] |
|
|
|
scene_scale = 1.0 |
|
masks = opacities.sigmoid() > 0.01 |
|
harmonics = harmonics[masks] |
|
params = [ |
|
|
|
("means", torch.nn.Parameter(points[masks]), cfg.lr_means * scene_scale), |
|
("scales", torch.nn.Parameter(scales[masks]), cfg.lr_scales), |
|
("quats", torch.nn.Parameter(quats[masks]), cfg.lr_quats), |
|
("opacities", torch.nn.Parameter(opacities[masks]), cfg.lr_opacities), |
|
] |
|
|
|
params.append(("sh0", torch.nn.Parameter(harmonics[:, :1, :]), cfg.lr_sh)) |
|
params.append(("shN", torch.nn.Parameter(harmonics[:, 1:, :]), cfg.lr_sh/20)) |
|
|
|
splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) |
|
|
|
|
|
|
|
|
|
BS = batch_size * world_size |
|
optimizer_class = None |
|
if sparse_grad: |
|
optimizer_class = torch.optim.SparseAdam |
|
elif visible_adam: |
|
optimizer_class = SelectiveAdam |
|
else: |
|
optimizer_class = torch.optim.Adam |
|
optimizers = { |
|
name: optimizer_class( |
|
[{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], |
|
eps=1e-15 / math.sqrt(BS), |
|
|
|
betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), |
|
) |
|
for name, _, lr in params |
|
} |
|
return splats, optimizers |
|
|
|
|
|
class Runner: |
|
"""Engine for training and testing.""" |
|
|
|
def __init__( |
|
self, local_rank: int, world_rank, world_size: int, cfg: Config |
|
) -> None: |
|
set_random_seed(42 + local_rank) |
|
|
|
self.cfg = cfg |
|
self.world_rank = world_rank |
|
self.local_rank = local_rank |
|
self.world_size = world_size |
|
self.device = f"cuda:{local_rank}" |
|
|
|
|
|
os.makedirs(cfg.result_dir, exist_ok=True) |
|
|
|
|
|
self.ckpt_dir = f"{cfg.result_dir}/ckpts" |
|
os.makedirs(self.ckpt_dir, exist_ok=True) |
|
self.stats_dir = f"{cfg.result_dir}/stats" |
|
os.makedirs(self.stats_dir, exist_ok=True) |
|
self.render_dir = f"{cfg.result_dir}/renders" |
|
os.makedirs(self.render_dir, exist_ok=True) |
|
self.ply_dir = f"{cfg.result_dir}/ply" |
|
os.makedirs(self.ply_dir, exist_ok=True) |
|
|
|
|
|
self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") |
|
|
|
|
|
model = AnySplat.from_pretrained("lhjiang/anysplat") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
model.eval() |
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
image_folder = cfg.data_dir |
|
image_names = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) |
|
images = [process_image(img_path) for img_path in image_names] |
|
ctx_indices = [idx for idx, name in enumerate(image_names) if idx % cfg.test_every != 0] |
|
tgt_indices = [idx for idx, name in enumerate(image_names) if idx % cfg.test_every == 0] |
|
|
|
ctx_images = torch.stack([images[i] for i in ctx_indices], dim=0).unsqueeze(0).to(device) |
|
tgt_images = torch.stack([images[i] for i in tgt_indices], dim=0).unsqueeze(0).to(device) |
|
ctx_images = (ctx_images+1)*0.5 |
|
tgt_images = (tgt_images+1)*0.5 |
|
b, v, _, h, w = tgt_images.shape |
|
|
|
|
|
encoder_output = model.encoder( |
|
ctx_images, |
|
global_step=0, |
|
visualization_dump={}, |
|
) |
|
gaussians, pred_context_pose = encoder_output.gaussians, encoder_output.pred_context_pose |
|
|
|
num_context_view = ctx_images.shape[1] |
|
vggt_input_image = torch.cat((ctx_images, tgt_images), dim=1).to(torch.bfloat16) |
|
with torch.no_grad(), torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16): |
|
aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(vggt_input_image, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx) |
|
with torch.cuda.amp.autocast(enabled=False): |
|
fp32_tokens = [token.float() for token in aggregated_tokens_list] |
|
pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1] |
|
pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, vggt_input_image.shape[-2:]) |
|
|
|
extrinsic_padding = torch.tensor([0, 0, 0, 1], device=pred_all_extrinsic.device, dtype=pred_all_extrinsic.dtype).view(1, 1, 1, 4).repeat(b, vggt_input_image.shape[1], 1, 1) |
|
pred_all_extrinsic = torch.cat([pred_all_extrinsic, extrinsic_padding], dim=2).inverse() |
|
|
|
pred_all_intrinsic[:, :, 0] = pred_all_intrinsic[:, :, 0] / w |
|
pred_all_intrinsic[:, :, 1] = pred_all_intrinsic[:, :, 1] / h |
|
pred_all_context_extrinsic, pred_all_target_extrinsic = pred_all_extrinsic[:, :num_context_view], pred_all_extrinsic[:, num_context_view:] |
|
pred_all_context_intrinsic, pred_all_target_intrinsic = pred_all_intrinsic[:, :num_context_view], pred_all_intrinsic[:, num_context_view:] |
|
|
|
scale_factor = pred_context_pose['extrinsic'][:, :, :3, 3].mean() / pred_all_context_extrinsic[:, :, :3, 3].mean() |
|
pred_all_target_extrinsic[..., :3, 3] = pred_all_target_extrinsic[..., :3, 3] * scale_factor |
|
pred_all_context_extrinsic[..., :3, 3] = pred_all_context_extrinsic[..., :3, 3] * scale_factor |
|
print("scale_factor:", scale_factor) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.trainset = Dataset( |
|
split="train", |
|
images=ctx_images[0].detach().cpu().numpy(), |
|
camtoworlds=pred_all_context_extrinsic[0].detach().cpu().numpy(), |
|
Ks=pred_all_context_intrinsic[0].detach().cpu().numpy(), |
|
patch_size=cfg.patch_size, |
|
load_depths=cfg.depth_loss, |
|
) |
|
self.valset = Dataset( |
|
images=tgt_images[0].detach().cpu().numpy(), |
|
camtoworlds=pred_all_target_extrinsic[0].detach().cpu().numpy(), |
|
Ks=pred_all_target_intrinsic[0].detach().cpu().numpy(), |
|
split="val" |
|
) |
|
|
|
|
|
feature_dim = 32 if cfg.app_opt else None |
|
self.splats, self.optimizers = create_splats_with_optimizers( |
|
gaussians=gaussians, |
|
init_num_pts=cfg.init_num_pts, |
|
init_extent=cfg.init_extent, |
|
init_opacity=cfg.init_opa, |
|
init_scale=cfg.init_scale, |
|
sh_degree=cfg.sh_degree, |
|
sparse_grad=cfg.sparse_grad, |
|
visible_adam=cfg.visible_adam, |
|
batch_size=cfg.batch_size, |
|
feature_dim=feature_dim, |
|
device=self.device, |
|
world_rank=world_rank, |
|
world_size=world_size, |
|
cfg=cfg, |
|
) |
|
print("Model initialized. Number of GS:", len(self.splats["means"])) |
|
|
|
|
|
self.cfg.strategy.check_sanity(self.splats, self.optimizers) |
|
|
|
if isinstance(self.cfg.strategy, DefaultStrategy): |
|
self.strategy_state = self.cfg.strategy.initialize_state( |
|
scene_scale=1.0 |
|
) |
|
elif isinstance(self.cfg.strategy, MCMCStrategy): |
|
self.strategy_state = self.cfg.strategy.initialize_state() |
|
else: |
|
assert_never(self.cfg.strategy) |
|
|
|
|
|
self.compression_method = None |
|
if cfg.compression is not None: |
|
if cfg.compression == "png": |
|
self.compression_method = PngCompression() |
|
else: |
|
raise ValueError(f"Unknown compression strategy: {cfg.compression}") |
|
|
|
self.pose_optimizers = [] |
|
if cfg.pose_opt: |
|
self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) |
|
self.pose_adjust.zero_init() |
|
self.pose_optimizers = [ |
|
torch.optim.Adam( |
|
self.pose_adjust.parameters(), |
|
lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), |
|
weight_decay=cfg.pose_opt_reg, |
|
) |
|
] |
|
if world_size > 1: |
|
self.pose_adjust = DDP(self.pose_adjust) |
|
|
|
if cfg.pose_noise > 0.0: |
|
self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) |
|
self.pose_perturb.random_init(cfg.pose_noise) |
|
if world_size > 1: |
|
self.pose_perturb = DDP(self.pose_perturb) |
|
|
|
self.app_optimizers = [] |
|
if cfg.app_opt: |
|
assert feature_dim is not None |
|
self.app_module = AppearanceOptModule( |
|
len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree |
|
).to(self.device) |
|
|
|
torch.nn.init.zeros_(self.app_module.color_head[-1].weight) |
|
torch.nn.init.zeros_(self.app_module.color_head[-1].bias) |
|
self.app_optimizers = [ |
|
torch.optim.Adam( |
|
self.app_module.embeds.parameters(), |
|
lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, |
|
weight_decay=cfg.app_opt_reg, |
|
), |
|
torch.optim.Adam( |
|
self.app_module.color_head.parameters(), |
|
lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), |
|
), |
|
] |
|
if world_size > 1: |
|
self.app_module = DDP(self.app_module) |
|
|
|
self.bil_grid_optimizers = [] |
|
if cfg.use_bilateral_grid: |
|
self.bil_grids = BilateralGrid( |
|
len(self.trainset), |
|
grid_X=cfg.bilateral_grid_shape[0], |
|
grid_Y=cfg.bilateral_grid_shape[1], |
|
grid_W=cfg.bilateral_grid_shape[2], |
|
).to(self.device) |
|
self.bil_grid_optimizers = [ |
|
torch.optim.Adam( |
|
self.bil_grids.parameters(), |
|
lr=2e-3 * math.sqrt(cfg.batch_size), |
|
eps=1e-15, |
|
), |
|
] |
|
|
|
|
|
self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) |
|
self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) |
|
|
|
if cfg.lpips_net == "alex": |
|
self.lpips = LearnedPerceptualImagePatchSimilarity( |
|
net_type="alex", normalize=True |
|
).to(self.device) |
|
elif cfg.lpips_net == "vgg": |
|
|
|
self.lpips = LearnedPerceptualImagePatchSimilarity( |
|
net_type="vgg", normalize=False |
|
).to(self.device) |
|
else: |
|
raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}") |
|
|
|
|
|
if not self.cfg.disable_viewer: |
|
self.server = viser.ViserServer(port=cfg.port, verbose=False) |
|
self.viewer = GsplatViewer( |
|
server=self.server, |
|
render_fn=self._viewer_render_fn, |
|
output_dir=Path(cfg.result_dir), |
|
mode="training", |
|
) |
|
|
|
def rasterize_splats( |
|
self, |
|
camtoworlds: Tensor, |
|
Ks: Tensor, |
|
width: int, |
|
height: int, |
|
masks: Optional[Tensor] = None, |
|
rasterize_mode: Optional[Literal["classic", "antialiased"]] = None, |
|
camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None, |
|
**kwargs, |
|
) -> Tuple[Tensor, Tensor, Dict]: |
|
means = self.splats["means"] |
|
|
|
|
|
quats = self.splats["quats"] |
|
scales = torch.exp(self.splats["scales"]) |
|
opacities = torch.sigmoid(self.splats["opacities"]) |
|
|
|
image_ids = kwargs.pop("image_ids", None) |
|
if self.cfg.app_opt: |
|
colors = self.app_module( |
|
features=self.splats["features"], |
|
embed_ids=image_ids, |
|
dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], |
|
sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), |
|
) |
|
colors = colors + self.splats["colors"] |
|
colors = torch.sigmoid(colors) |
|
else: |
|
colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) |
|
|
|
if rasterize_mode is None: |
|
rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" |
|
if camera_model is None: |
|
camera_model = self.cfg.camera_model |
|
|
|
|
|
render_colors, render_alphas, info = rasterization( |
|
means=means, |
|
quats=quats, |
|
scales=scales, |
|
opacities=opacities, |
|
colors=colors, |
|
|
|
viewmats=torch.linalg.inv(camtoworlds), |
|
Ks=Ks, |
|
width=width, |
|
height=height, |
|
packed=self.cfg.packed, |
|
absgrad=( |
|
self.cfg.strategy.absgrad |
|
if isinstance(self.cfg.strategy, DefaultStrategy) |
|
else False |
|
), |
|
sparse_grad=self.cfg.sparse_grad, |
|
rasterize_mode=rasterize_mode, |
|
distributed=self.world_size > 1, |
|
camera_model=self.cfg.camera_model, |
|
radius_clip=0.1, |
|
backgrounds=torch.tensor([0.0, 0.0, 0.0]).cuda().unsqueeze(0).repeat(1, 1), |
|
**kwargs, |
|
) |
|
if masks is not None: |
|
render_colors[~masks] = 0 |
|
return render_colors, render_alphas, info |
|
|
|
def train(self): |
|
cfg = self.cfg |
|
device = self.device |
|
world_rank = self.world_rank |
|
world_size = self.world_size |
|
|
|
|
|
if world_rank == 0: |
|
with open(f"{cfg.result_dir}/cfg.yml", "w") as f: |
|
yaml.dump(vars(cfg), f) |
|
|
|
max_steps = cfg.max_steps |
|
init_step = 0 |
|
|
|
schedulers = [ |
|
|
|
torch.optim.lr_scheduler.ExponentialLR( |
|
self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps) |
|
), |
|
] |
|
if cfg.pose_opt: |
|
|
|
schedulers.append( |
|
torch.optim.lr_scheduler.ExponentialLR( |
|
self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps) |
|
) |
|
) |
|
if cfg.use_bilateral_grid: |
|
|
|
schedulers.append( |
|
torch.optim.lr_scheduler.ChainedScheduler( |
|
[ |
|
torch.optim.lr_scheduler.LinearLR( |
|
self.bil_grid_optimizers[0], |
|
start_factor=0.01, |
|
total_iters=1000, |
|
), |
|
torch.optim.lr_scheduler.ExponentialLR( |
|
self.bil_grid_optimizers[0], gamma=0.01 ** (1.0 / max_steps) |
|
), |
|
] |
|
) |
|
) |
|
|
|
trainloader = torch.utils.data.DataLoader( |
|
self.trainset, |
|
batch_size=cfg.batch_size, |
|
shuffle=True, |
|
num_workers=4, |
|
persistent_workers=True, |
|
pin_memory=True, |
|
) |
|
trainloader_iter = iter(trainloader) |
|
|
|
|
|
global_tic = time.time() |
|
pbar = tqdm.tqdm(range(init_step, max_steps)) |
|
for step in pbar: |
|
if not cfg.disable_viewer: |
|
while self.viewer.state == "paused": |
|
time.sleep(0.01) |
|
self.viewer.lock.acquire() |
|
tic = time.time() |
|
|
|
try: |
|
data = next(trainloader_iter) |
|
except StopIteration: |
|
trainloader_iter = iter(trainloader) |
|
data = next(trainloader_iter) |
|
|
|
camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) |
|
Ks = data["K"].to(device) |
|
pixels = data["image"].to(device) / 255.0 |
|
num_train_rays_per_step = ( |
|
pixels.shape[0] * pixels.shape[1] * pixels.shape[2] |
|
) |
|
image_ids = data["image_id"].to(device) |
|
masks = data["mask"].to(device) if "mask" in data else None |
|
if cfg.depth_loss: |
|
points = data["points"].to(device) |
|
depths_gt = data["depths"].to(device) |
|
|
|
height, width = pixels.shape[1:3] |
|
|
|
if cfg.pose_noise: |
|
camtoworlds = self.pose_perturb(camtoworlds, image_ids) |
|
|
|
if cfg.pose_opt: |
|
camtoworlds = self.pose_adjust(camtoworlds, image_ids) |
|
|
|
|
|
|
|
sh_degree_to_use = cfg.sh_degree |
|
|
|
|
|
renders, alphas, info = self.rasterize_splats( |
|
camtoworlds=camtoworlds, |
|
Ks=Ks, |
|
width=width, |
|
height=height, |
|
sh_degree=sh_degree_to_use, |
|
near_plane=cfg.near_plane, |
|
far_plane=cfg.far_plane, |
|
image_ids=image_ids, |
|
render_mode="RGB+ED" if cfg.depth_loss else "RGB", |
|
masks=masks, |
|
) |
|
if renders.shape[-1] == 4: |
|
colors, depths = renders[..., 0:3], renders[..., 3:4] |
|
else: |
|
colors, depths = renders, None |
|
|
|
if cfg.use_bilateral_grid: |
|
grid_y, grid_x = torch.meshgrid( |
|
(torch.arange(height, device=self.device) + 0.5) / height, |
|
(torch.arange(width, device=self.device) + 0.5) / width, |
|
indexing="ij", |
|
) |
|
grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) |
|
colors = slice(self.bil_grids, grid_xy, colors, image_ids)["rgb"] |
|
|
|
if cfg.random_bkgd: |
|
bkgd = torch.rand(1, 3, device=device) |
|
colors = colors + bkgd * (1.0 - alphas) |
|
|
|
self.cfg.strategy.step_pre_backward( |
|
params=self.splats, |
|
optimizers=self.optimizers, |
|
state=self.strategy_state, |
|
step=step, |
|
info=info, |
|
) |
|
|
|
|
|
l1loss = F.l1_loss(colors, pixels) |
|
ssimloss = 1.0 - fused_ssim( |
|
colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid" |
|
) |
|
loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda |
|
if cfg.depth_loss: |
|
|
|
points = torch.stack( |
|
[ |
|
points[:, :, 0] / (width - 1) * 2 - 1, |
|
points[:, :, 1] / (height - 1) * 2 - 1, |
|
], |
|
dim=-1, |
|
) |
|
grid = points.unsqueeze(2) |
|
depths = F.grid_sample( |
|
depths.permute(0, 3, 1, 2), grid, align_corners=True |
|
) |
|
depths = depths.squeeze(3).squeeze(1) |
|
|
|
disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths)) |
|
disp_gt = 1.0 / depths_gt |
|
depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale |
|
loss += depthloss * cfg.depth_lambda |
|
if cfg.use_bilateral_grid: |
|
tvloss = 10 * total_variation_loss(self.bil_grids.grids) |
|
loss += tvloss |
|
|
|
|
|
if cfg.opacity_reg > 0.0: |
|
loss = ( |
|
loss |
|
+ cfg.opacity_reg |
|
* torch.abs(torch.sigmoid(self.splats["opacities"])).mean() |
|
) |
|
if cfg.scale_reg > 0.0: |
|
loss = ( |
|
loss |
|
+ cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() |
|
) |
|
|
|
loss.backward() |
|
|
|
desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " |
|
if cfg.depth_loss: |
|
desc += f"depth loss={depthloss.item():.6f}| " |
|
if cfg.pose_opt and cfg.pose_noise: |
|
|
|
pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) |
|
desc += f"pose err={pose_err.item():.6f}| " |
|
pbar.set_description(desc) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0: |
|
mem = torch.cuda.max_memory_allocated() / 1024**3 |
|
self.writer.add_scalar("train/loss", loss.item(), step) |
|
self.writer.add_scalar("train/l1loss", l1loss.item(), step) |
|
self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) |
|
self.writer.add_scalar("train/num_GS", len(self.splats["means"]), step) |
|
self.writer.add_scalar("train/mem", mem, step) |
|
if cfg.depth_loss: |
|
self.writer.add_scalar("train/depthloss", depthloss.item(), step) |
|
if cfg.use_bilateral_grid: |
|
self.writer.add_scalar("train/tvloss", tvloss.item(), step) |
|
if cfg.tb_save_image: |
|
canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() |
|
canvas = canvas.reshape(-1, *canvas.shape[2:]) |
|
self.writer.add_image("train/render", canvas, step) |
|
self.writer.flush() |
|
|
|
|
|
if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: |
|
mem = torch.cuda.max_memory_allocated() / 1024**3 |
|
stats = { |
|
"mem": mem, |
|
"ellipse_time": time.time() - global_tic, |
|
"num_GS": len(self.splats["means"]), |
|
} |
|
print("Step: ", step, stats) |
|
with open( |
|
f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json", |
|
"w", |
|
) as f: |
|
json.dump(stats, f) |
|
data = {"step": step, "splats": self.splats.state_dict()} |
|
if cfg.pose_opt: |
|
if world_size > 1: |
|
data["pose_adjust"] = self.pose_adjust.module.state_dict() |
|
else: |
|
data["pose_adjust"] = self.pose_adjust.state_dict() |
|
if cfg.app_opt: |
|
if world_size > 1: |
|
data["app_module"] = self.app_module.module.state_dict() |
|
else: |
|
data["app_module"] = self.app_module.state_dict() |
|
torch.save( |
|
data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" |
|
) |
|
if ( |
|
step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1 |
|
) and cfg.save_ply: |
|
|
|
if self.cfg.app_opt: |
|
|
|
rgb = self.app_module( |
|
features=self.splats["features"], |
|
embed_ids=None, |
|
dirs=torch.zeros_like(self.splats["means"][None, :, :]), |
|
sh_degree=sh_degree_to_use, |
|
) |
|
rgb = rgb + self.splats["colors"] |
|
rgb = torch.sigmoid(rgb).squeeze(0).unsqueeze(1) |
|
sh0 = rgb_to_sh(rgb) |
|
shN = torch.empty([sh0.shape[0], 0, 3], device=sh0.device) |
|
else: |
|
sh0 = self.splats["sh0"] |
|
shN = self.splats["shN"] |
|
|
|
|
|
means = self.splats["means"] |
|
scales = self.splats["scales"] |
|
quats = self.splats["quats"] |
|
opacities = self.splats["opacities"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export_ply( |
|
means=means, |
|
scales=scales, |
|
rotations=quats, |
|
harmonics=torch.cat([sh0, shN], dim=1).permute(0, 2, 1), |
|
opacities=opacities.sigmoid(), |
|
path=Path(f"{self.ply_dir}/point_cloud_{step}.ply"), |
|
) |
|
|
|
|
|
if cfg.sparse_grad: |
|
assert cfg.packed, "Sparse gradients only work with packed mode." |
|
gaussian_ids = info["gaussian_ids"] |
|
for k in self.splats.keys(): |
|
grad = self.splats[k].grad |
|
if grad is None or grad.is_sparse: |
|
continue |
|
self.splats[k].grad = torch.sparse_coo_tensor( |
|
indices=gaussian_ids[None], |
|
values=grad[gaussian_ids], |
|
size=self.splats[k].size(), |
|
is_coalesced=len(Ks) == 1, |
|
) |
|
|
|
if cfg.visible_adam: |
|
gaussian_cnt = self.splats.means.shape[0] |
|
if cfg.packed: |
|
visibility_mask = torch.zeros_like( |
|
self.splats["opacities"], dtype=bool |
|
) |
|
visibility_mask.scatter_(0, info["gaussian_ids"], 1) |
|
else: |
|
visibility_mask = (info["radii"] > 0).all(-1).any(0) |
|
|
|
|
|
for optimizer in self.optimizers.values(): |
|
if cfg.visible_adam: |
|
optimizer.step(visibility_mask) |
|
else: |
|
optimizer.step() |
|
optimizer.zero_grad(set_to_none=True) |
|
for optimizer in self.pose_optimizers: |
|
optimizer.step() |
|
optimizer.zero_grad(set_to_none=True) |
|
for optimizer in self.app_optimizers: |
|
optimizer.step() |
|
optimizer.zero_grad(set_to_none=True) |
|
for optimizer in self.bil_grid_optimizers: |
|
optimizer.step() |
|
optimizer.zero_grad(set_to_none=True) |
|
for scheduler in schedulers: |
|
scheduler.step() |
|
|
|
|
|
if isinstance(self.cfg.strategy, DefaultStrategy): |
|
self.cfg.strategy.step_post_backward( |
|
params=self.splats, |
|
optimizers=self.optimizers, |
|
state=self.strategy_state, |
|
step=step, |
|
info=info, |
|
packed=cfg.packed, |
|
) |
|
elif isinstance(self.cfg.strategy, MCMCStrategy): |
|
self.cfg.strategy.step_post_backward( |
|
params=self.splats, |
|
optimizers=self.optimizers, |
|
state=self.strategy_state, |
|
step=step, |
|
info=info, |
|
lr=schedulers[0].get_last_lr()[0], |
|
) |
|
else: |
|
assert_never(self.cfg.strategy) |
|
|
|
|
|
if step in [i - 1 for i in cfg.eval_steps]: |
|
self.eval(step) |
|
|
|
|
|
|
|
if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: |
|
self.run_compression(step=step) |
|
|
|
if not cfg.disable_viewer: |
|
self.viewer.lock.release() |
|
num_train_steps_per_sec = 1.0 / (time.time() - tic) |
|
num_train_rays_per_sec = ( |
|
num_train_rays_per_step * num_train_steps_per_sec |
|
) |
|
|
|
self.viewer.render_tab_state.num_train_rays_per_sec = ( |
|
num_train_rays_per_sec |
|
) |
|
|
|
self.viewer.update(step, num_train_rays_per_step) |
|
|
|
@torch.no_grad() |
|
def eval(self, step: int, stage: str = "val"): |
|
"""Entry for evaluation.""" |
|
print("Running evaluation...") |
|
cfg = self.cfg |
|
device = self.device |
|
world_rank = self.world_rank |
|
world_size = self.world_size |
|
|
|
valloader = torch.utils.data.DataLoader( |
|
self.valset, batch_size=1, shuffle=False, num_workers=1 |
|
) |
|
ellipse_time = 0 |
|
metrics = defaultdict(list) |
|
for i, data in enumerate(valloader): |
|
camtoworlds = data["camtoworld"].to(device) |
|
Ks = data["K"].to(device) |
|
pixels = data["image"].to(device) / 255.0 |
|
masks = data["mask"].to(device) if "mask" in data else None |
|
height, width = pixels.shape[1:3] |
|
|
|
torch.cuda.synchronize() |
|
tic = time.time() |
|
render_colors, _, _ = self.rasterize_splats( |
|
camtoworlds=camtoworlds, |
|
Ks=Ks, |
|
width=width, |
|
height=height, |
|
sh_degree=cfg.sh_degree, |
|
near_plane=cfg.near_plane, |
|
far_plane=cfg.far_plane, |
|
|
|
render_mode="RGB+ED", |
|
masks=masks, |
|
) |
|
torch.cuda.synchronize() |
|
ellipse_time += time.time() - tic |
|
|
|
colors = render_colors[..., :3] |
|
depths = render_colors[..., 3] |
|
|
|
colors = torch.clamp(colors, 0.0, 1.0) |
|
canvas_list = [pixels, colors] |
|
|
|
if world_rank == 0: |
|
|
|
canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() |
|
canvas = (canvas * 255).astype(np.uint8) |
|
imageio.imwrite( |
|
f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", |
|
canvas, |
|
) |
|
torchvision.utils.save_image(pixels.permute(0, 3, 1, 2), f"{self.render_dir}/gt_rgb_{stage}_step{step}_{i:04d}.png") |
|
torchvision.utils.save_image(colors.permute(0, 3, 1, 2), f"{self.render_dir}/render_rgb_{stage}_step{step}_{i:04d}.png") |
|
|
|
|
|
|
|
pixels_p = pixels.permute(0, 3, 1, 2) |
|
colors_p = colors.permute(0, 3, 1, 2) |
|
|
|
metrics["psnr"].append(self.psnr(colors_p, pixels_p)) |
|
metrics["ssim"].append(self.ssim(colors_p, pixels_p)) |
|
metrics["lpips"].append(self.lpips(colors_p, pixels_p)) |
|
if cfg.use_bilateral_grid: |
|
cc_colors = color_correct(colors, pixels) |
|
cc_colors_p = cc_colors.permute(0, 3, 1, 2) |
|
metrics["cc_psnr"].append(self.psnr(cc_colors_p, pixels_p)) |
|
|
|
if world_rank == 0: |
|
ellipse_time /= len(valloader) |
|
|
|
stats = {k: torch.stack(v).mean().item() for k, v in metrics.items()} |
|
stats.update( |
|
{ |
|
"ellipse_time": ellipse_time, |
|
"num_GS": len(self.splats["means"]), |
|
} |
|
) |
|
print( |
|
f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} " |
|
f"Time: {stats['ellipse_time']:.3f}s/image " |
|
f"Number of GS: {stats['num_GS']}" |
|
) |
|
|
|
with open(f"{self.stats_dir}/{stage}_step{step:04d}.json", "w") as f: |
|
json.dump(stats, f) |
|
|
|
for k, v in stats.items(): |
|
self.writer.add_scalar(f"{stage}/{k}", v, step) |
|
self.writer.flush() |
|
|
|
@torch.no_grad() |
|
def render_traj(self, step: int): |
|
"""Entry for trajectory rendering.""" |
|
if self.cfg.disable_video: |
|
return |
|
print("Running trajectory rendering...") |
|
cfg = self.cfg |
|
device = self.device |
|
|
|
camtoworlds_all = self.parser.camtoworlds[5:-5] |
|
if cfg.render_traj_path == "interp": |
|
camtoworlds_all = generate_interpolated_path( |
|
camtoworlds_all, 1 |
|
) |
|
elif cfg.render_traj_path == "ellipse": |
|
height = camtoworlds_all[:, 2, 3].mean() |
|
camtoworlds_all = generate_ellipse_path_z( |
|
camtoworlds_all, height=height |
|
) |
|
elif cfg.render_traj_path == "spiral": |
|
camtoworlds_all = generate_spiral_path( |
|
camtoworlds_all, |
|
bounds=self.parser.bounds * self.scene_scale, |
|
spiral_scale_r=self.parser.extconf["spiral_radius_scale"], |
|
) |
|
else: |
|
raise ValueError( |
|
f"Render trajectory type not supported: {cfg.render_traj_path}" |
|
) |
|
|
|
camtoworlds_all = np.concatenate( |
|
[ |
|
camtoworlds_all, |
|
np.repeat( |
|
np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0 |
|
), |
|
], |
|
axis=1, |
|
) |
|
|
|
camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device) |
|
K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) |
|
width, height = list(self.parser.imsize_dict.values())[0] |
|
|
|
|
|
video_dir = f"{cfg.result_dir}/videos" |
|
os.makedirs(video_dir, exist_ok=True) |
|
writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) |
|
for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): |
|
camtoworlds = camtoworlds_all[i : i + 1] |
|
Ks = K[None] |
|
|
|
renders, _, _ = self.rasterize_splats( |
|
camtoworlds=camtoworlds, |
|
Ks=Ks, |
|
width=width, |
|
height=height, |
|
sh_degree=cfg.sh_degree, |
|
near_plane=cfg.near_plane, |
|
far_plane=cfg.far_plane, |
|
render_mode="RGB+ED", |
|
) |
|
colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) |
|
depths = renders[..., 3:4] |
|
depths = (depths - depths.min()) / (depths.max() - depths.min()) |
|
canvas_list = [colors, depths.repeat(1, 1, 1, 3)] |
|
|
|
|
|
canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() |
|
canvas = (canvas * 255).astype(np.uint8) |
|
writer.append_data(canvas) |
|
writer.close() |
|
print(f"Video saved to {video_dir}/traj_{step}.mp4") |
|
|
|
@torch.no_grad() |
|
def run_compression(self, step: int): |
|
"""Entry for running compression.""" |
|
print("Running compression...") |
|
world_rank = self.world_rank |
|
|
|
compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}" |
|
os.makedirs(compress_dir, exist_ok=True) |
|
|
|
self.compression_method.compress(compress_dir, self.splats) |
|
|
|
|
|
splats_c = self.compression_method.decompress(compress_dir) |
|
for k in splats_c.keys(): |
|
self.splats[k].data = splats_c[k].to(self.device) |
|
self.eval(step=step, stage="compress") |
|
|
|
@torch.no_grad() |
|
def _viewer_render_fn( |
|
self, camera_state: CameraState, render_tab_state: RenderTabState |
|
): |
|
assert isinstance(render_tab_state, GsplatRenderTabState) |
|
if render_tab_state.preview_render: |
|
width = render_tab_state.render_width |
|
height = render_tab_state.render_height |
|
else: |
|
width = render_tab_state.viewer_width |
|
height = render_tab_state.viewer_height |
|
c2w = camera_state.c2w |
|
K = camera_state.get_K((width, height)) |
|
c2w = torch.from_numpy(c2w).float().to(self.device) |
|
K = torch.from_numpy(K).float().to(self.device) |
|
|
|
RENDER_MODE_MAP = { |
|
"rgb": "RGB", |
|
"depth(accumulated)": "D", |
|
"depth(expected)": "ED", |
|
"alpha": "RGB", |
|
} |
|
|
|
render_colors, render_alphas, info = self.rasterize_splats( |
|
camtoworlds=c2w[None], |
|
Ks=K[None], |
|
width=width, |
|
height=height, |
|
sh_degree=min(render_tab_state.max_sh_degree, self.cfg.sh_degree), |
|
near_plane=render_tab_state.near_plane, |
|
far_plane=render_tab_state.far_plane, |
|
radius_clip=render_tab_state.radius_clip, |
|
|
|
eps2d=render_tab_state.eps2d, |
|
backgrounds=torch.tensor([render_tab_state.backgrounds], device=self.device) |
|
/ 255.0, |
|
render_mode=RENDER_MODE_MAP[render_tab_state.render_mode], |
|
rasterize_mode=render_tab_state.rasterize_mode, |
|
camera_model=render_tab_state.camera_model, |
|
) |
|
render_tab_state.total_gs_count = len(self.splats["means"]) |
|
render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item() |
|
|
|
if render_tab_state.render_mode == "rgb": |
|
|
|
render_colors = render_colors[0, ..., 0:3].clamp(0, 1) |
|
renders = render_colors.cpu().numpy() |
|
elif render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]: |
|
|
|
depth = render_colors[0, ..., 0:1] |
|
if render_tab_state.normalize_nearfar: |
|
near_plane = render_tab_state.near_plane |
|
far_plane = render_tab_state.far_plane |
|
else: |
|
near_plane = depth.min() |
|
far_plane = depth.max() |
|
depth_norm = (depth - near_plane) / (far_plane - near_plane + 1e-10) |
|
depth_norm = torch.clip(depth_norm, 0, 1) |
|
if render_tab_state.inverse: |
|
depth_norm = 1 - depth_norm |
|
renders = ( |
|
apply_float_colormap(depth_norm, render_tab_state.colormap) |
|
.cpu() |
|
.numpy() |
|
) |
|
elif render_tab_state.render_mode == "alpha": |
|
alpha = render_alphas[0, ..., 0:1] |
|
if render_tab_state.inverse: |
|
alpha = 1 - alpha |
|
renders = ( |
|
apply_float_colormap(alpha, render_tab_state.colormap).cpu().numpy() |
|
) |
|
return renders |
|
|
|
|
|
def main(local_rank: int, world_rank, world_size: int, cfg: Config): |
|
if world_size > 1 and not cfg.disable_viewer: |
|
cfg.disable_viewer = True |
|
if world_rank == 0: |
|
print("Viewer is disabled in distributed training.") |
|
|
|
runner = Runner(local_rank, world_rank, world_size, cfg) |
|
|
|
if cfg.ckpt is not None: |
|
|
|
ckpts = [ |
|
torch.load(file, map_location=runner.device, weights_only=True) |
|
for file in cfg.ckpt |
|
] |
|
for k in runner.splats.keys(): |
|
runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts]) |
|
step = ckpts[0]["step"] |
|
runner.eval(step=step) |
|
|
|
if cfg.compression is not None: |
|
runner.run_compression(step=step) |
|
else: |
|
runner.train() |
|
runner.eval(step=runner.cfg.max_steps) |
|
|
|
print("Training complete.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
""" |
|
Usage: |
|
|
|
```bash |
|
# Single GPU training |
|
CUDA_VISIBLE_DEVICES=9 python -m examples.simple_trainer default |
|
|
|
# Distributed training on 4 GPUs: Effectively 4x batch size so run 4x less steps. |
|
CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py default --steps_scaler 0.25 |
|
|
|
""" |
|
|
|
|
|
|
|
configs = { |
|
"default": ( |
|
"Gaussian splatting training using densification heuristics from the original paper.", |
|
Config( |
|
strategy=DefaultStrategy(verbose=True), |
|
), |
|
), |
|
"mcmc": ( |
|
"Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.", |
|
Config( |
|
init_opa=0.5, |
|
init_scale=0.1, |
|
opacity_reg=0.01, |
|
scale_reg=0.01, |
|
strategy=MCMCStrategy(verbose=True), |
|
), |
|
), |
|
} |
|
cfg = tyro.extras.overridable_config_cli(configs) |
|
cfg.adjust_steps(cfg.steps_scaler) |
|
|
|
|
|
if cfg.compression == "png": |
|
try: |
|
import plas |
|
import torchpq |
|
except: |
|
raise ImportError( |
|
"To use PNG compression, you need to install " |
|
"torchpq (instruction at https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install) " |
|
"and plas (via 'pip install git+https://github.com/fraunhoferhhi/PLAS.git') " |
|
) |
|
|
|
cli(main, cfg, verbose=True) |
|
|