AnySplat / src /post_opt /utils.py
alexnasa's picture
Upload 243 files
2568013 verified
import random
import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from torch import Tensor
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib import colormaps
class CameraOptModule(torch.nn.Module):
"""Camera pose optimization module."""
def __init__(self, n: int):
super().__init__()
# Delta positions (3D) + Delta rotations (6D)
self.embeds = torch.nn.Embedding(n, 9)
# Identity rotation in 6D representation
self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]))
def zero_init(self):
torch.nn.init.zeros_(self.embeds.weight)
def random_init(self, std: float):
torch.nn.init.normal_(self.embeds.weight, std=std)
def forward(self, camtoworlds: Tensor, embed_ids: Tensor) -> Tensor:
"""Adjust camera pose based on deltas.
Args:
camtoworlds: (..., 4, 4)
embed_ids: (...,)
Returns:
updated camtoworlds: (..., 4, 4)
"""
assert camtoworlds.shape[:-2] == embed_ids.shape
batch_shape = camtoworlds.shape[:-2]
pose_deltas = self.embeds(embed_ids) # (..., 9)
dx, drot = pose_deltas[..., :3], pose_deltas[..., 3:]
rot = rotation_6d_to_matrix(
drot + self.identity.expand(*batch_shape, -1)
) # (..., 3, 3)
transform = torch.eye(4, device=pose_deltas.device).repeat((*batch_shape, 1, 1))
transform[..., :3, :3] = rot
transform[..., :3, 3] = dx
return torch.matmul(camtoworlds, transform)
class AppearanceOptModule(torch.nn.Module):
"""Appearance optimization module."""
def __init__(
self,
n: int,
feature_dim: int,
embed_dim: int = 16,
sh_degree: int = 3,
mlp_width: int = 64,
mlp_depth: int = 2,
):
super().__init__()
self.embed_dim = embed_dim
self.sh_degree = sh_degree
self.embeds = torch.nn.Embedding(n, embed_dim)
layers = []
layers.append(
torch.nn.Linear(embed_dim + feature_dim + (sh_degree + 1) ** 2, mlp_width)
)
layers.append(torch.nn.ReLU(inplace=True))
for _ in range(mlp_depth - 1):
layers.append(torch.nn.Linear(mlp_width, mlp_width))
layers.append(torch.nn.ReLU(inplace=True))
layers.append(torch.nn.Linear(mlp_width, 3))
self.color_head = torch.nn.Sequential(*layers)
def forward(
self, features: Tensor, embed_ids: Tensor, dirs: Tensor, sh_degree: int
) -> Tensor:
"""Adjust appearance based on embeddings.
Args:
features: (N, feature_dim)
embed_ids: (C,)
dirs: (C, N, 3)
Returns:
colors: (C, N, 3)
"""
from gsplat.cuda._torch_impl import _eval_sh_bases_fast
C, N = dirs.shape[:2]
# Camera embeddings
if embed_ids is None:
embeds = torch.zeros(C, self.embed_dim, device=features.device)
else:
embeds = self.embeds(embed_ids) # [C, D2]
embeds = embeds[:, None, :].expand(-1, N, -1) # [C, N, D2]
# GS features
features = features[None, :, :].expand(C, -1, -1) # [C, N, D1]
# View directions
dirs = F.normalize(dirs, dim=-1) # [C, N, 3]
num_bases_to_use = (sh_degree + 1) ** 2
num_bases = (self.sh_degree + 1) ** 2
sh_bases = torch.zeros(C, N, num_bases, device=features.device) # [C, N, K]
sh_bases[:, :, :num_bases_to_use] = _eval_sh_bases_fast(num_bases_to_use, dirs)
# Get colors
if self.embed_dim > 0:
h = torch.cat([embeds, features, sh_bases], dim=-1) # [C, N, D1 + D2 + K]
else:
h = torch.cat([features, sh_bases], dim=-1)
colors = self.color_head(h)
return colors
def rotation_6d_to_matrix(d6: Tensor) -> Tensor:
"""
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
using Gram--Schmidt orthogonalization per Section B of [1]. Adapted from pytorch3d.
Args:
d6: 6D rotation representation, of size (*, 6)
Returns:
batch of rotation matrices of size (*, 3, 3)
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
a1, a2 = d6[..., :3], d6[..., 3:]
b1 = F.normalize(a1, dim=-1)
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
b2 = F.normalize(b2, dim=-1)
b3 = torch.cross(b1, b2, dim=-1)
return torch.stack((b1, b2, b3), dim=-2)
def knn(x: Tensor, K: int = 4) -> Tensor:
x_np = x.cpu().numpy()
model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np)
distances, _ = model.kneighbors(x_np)
return torch.from_numpy(distances).to(x)
def rgb_to_sh(rgb: Tensor) -> Tensor:
C0 = 0.28209479177387814
return (rgb - 0.5) / C0
def set_random_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# ref: https://github.com/hbb1/2d-gaussian-splatting/blob/main/utils/general_utils.py#L163
def colormap(img, cmap="jet"):
W, H = img.shape[:2]
dpi = 300
fig, ax = plt.subplots(1, figsize=(H / dpi, W / dpi), dpi=dpi)
im = ax.imshow(img, cmap=cmap)
ax.set_axis_off()
fig.colorbar(im, ax=ax)
fig.tight_layout()
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
img = torch.from_numpy(data).float().permute(2, 0, 1)
plt.close()
return img
def apply_float_colormap(img: torch.Tensor, colormap: str = "turbo") -> torch.Tensor:
"""Convert single channel to a color img.
Args:
img (torch.Tensor): (..., 1) float32 single channel image.
colormap (str): Colormap for img.
Returns:
(..., 3) colored img with colors in [0, 1].
"""
img = torch.nan_to_num(img, 0)
if colormap == "gray":
return img.repeat(1, 1, 3)
img_long = (img * 255).long()
img_long_min = torch.min(img_long)
img_long_max = torch.max(img_long)
assert img_long_min >= 0, f"the min value is {img_long_min}"
assert img_long_max <= 255, f"the max value is {img_long_max}"
return torch.tensor(
colormaps[colormap].colors, # type: ignore
device=img.device,
)[img_long[..., 0]]
def apply_depth_colormap(
depth: torch.Tensor,
acc: torch.Tensor = None,
near_plane: float = None,
far_plane: float = None,
) -> torch.Tensor:
"""Converts a depth image to color for easier analysis.
Args:
depth (torch.Tensor): (..., 1) float32 depth.
acc (torch.Tensor | None): (..., 1) optional accumulation mask.
near_plane: Closest depth to consider. If None, use min image value.
far_plane: Furthest depth to consider. If None, use max image value.
Returns:
(..., 3) colored depth image with colors in [0, 1].
"""
near_plane = near_plane or float(torch.min(depth))
far_plane = far_plane or float(torch.max(depth))
depth = (depth - near_plane) / (far_plane - near_plane + 1e-10)
depth = torch.clip(depth, 0.0, 1.0)
img = apply_float_colormap(depth, colormap="turbo")
if acc is not None:
img = img * acc + (1.0 - acc)
return img