AnySplat / src /model /encoder /common /gaussian_adapter.py
alexnasa's picture
Upload 243 files
2568013 verified
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from einops import einsum, rearrange
from jaxtyping import Float
from torch import Tensor, nn
from src.geometry.projection import get_world_rays
from src.misc.sh_rotation import rotate_sh
from .gaussians import build_covariance
from ...types import Gaussians
@dataclass
class GaussianAdapterCfg:
gaussian_scale_min: float
gaussian_scale_max: float
sh_degree: int
class GaussianAdapter(nn.Module):
cfg: GaussianAdapterCfg
def __init__(self, cfg: GaussianAdapterCfg):
super().__init__()
self.cfg = cfg
# Create a mask for the spherical harmonics coefficients. This ensures that at
# initialization, the coefficients are biased towards having a large DC
# component and small view-dependent components.
self.register_buffer(
"sh_mask",
torch.ones((self.d_sh,), dtype=torch.float32),
persistent=False,
)
for degree in range(1, self.cfg.sh_degree + 1):
self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree
def forward(
self,
extrinsics: Float[Tensor, "*#batch 4 4"],
intrinsics: Float[Tensor, "*#batch 3 3"],
coordinates: Float[Tensor, "*#batch 2"],
depths: Float[Tensor, "*#batch"],
opacities: Float[Tensor, "*#batch"],
raw_gaussians: Float[Tensor, "*#batch _"],
image_shape: tuple[int, int],
eps: float = 1e-8,
) -> Gaussians:
device = extrinsics.device
scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1)
# Map scale features to valid scale range.
scale_min = self.cfg.gaussian_scale_min
scale_max = self.cfg.gaussian_scale_max
scales = scale_min + (scale_max - scale_min) * scales.sigmoid()
h, w = image_shape
pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=device)
multiplier = self.get_scale_multiplier(intrinsics, pixel_size)
scales = scales * depths[..., None] * multiplier[..., None]
# Normalize the quaternion features to yield a valid quaternion.
rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps)
sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3)
sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)) * self.sh_mask
# Create world-space covariance matrices.
covariances = build_covariance(scales, rotations)
c2w_rotations = extrinsics[..., :3, :3]
covariances = c2w_rotations @ covariances @ c2w_rotations.transpose(-1, -2)
# Compute Gaussian means.
origins, directions = get_world_rays(coordinates, extrinsics, intrinsics)
means = origins + directions * depths[..., None]
return Gaussians(
means=means,
covariances=covariances,
# harmonics=rotate_sh(sh, c2w_rotations[..., None, :, :]),
harmonics=sh,
opacities=opacities,
# Note: These aren't yet rotated into world space, but they're only used for
# exporting Gaussians to ply files. This needs to be fixed...
scales=scales,
rotations=rotations.broadcast_to((*scales.shape[:-1], 4)),
)
def get_scale_multiplier(
self,
intrinsics: Float[Tensor, "*#batch 3 3"],
pixel_size: Float[Tensor, "*#batch 2"],
multiplier: float = 0.1,
) -> Float[Tensor, " *batch"]:
xy_multipliers = multiplier * einsum(
intrinsics[..., :2, :2].inverse(),
pixel_size,
"... i j, j -> ... i",
)
return xy_multipliers.sum(dim=-1)
@property
def d_sh(self) -> int:
return (self.cfg.sh_degree + 1) ** 2
@property
def d_in(self) -> int:
return 7 + 3 * self.d_sh
class UnifiedGaussianAdapter(GaussianAdapter):
def forward(
self,
means: Float[Tensor, "*#batch 3"],
# levels: Float[Tensor, "*#batch"],
depths: Float[Tensor, "*#batch"],
opacities: Float[Tensor, "*#batch"],
raw_gaussians: Float[Tensor, "*#batch _"],
eps: float = 1e-8,
intrinsics: Optional[Float[Tensor, "*#batch 3 3"]] = None,
coordinates: Optional[Float[Tensor, "*#batch 2"]] = None,
) -> Gaussians:
scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1)
scales = 0.001 * F.softplus(scales)
scales = scales.clamp_max(0.3)
# Normalize the quaternion features to yield a valid quaternion.
rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps)
sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3)
sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)) * self.sh_mask
# print(scales.max())
covariances = build_covariance(scales, rotations)
return Gaussians(
means=means.float(),
# levels=levels.int(),
covariances=covariances.float(),
harmonics=sh.float(),
opacities=opacities.float(),
scales=scales.float(),
rotations=rotations.float(),
)
class Unet3dGaussianAdapter(GaussianAdapter):
def forward(
self,
means: Float[Tensor, "*#batch 3"],
depths: Float[Tensor, "*#batch"],
opacities: Float[Tensor, "*#batch"],
raw_gaussians: Float[Tensor, "*#batch _"],
eps: float = 1e-8,
intrinsics: Optional[Float[Tensor, "*#batch 3 3"]] = None,
coordinates: Optional[Float[Tensor, "*#batch 2"]] = None,
) -> Gaussians:
scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1)
scales = 0.001 * F.softplus(scales)
scales = scales.clamp_max(0.3)
# Normalize the quaternion features to yield a valid quaternion.
rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps)
sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3)
sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)) * self.sh_mask
covariances = build_covariance(scales, rotations)
return Gaussians(
means=means,
covariances=covariances,
harmonics=sh,
opacities=opacities,
scales=scales,
rotations=rotations,
)