AnySplat / src /model /decoder /decoder_splatting_cuda.py
alexnasa's picture
Upload 243 files
2568013 verified
from dataclasses import dataclass
from typing import Literal
import torch
from einops import rearrange, repeat
from jaxtyping import Float
from torch import Tensor
import torchvision
from ..types import Gaussians
# from .cuda_splatting import DepthRenderingMode, render_cuda
from .decoder import Decoder, DecoderOutput
from math import sqrt
from gsplat import rasterization
from ...misc.utils import vis_depth_map
DepthRenderingMode = Literal["depth", "disparity", "relative_disparity", "log"]
@dataclass
class DecoderSplattingCUDACfg:
name: Literal["splatting_cuda"]
background_color: list[float]
make_scale_invariant: bool
class DecoderSplattingCUDA(Decoder[DecoderSplattingCUDACfg]):
background_color: Float[Tensor, "3"]
def __init__(
self,
cfg: DecoderSplattingCUDACfg,
) -> None:
super().__init__(cfg)
self.make_scale_invariant = cfg.make_scale_invariant
self.register_buffer(
"background_color",
torch.tensor(cfg.background_color, dtype=torch.float32),
persistent=False,
)
def rendering_fn(
self,
gaussians: Gaussians,
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
near: Float[Tensor, "batch view"],
far: Float[Tensor, "batch view"],
image_shape: tuple[int, int],
depth_mode: DepthRenderingMode | None = None,
cam_rot_delta: Float[Tensor, "batch view 3"] | None = None,
cam_trans_delta: Float[Tensor, "batch view 3"] | None = None,
) -> DecoderOutput:
B, V, _, _ = intrinsics.shape
H, W = image_shape
rendered_imgs, rendered_depths, rendered_alphas = [], [], []
xyzs, opacitys, rotations, scales, features = gaussians.means, gaussians.opacities, gaussians.rotations, gaussians.scales, gaussians.harmonics.permute(0, 1, 3, 2).contiguous()
covariances = gaussians.covariances
for i in range(B):
xyz_i = xyzs[i].float()
feature_i = features[i].float()
covar_i = covariances[i].float()
scale_i = scales[i].float()
rotation_i = rotations[i].float()
opacity_i = opacitys[i].squeeze().float()
test_w2c_i = extrinsics[i].float().inverse() # (V, 4, 4)
test_intr_i_normalized = intrinsics[i].float()
# Denormalize the intrinsics into standred format
test_intr_i = test_intr_i_normalized.clone()
test_intr_i[:, 0] = test_intr_i_normalized[:, 0] * W
test_intr_i[:, 1] = test_intr_i_normalized[:, 1] * H
sh_degree = (int(sqrt(feature_i.shape[-2])) - 1)
rendering_list = []
rendering_depth_list = []
rendering_alpha_list = []
for j in range(V):
rendering, alpha, _ = rasterization(xyz_i, rotation_i, scale_i, opacity_i, feature_i,
test_w2c_i[j:j+1], test_intr_i[j:j+1], W, H, sh_degree=sh_degree,
# near_plane=near[i].mean(), far_plane=far[i].mean(),
render_mode="RGB+D", packed=False,
near_plane=1e-10,
backgrounds=self.background_color.unsqueeze(0).repeat(1, 1),
radius_clip=0.1,
covars=covar_i,
rasterize_mode='classic') # (V, H, W, 3)
rendering_img, rendering_depth = torch.split(rendering, [3, 1], dim=-1)
rendering_img = rendering_img.clamp(0.0, 1.0)
rendering_list.append(rendering_img.permute(0, 3, 1, 2))
rendering_depth_list.append(rendering_depth)
rendering_alpha_list.append(alpha)
rendered_depths.append(torch.cat(rendering_depth_list, dim=0).squeeze())
rendered_imgs.append(torch.cat(rendering_list, dim=0))
rendered_alphas.append(torch.cat(rendering_alpha_list, dim=0).squeeze())
return DecoderOutput(torch.stack(rendered_imgs), torch.stack(rendered_depths), torch.stack(rendered_alphas), lod_rendering=None)
def forward(
self,
gaussians: Gaussians,
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
near: Float[Tensor, "batch view"],
far: Float[Tensor, "batch view"],
image_shape: tuple[int, int],
depth_mode: DepthRenderingMode | None = None,
cam_rot_delta: Float[Tensor, "batch view 3"] | None = None,
cam_trans_delta: Float[Tensor, "batch view 3"] | None = None,
) -> DecoderOutput:
return self.rendering_fn(gaussians, extrinsics, intrinsics, near, far, image_shape, depth_mode, cam_rot_delta, cam_trans_delta)