import random import numpy as np import torch from sklearn.neighbors import NearestNeighbors from torch import Tensor import torch.nn.functional as F import matplotlib.pyplot as plt from matplotlib import colormaps class CameraOptModule(torch.nn.Module): """Camera pose optimization module.""" def __init__(self, n: int): super().__init__() # Delta positions (3D) + Delta rotations (6D) self.embeds = torch.nn.Embedding(n, 9) # Identity rotation in 6D representation self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0])) def zero_init(self): torch.nn.init.zeros_(self.embeds.weight) def random_init(self, std: float): torch.nn.init.normal_(self.embeds.weight, std=std) def forward(self, camtoworlds: Tensor, embed_ids: Tensor) -> Tensor: """Adjust camera pose based on deltas. Args: camtoworlds: (..., 4, 4) embed_ids: (...,) Returns: updated camtoworlds: (..., 4, 4) """ assert camtoworlds.shape[:-2] == embed_ids.shape batch_shape = camtoworlds.shape[:-2] pose_deltas = self.embeds(embed_ids) # (..., 9) dx, drot = pose_deltas[..., :3], pose_deltas[..., 3:] rot = rotation_6d_to_matrix( drot + self.identity.expand(*batch_shape, -1) ) # (..., 3, 3) transform = torch.eye(4, device=pose_deltas.device).repeat((*batch_shape, 1, 1)) transform[..., :3, :3] = rot transform[..., :3, 3] = dx return torch.matmul(camtoworlds, transform) class AppearanceOptModule(torch.nn.Module): """Appearance optimization module.""" def __init__( self, n: int, feature_dim: int, embed_dim: int = 16, sh_degree: int = 3, mlp_width: int = 64, mlp_depth: int = 2, ): super().__init__() self.embed_dim = embed_dim self.sh_degree = sh_degree self.embeds = torch.nn.Embedding(n, embed_dim) layers = [] layers.append( torch.nn.Linear(embed_dim + feature_dim + (sh_degree + 1) ** 2, mlp_width) ) layers.append(torch.nn.ReLU(inplace=True)) for _ in range(mlp_depth - 1): layers.append(torch.nn.Linear(mlp_width, mlp_width)) layers.append(torch.nn.ReLU(inplace=True)) layers.append(torch.nn.Linear(mlp_width, 3)) self.color_head = torch.nn.Sequential(*layers) def forward( self, features: Tensor, embed_ids: Tensor, dirs: Tensor, sh_degree: int ) -> Tensor: """Adjust appearance based on embeddings. Args: features: (N, feature_dim) embed_ids: (C,) dirs: (C, N, 3) Returns: colors: (C, N, 3) """ from gsplat.cuda._torch_impl import _eval_sh_bases_fast C, N = dirs.shape[:2] # Camera embeddings if embed_ids is None: embeds = torch.zeros(C, self.embed_dim, device=features.device) else: embeds = self.embeds(embed_ids) # [C, D2] embeds = embeds[:, None, :].expand(-1, N, -1) # [C, N, D2] # GS features features = features[None, :, :].expand(C, -1, -1) # [C, N, D1] # View directions dirs = F.normalize(dirs, dim=-1) # [C, N, 3] num_bases_to_use = (sh_degree + 1) ** 2 num_bases = (self.sh_degree + 1) ** 2 sh_bases = torch.zeros(C, N, num_bases, device=features.device) # [C, N, K] sh_bases[:, :, :num_bases_to_use] = _eval_sh_bases_fast(num_bases_to_use, dirs) # Get colors if self.embed_dim > 0: h = torch.cat([embeds, features, sh_bases], dim=-1) # [C, N, D1 + D2 + K] else: h = torch.cat([features, sh_bases], dim=-1) colors = self.color_head(h) return colors def rotation_6d_to_matrix(d6: Tensor) -> Tensor: """ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix using Gram--Schmidt orthogonalization per Section B of [1]. Adapted from pytorch3d. Args: d6: 6D rotation representation, of size (*, 6) Returns: batch of rotation matrices of size (*, 3, 3) [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. On the Continuity of Rotation Representations in Neural Networks. IEEE Conference on Computer Vision and Pattern Recognition, 2019. Retrieved from http://arxiv.org/abs/1812.07035 """ a1, a2 = d6[..., :3], d6[..., 3:] b1 = F.normalize(a1, dim=-1) b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 b2 = F.normalize(b2, dim=-1) b3 = torch.cross(b1, b2, dim=-1) return torch.stack((b1, b2, b3), dim=-2) def knn(x: Tensor, K: int = 4) -> Tensor: x_np = x.cpu().numpy() model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np) distances, _ = model.kneighbors(x_np) return torch.from_numpy(distances).to(x) def rgb_to_sh(rgb: Tensor) -> Tensor: C0 = 0.28209479177387814 return (rgb - 0.5) / C0 def set_random_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # ref: https://github.com/hbb1/2d-gaussian-splatting/blob/main/utils/general_utils.py#L163 def colormap(img, cmap="jet"): W, H = img.shape[:2] dpi = 300 fig, ax = plt.subplots(1, figsize=(H / dpi, W / dpi), dpi=dpi) im = ax.imshow(img, cmap=cmap) ax.set_axis_off() fig.colorbar(im, ax=ax) fig.tight_layout() fig.canvas.draw() data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) img = torch.from_numpy(data).float().permute(2, 0, 1) plt.close() return img def apply_float_colormap(img: torch.Tensor, colormap: str = "turbo") -> torch.Tensor: """Convert single channel to a color img. Args: img (torch.Tensor): (..., 1) float32 single channel image. colormap (str): Colormap for img. Returns: (..., 3) colored img with colors in [0, 1]. """ img = torch.nan_to_num(img, 0) if colormap == "gray": return img.repeat(1, 1, 3) img_long = (img * 255).long() img_long_min = torch.min(img_long) img_long_max = torch.max(img_long) assert img_long_min >= 0, f"the min value is {img_long_min}" assert img_long_max <= 255, f"the max value is {img_long_max}" return torch.tensor( colormaps[colormap].colors, # type: ignore device=img.device, )[img_long[..., 0]] def apply_depth_colormap( depth: torch.Tensor, acc: torch.Tensor = None, near_plane: float = None, far_plane: float = None, ) -> torch.Tensor: """Converts a depth image to color for easier analysis. Args: depth (torch.Tensor): (..., 1) float32 depth. acc (torch.Tensor | None): (..., 1) optional accumulation mask. near_plane: Closest depth to consider. If None, use min image value. far_plane: Furthest depth to consider. If None, use max image value. Returns: (..., 3) colored depth image with colors in [0, 1]. """ near_plane = near_plane or float(torch.min(depth)) far_plane = far_plane or float(torch.max(depth)) depth = (depth - near_plane) / (far_plane - near_plane + 1e-10) depth = torch.clip(depth, 0.0, 1.0) img = apply_float_colormap(depth, colormap="turbo") if acc is not None: img = img * acc + (1.0 - acc) return img