AnySplat / src /model /decoder /cuda_splatting.py
alexnasa's picture
Upload 243 files
2568013 verified
from math import isqrt
from typing import Literal
import torch
from diff_gaussian_rasterization import (
GaussianRasterizationSettings,
GaussianRasterizer,
)
from einops import einsum, rearrange, repeat
from jaxtyping import Float, Bool
from torch import Tensor
from ...geometry.projection import get_fov, homogenize_points
def get_projection_matrix(
near: Float[Tensor, " batch"],
far: Float[Tensor, " batch"],
fov_x: Float[Tensor, " batch"],
fov_y: Float[Tensor, " batch"],
) -> Float[Tensor, "batch 4 4"]:
"""Maps points in the viewing frustum to (-1, 1) on the X/Y axes and (0, 1) on the Z
axis. Differs from the OpenGL version in that Z doesn't have range (-1, 1) after
transformation and that Z is flipped.
"""
tan_fov_x = (0.5 * fov_x).tan()
tan_fov_y = (0.5 * fov_y).tan()
top = tan_fov_y * near
bottom = -top
right = tan_fov_x * near
left = -right
(b,) = near.shape
result = torch.zeros((b, 4, 4), dtype=torch.float32, device=near.device)
result[:, 0, 0] = 2 * near / (right - left)
result[:, 1, 1] = 2 * near / (top - bottom)
result[:, 0, 2] = (right + left) / (right - left)
result[:, 1, 2] = (top + bottom) / (top - bottom)
result[:, 3, 2] = 1
result[:, 2, 2] = far / (far - near)
result[:, 2, 3] = -(far * near) / (far - near)
return result
def render_cuda(
extrinsics: Float[Tensor, "batch 4 4"],
intrinsics: Float[Tensor, "batch 3 3"],
near: Float[Tensor, " batch"],
far: Float[Tensor, " batch"],
image_shape: tuple[int, int],
background_color: Float[Tensor, "batch 3"],
gaussian_means: Float[Tensor, "batch gaussian 3"],
gaussian_covariances: Float[Tensor, "batch gaussian 3 3"],
gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"],
gaussian_opacities: Float[Tensor, "batch gaussian"],
scale_invariant: bool = True,
use_sh: bool = True,
cam_rot_delta: Float[Tensor, "batch 3"] | None = None,
cam_trans_delta: Float[Tensor, "batch 3"] | None = None,
voxel_masks: Bool[Tensor, "batch gaussian"] | None = None,
) -> tuple[Float[Tensor, "batch 3 height width"], Float[Tensor, "batch height width"]]:
assert use_sh or gaussian_sh_coefficients.shape[-1] == 1
# Make sure everything is in a range where numerical issues don't appear.
if scale_invariant:
scale = 1 / near
extrinsics = extrinsics.clone()
extrinsics[..., :3, 3] = extrinsics[..., :3, 3] * scale[:, None]
gaussian_covariances = gaussian_covariances * (scale[:, None, None, None] ** 2)
gaussian_means = gaussian_means * scale[:, None, None]
near = near * scale
far = far * scale
_, _, _, n = gaussian_sh_coefficients.shape
degree = isqrt(n) - 1
shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous()
b, _, _ = extrinsics.shape
h, w = image_shape
fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1)
tan_fov_x = (0.5 * fov_x).tan()
tan_fov_y = (0.5 * fov_y).tan()
projection_matrix = get_projection_matrix(near, far, fov_x, fov_y)
projection_matrix = rearrange(projection_matrix, "b i j -> b j i")
view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i")
full_projection = view_matrix @ projection_matrix
all_images = []
all_radii = []
all_depths = []
for i in range(b):
# Set up a tensor for the gradients of the screen-space means.
mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True)
try:
mean_gradients.retain_grad()
except Exception:
pass
settings = GaussianRasterizationSettings(
image_height=h,
image_width=w,
tanfovx=tan_fov_x[i].item(),
tanfovy=tan_fov_y[i].item(),
bg=background_color[i],
scale_modifier=1.0,
viewmatrix=view_matrix[i],
projmatrix=full_projection[i],
projmatrix_raw=projection_matrix[i],
sh_degree=degree,
campos=extrinsics[i, :3, 3],
prefiltered=False, # This matches the original usage.
debug=False,
)
rasterizer = GaussianRasterizer(settings)
row, col = torch.triu_indices(3, 3)
if voxel_masks is not None:
voxel_mask = voxel_masks[i]
image, radii, depth, opacity, n_touched = rasterizer(
means3D=gaussian_means[i][voxel_mask],
means2D=mean_gradients[voxel_mask],
shs=shs[i][voxel_mask] if use_sh else None,
colors_precomp=None if use_sh else shs[i, :, 0, :][voxel_mask],
opacities=gaussian_opacities[i][voxel_mask, ..., None],
cov3D_precomp=gaussian_covariances[i, :, row, col][voxel_mask],
theta=cam_rot_delta[i] if cam_rot_delta is not None else None,
rho=cam_trans_delta[i] if cam_trans_delta is not None else None,
)
else:
image, radii, depth, opacity, n_touched = rasterizer(
means3D=gaussian_means[i],
means2D=mean_gradients,
shs=shs[i] if use_sh else None,
colors_precomp=None if use_sh else shs[i, :, 0, :],
opacities=gaussian_opacities[i, ..., None],
cov3D_precomp=gaussian_covariances[i, :, row, col],
theta=cam_rot_delta[i] if cam_rot_delta is not None else None,
rho=cam_trans_delta[i] if cam_trans_delta is not None else None,
)
all_images.append(image)
all_radii.append(radii)
all_depths.append(depth.squeeze(0))
return torch.stack(all_images), torch.stack(all_depths)
def render_cuda_orthographic(
extrinsics: Float[Tensor, "batch 4 4"],
width: Float[Tensor, " batch"],
height: Float[Tensor, " batch"],
near: Float[Tensor, " batch"],
far: Float[Tensor, " batch"],
image_shape: tuple[int, int],
background_color: Float[Tensor, "batch 3"],
gaussian_means: Float[Tensor, "batch gaussian 3"],
gaussian_covariances: Float[Tensor, "batch gaussian 3 3"],
gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"],
gaussian_opacities: Float[Tensor, "batch gaussian"],
fov_degrees: float = 0.1,
use_sh: bool = True,
dump: dict | None = None,
) -> Float[Tensor, "batch 3 height width"]:
b, _, _ = extrinsics.shape
h, w = image_shape
assert use_sh or gaussian_sh_coefficients.shape[-1] == 1
_, _, _, n = gaussian_sh_coefficients.shape
degree = isqrt(n) - 1
shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous()
# Create fake "orthographic" projection by moving the camera back and picking a
# small field of view.
fov_x = torch.tensor(fov_degrees, device=extrinsics.device).deg2rad()
tan_fov_x = (0.5 * fov_x).tan()
distance_to_near = (0.5 * width) / tan_fov_x
tan_fov_y = 0.5 * height / distance_to_near
fov_y = (2 * tan_fov_y).atan()
near = near + distance_to_near
far = far + distance_to_near
move_back = torch.eye(4, dtype=torch.float32, device=extrinsics.device)
move_back[2, 3] = -distance_to_near
extrinsics = extrinsics @ move_back
# Escape hatch for visualization/figures.
if dump is not None:
dump["extrinsics"] = extrinsics
dump["fov_x"] = fov_x
dump["fov_y"] = fov_y
dump["near"] = near
dump["far"] = far
projection_matrix = get_projection_matrix(
near, far, repeat(fov_x, "-> b", b=b), fov_y
)
projection_matrix = rearrange(projection_matrix, "b i j -> b j i")
view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i")
full_projection = view_matrix @ projection_matrix
all_images = []
all_radii = []
for i in range(b):
# Set up a tensor for the gradients of the screen-space means.
mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True)
try:
mean_gradients.retain_grad()
except Exception:
pass
settings = GaussianRasterizationSettings(
image_height=h,
image_width=w,
tanfovx=tan_fov_x,
tanfovy=tan_fov_y,
bg=background_color[i],
scale_modifier=1.0,
viewmatrix=view_matrix[i],
projmatrix=full_projection[i],
projmatrix_raw=projection_matrix[i],
sh_degree=degree,
campos=extrinsics[i, :3, 3],
prefiltered=False, # This matches the original usage.
debug=False,
)
rasterizer = GaussianRasterizer(settings)
row, col = torch.triu_indices(3, 3)
image, radii, depth, opacity, n_touched = rasterizer(
means3D=gaussian_means[i],
means2D=mean_gradients,
shs=shs[i] if use_sh else None,
colors_precomp=None if use_sh else shs[i, :, 0, :],
opacities=gaussian_opacities[i, ..., None],
cov3D_precomp=gaussian_covariances[i, :, row, col],
)
all_images.append(image)
all_radii.append(radii)
return torch.stack(all_images)
DepthRenderingMode = Literal["depth", "disparity", "relative_disparity", "log"]