AnySplat / src /model /decoder /decoder_splatting_cuda.py
alexnasa's picture
Upload 243 files
2568013 verified
raw
history blame
5 kB
from dataclasses import dataclass
from typing import Literal
import torch
from einops import rearrange, repeat
from jaxtyping import Float
from torch import Tensor
import torchvision
from ..types import Gaussians
# from .cuda_splatting import DepthRenderingMode, render_cuda
from .decoder import Decoder, DecoderOutput
from math import sqrt
from gsplat import rasterization
from ...misc.utils import vis_depth_map
DepthRenderingMode = Literal["depth", "disparity", "relative_disparity", "log"]
@dataclass
class DecoderSplattingCUDACfg:
name: Literal["splatting_cuda"]
background_color: list[float]
make_scale_invariant: bool
class DecoderSplattingCUDA(Decoder[DecoderSplattingCUDACfg]):
background_color: Float[Tensor, "3"]
def __init__(
self,
cfg: DecoderSplattingCUDACfg,
) -> None:
super().__init__(cfg)
self.make_scale_invariant = cfg.make_scale_invariant
self.register_buffer(
"background_color",
torch.tensor(cfg.background_color, dtype=torch.float32),
persistent=False,
)
def rendering_fn(
self,
gaussians: Gaussians,
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
near: Float[Tensor, "batch view"],
far: Float[Tensor, "batch view"],
image_shape: tuple[int, int],
depth_mode: DepthRenderingMode | None = None,
cam_rot_delta: Float[Tensor, "batch view 3"] | None = None,
cam_trans_delta: Float[Tensor, "batch view 3"] | None = None,
) -> DecoderOutput:
B, V, _, _ = intrinsics.shape
H, W = image_shape
rendered_imgs, rendered_depths, rendered_alphas = [], [], []
xyzs, opacitys, rotations, scales, features = gaussians.means, gaussians.opacities, gaussians.rotations, gaussians.scales, gaussians.harmonics.permute(0, 1, 3, 2).contiguous()
covariances = gaussians.covariances
for i in range(B):
xyz_i = xyzs[i].float()
feature_i = features[i].float()
covar_i = covariances[i].float()
scale_i = scales[i].float()
rotation_i = rotations[i].float()
opacity_i = opacitys[i].squeeze().float()
test_w2c_i = extrinsics[i].float().inverse() # (V, 4, 4)
test_intr_i_normalized = intrinsics[i].float()
# Denormalize the intrinsics into standred format
test_intr_i = test_intr_i_normalized.clone()
test_intr_i[:, 0] = test_intr_i_normalized[:, 0] * W
test_intr_i[:, 1] = test_intr_i_normalized[:, 1] * H
sh_degree = (int(sqrt(feature_i.shape[-2])) - 1)
rendering_list = []
rendering_depth_list = []
rendering_alpha_list = []
for j in range(V):
rendering, alpha, _ = rasterization(xyz_i, rotation_i, scale_i, opacity_i, feature_i,
test_w2c_i[j:j+1], test_intr_i[j:j+1], W, H, sh_degree=sh_degree,
# near_plane=near[i].mean(), far_plane=far[i].mean(),
render_mode="RGB+D", packed=False,
near_plane=1e-10,
backgrounds=self.background_color.unsqueeze(0).repeat(1, 1),
radius_clip=0.1,
covars=covar_i,
rasterize_mode='classic') # (V, H, W, 3)
rendering_img, rendering_depth = torch.split(rendering, [3, 1], dim=-1)
rendering_img = rendering_img.clamp(0.0, 1.0)
rendering_list.append(rendering_img.permute(0, 3, 1, 2))
rendering_depth_list.append(rendering_depth)
rendering_alpha_list.append(alpha)
rendered_depths.append(torch.cat(rendering_depth_list, dim=0).squeeze())
rendered_imgs.append(torch.cat(rendering_list, dim=0))
rendered_alphas.append(torch.cat(rendering_alpha_list, dim=0).squeeze())
return DecoderOutput(torch.stack(rendered_imgs), torch.stack(rendered_depths), torch.stack(rendered_alphas), lod_rendering=None)
def forward(
self,
gaussians: Gaussians,
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
near: Float[Tensor, "batch view"],
far: Float[Tensor, "batch view"],
image_shape: tuple[int, int],
depth_mode: DepthRenderingMode | None = None,
cam_rot_delta: Float[Tensor, "batch view 3"] | None = None,
cam_trans_delta: Float[Tensor, "batch view 3"] | None = None,
) -> DecoderOutput:
return self.rendering_fn(gaussians, extrinsics, intrinsics, near, far, image_shape, depth_mode, cam_rot_delta, cam_trans_delta)