File size: 5,003 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from dataclasses import dataclass
from typing import Literal
import torch
from einops import rearrange, repeat
from jaxtyping import Float
from torch import Tensor
import torchvision
from ..types import Gaussians
# from .cuda_splatting import DepthRenderingMode, render_cuda
from .decoder import Decoder, DecoderOutput
from math import sqrt
from gsplat import rasterization
from ...misc.utils import vis_depth_map
DepthRenderingMode = Literal["depth", "disparity", "relative_disparity", "log"]
@dataclass
class DecoderSplattingCUDACfg:
name: Literal["splatting_cuda"]
background_color: list[float]
make_scale_invariant: bool
class DecoderSplattingCUDA(Decoder[DecoderSplattingCUDACfg]):
background_color: Float[Tensor, "3"]
def __init__(
self,
cfg: DecoderSplattingCUDACfg,
) -> None:
super().__init__(cfg)
self.make_scale_invariant = cfg.make_scale_invariant
self.register_buffer(
"background_color",
torch.tensor(cfg.background_color, dtype=torch.float32),
persistent=False,
)
def rendering_fn(
self,
gaussians: Gaussians,
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
near: Float[Tensor, "batch view"],
far: Float[Tensor, "batch view"],
image_shape: tuple[int, int],
depth_mode: DepthRenderingMode | None = None,
cam_rot_delta: Float[Tensor, "batch view 3"] | None = None,
cam_trans_delta: Float[Tensor, "batch view 3"] | None = None,
) -> DecoderOutput:
B, V, _, _ = intrinsics.shape
H, W = image_shape
rendered_imgs, rendered_depths, rendered_alphas = [], [], []
xyzs, opacitys, rotations, scales, features = gaussians.means, gaussians.opacities, gaussians.rotations, gaussians.scales, gaussians.harmonics.permute(0, 1, 3, 2).contiguous()
covariances = gaussians.covariances
for i in range(B):
xyz_i = xyzs[i].float()
feature_i = features[i].float()
covar_i = covariances[i].float()
scale_i = scales[i].float()
rotation_i = rotations[i].float()
opacity_i = opacitys[i].squeeze().float()
test_w2c_i = extrinsics[i].float().inverse() # (V, 4, 4)
test_intr_i_normalized = intrinsics[i].float()
# Denormalize the intrinsics into standred format
test_intr_i = test_intr_i_normalized.clone()
test_intr_i[:, 0] = test_intr_i_normalized[:, 0] * W
test_intr_i[:, 1] = test_intr_i_normalized[:, 1] * H
sh_degree = (int(sqrt(feature_i.shape[-2])) - 1)
rendering_list = []
rendering_depth_list = []
rendering_alpha_list = []
for j in range(V):
rendering, alpha, _ = rasterization(xyz_i, rotation_i, scale_i, opacity_i, feature_i,
test_w2c_i[j:j+1], test_intr_i[j:j+1], W, H, sh_degree=sh_degree,
# near_plane=near[i].mean(), far_plane=far[i].mean(),
render_mode="RGB+D", packed=False,
near_plane=1e-10,
backgrounds=self.background_color.unsqueeze(0).repeat(1, 1),
radius_clip=0.1,
covars=covar_i,
rasterize_mode='classic') # (V, H, W, 3)
rendering_img, rendering_depth = torch.split(rendering, [3, 1], dim=-1)
rendering_img = rendering_img.clamp(0.0, 1.0)
rendering_list.append(rendering_img.permute(0, 3, 1, 2))
rendering_depth_list.append(rendering_depth)
rendering_alpha_list.append(alpha)
rendered_depths.append(torch.cat(rendering_depth_list, dim=0).squeeze())
rendered_imgs.append(torch.cat(rendering_list, dim=0))
rendered_alphas.append(torch.cat(rendering_alpha_list, dim=0).squeeze())
return DecoderOutput(torch.stack(rendered_imgs), torch.stack(rendered_depths), torch.stack(rendered_alphas), lod_rendering=None)
def forward(
self,
gaussians: Gaussians,
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
near: Float[Tensor, "batch view"],
far: Float[Tensor, "batch view"],
image_shape: tuple[int, int],
depth_mode: DepthRenderingMode | None = None,
cam_rot_delta: Float[Tensor, "batch view 3"] | None = None,
cam_trans_delta: Float[Tensor, "batch view 3"] | None = None,
) -> DecoderOutput:
return self.rendering_fn(gaussians, extrinsics, intrinsics, near, far, image_shape, depth_mode, cam_rot_delta, cam_trans_delta)
|