File size: 6,531 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from einops import einsum, rearrange
from jaxtyping import Float
from torch import Tensor, nn
from src.geometry.projection import get_world_rays
from src.misc.sh_rotation import rotate_sh
from .gaussians import build_covariance
from ...types import Gaussians
@dataclass
class GaussianAdapterCfg:
gaussian_scale_min: float
gaussian_scale_max: float
sh_degree: int
class GaussianAdapter(nn.Module):
cfg: GaussianAdapterCfg
def __init__(self, cfg: GaussianAdapterCfg):
super().__init__()
self.cfg = cfg
# Create a mask for the spherical harmonics coefficients. This ensures that at
# initialization, the coefficients are biased towards having a large DC
# component and small view-dependent components.
self.register_buffer(
"sh_mask",
torch.ones((self.d_sh,), dtype=torch.float32),
persistent=False,
)
for degree in range(1, self.cfg.sh_degree + 1):
self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree
def forward(
self,
extrinsics: Float[Tensor, "*#batch 4 4"],
intrinsics: Float[Tensor, "*#batch 3 3"],
coordinates: Float[Tensor, "*#batch 2"],
depths: Float[Tensor, "*#batch"],
opacities: Float[Tensor, "*#batch"],
raw_gaussians: Float[Tensor, "*#batch _"],
image_shape: tuple[int, int],
eps: float = 1e-8,
) -> Gaussians:
device = extrinsics.device
scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1)
# Map scale features to valid scale range.
scale_min = self.cfg.gaussian_scale_min
scale_max = self.cfg.gaussian_scale_max
scales = scale_min + (scale_max - scale_min) * scales.sigmoid()
h, w = image_shape
pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=device)
multiplier = self.get_scale_multiplier(intrinsics, pixel_size)
scales = scales * depths[..., None] * multiplier[..., None]
# Normalize the quaternion features to yield a valid quaternion.
rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps)
sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3)
sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)) * self.sh_mask
# Create world-space covariance matrices.
covariances = build_covariance(scales, rotations)
c2w_rotations = extrinsics[..., :3, :3]
covariances = c2w_rotations @ covariances @ c2w_rotations.transpose(-1, -2)
# Compute Gaussian means.
origins, directions = get_world_rays(coordinates, extrinsics, intrinsics)
means = origins + directions * depths[..., None]
return Gaussians(
means=means,
covariances=covariances,
# harmonics=rotate_sh(sh, c2w_rotations[..., None, :, :]),
harmonics=sh,
opacities=opacities,
# Note: These aren't yet rotated into world space, but they're only used for
# exporting Gaussians to ply files. This needs to be fixed...
scales=scales,
rotations=rotations.broadcast_to((*scales.shape[:-1], 4)),
)
def get_scale_multiplier(
self,
intrinsics: Float[Tensor, "*#batch 3 3"],
pixel_size: Float[Tensor, "*#batch 2"],
multiplier: float = 0.1,
) -> Float[Tensor, " *batch"]:
xy_multipliers = multiplier * einsum(
intrinsics[..., :2, :2].inverse(),
pixel_size,
"... i j, j -> ... i",
)
return xy_multipliers.sum(dim=-1)
@property
def d_sh(self) -> int:
return (self.cfg.sh_degree + 1) ** 2
@property
def d_in(self) -> int:
return 7 + 3 * self.d_sh
class UnifiedGaussianAdapter(GaussianAdapter):
def forward(
self,
means: Float[Tensor, "*#batch 3"],
# levels: Float[Tensor, "*#batch"],
depths: Float[Tensor, "*#batch"],
opacities: Float[Tensor, "*#batch"],
raw_gaussians: Float[Tensor, "*#batch _"],
eps: float = 1e-8,
intrinsics: Optional[Float[Tensor, "*#batch 3 3"]] = None,
coordinates: Optional[Float[Tensor, "*#batch 2"]] = None,
) -> Gaussians:
scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1)
scales = 0.001 * F.softplus(scales)
scales = scales.clamp_max(0.3)
# Normalize the quaternion features to yield a valid quaternion.
rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps)
sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3)
sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)) * self.sh_mask
# print(scales.max())
covariances = build_covariance(scales, rotations)
return Gaussians(
means=means.float(),
# levels=levels.int(),
covariances=covariances.float(),
harmonics=sh.float(),
opacities=opacities.float(),
scales=scales.float(),
rotations=rotations.float(),
)
class Unet3dGaussianAdapter(GaussianAdapter):
def forward(
self,
means: Float[Tensor, "*#batch 3"],
depths: Float[Tensor, "*#batch"],
opacities: Float[Tensor, "*#batch"],
raw_gaussians: Float[Tensor, "*#batch _"],
eps: float = 1e-8,
intrinsics: Optional[Float[Tensor, "*#batch 3 3"]] = None,
coordinates: Optional[Float[Tensor, "*#batch 2"]] = None,
) -> Gaussians:
scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1)
scales = 0.001 * F.softplus(scales)
scales = scales.clamp_max(0.3)
# Normalize the quaternion features to yield a valid quaternion.
rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps)
sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3)
sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)) * self.sh_mask
covariances = build_covariance(scales, rotations)
return Gaussians(
means=means,
covariances=covariances,
harmonics=sh,
opacities=opacities,
scales=scales,
rotations=rotations,
)
|