LSM / src /gaussian_head.py
kairunwen's picture
Update Code
57746f1
import torch
import torch.nn as nn
from einops import rearrange
from src.utils.gaussian_model import build_covariance
from simple_knn._C import distCUDA2
from src.utils.sh_utils import RGB2SH
class GaussianHead(nn.Module):
def __init__(self, d_pt_feat=64, **kwargs):
super().__init__()
# args
self.args = kwargs
self.d_means = 3
self.d_scales = 3
self.d_rotations = 4
self.d_opacities = 1
self.sh_degree = 3
self.d_view_dep_features = 3 # RGB
self.d_sh = (self.sh_degree + 1) ** 2
self.d_attr = (self.d_scales + self.d_rotations + self.d_opacities + self.d_view_dep_features * self.d_sh)
if self.args.get('d_gs_feats'):
self.d_attr += self.args['d_gs_feats']
# Create a mask for the spherical harmonics coefficients.
# This ensures that at initialization, the coefficients are biased
# towards having a large DC component and small view-dependent components.
self.register_buffer(
"sh_mask",
torch.ones((self.d_sh,), dtype=torch.float32),
persistent=False,
)
for degree in range(1, self.sh_degree + 1):
self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.5 * 0.25**degree
self.gaussian_proj = nn.Linear(d_pt_feat, self.d_attr)
# Activation functions
self.scale_activation = torch.exp
self.rotation_activation = torch.nn.functional.normalize
self.opacity_activation = torch.sigmoid
def forward(self, point_transformer_output, lseg_features=None):
pred1 = {}
pred2 = {}
scene_scale = point_transformer_output['scale'] # B, 1, 1
scene_center = point_transformer_output['center'] # B, 1, 3
B, H, W, _ = point_transformer_output['shape']
normalized_means = point_transformer_output['coord'] # B * V * H * W, 3
colors = point_transformer_output['color'] # B * V * H * W, 3
# split normalized_means to 2 views
normalized_means = rearrange(normalized_means, '(b v h w) c -> v b (h w) c', v=2, b=B, h=H, w=W)
means = normalized_means * scene_scale + scene_center # V, B, H * W, 3
means = rearrange(means, 'v b (h w) c -> b (v h w) c', b=B, v=2, h=H, w=W)
# get features
feat = point_transformer_output['feat']
gaussian_attr = self.gaussian_proj(feat)
# # split gaussian attributes
# scales, rotations, opacities, sh_coeffs = torch.split(gaussian_attr,
# [
# self.d_scales,
# self.d_rotations,
# self.d_opacities,
# self.d_view_dep_features * self.d_sh
# ],
# dim=-1)
scales, rotations, opacities, sh_coeffs, gs_feats = torch.split(gaussian_attr,
[
self.d_scales,
self.d_rotations,
self.d_opacities,
self.d_view_dep_features * self.d_sh,
self.args['d_gs_feats']
],
dim=-1)
# scales
# calculate the distance between each point and its nearest neighbor
all_dist = torch.stack([torch.sqrt(torch.clamp_min(distCUDA2(pts3d), 0.0000001)) for pts3d in means]) # B, V * H * W
median_dist = all_dist.median(dim=-1)[0][:, None, None] # B, 1, 1
scales = self.scale_activation(scales)
scales = rearrange(scales, '(b v h w) c -> b (v h w) c', b=B, v=2, h=H, w=W)
scales = scales * all_dist[..., None]
# clip scales
scales = torch.clamp(scales, min=0.1 * median_dist, max=3.0 * median_dist)
scales = rearrange(scales, 'b (v h w) c -> (b v h w) c', b=B, v=2, h=H, w=W)
# activation
rotations = self.rotation_activation(rotations)
opacities = self.opacity_activation(opacities)
# build covariance matrix
covs = build_covariance(scales, rotations)
# sh_mask
sh_coeffs = rearrange(sh_coeffs, '(b v h w) (c d) -> (b v h w) c d', b=B, v=2, h=H, w=W, c=self.d_sh, d=self.d_view_dep_features)
sh_dc = sh_coeffs[..., 0, :]
sh_rest = sh_coeffs[..., 1:, :]
if self.args.get('rgb_residual'):
# denormalize colors
colors = colors * 0.5 + 0.5
sh_rgb = RGB2SH(colors) # (B * V * H * W, 3)
# add rgb residual to dc component
sh_dc = sh_dc + sh_rgb
# concatenate dc and rest
sh_coeffs = torch.cat([sh_dc[..., None, :], sh_rest], dim=-2)
sh_coeffs = sh_coeffs * self.sh_mask[None, :, None]
# lseg_features(learning residual)
lseg_features = rearrange(lseg_features, '(v b) c h w -> (b v h w) c', b=B, v=2, h=H, w=W)
gs_feats = gs_feats + lseg_features
# split to 2 views
scales = rearrange(scales, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
rotations = rearrange(rotations, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
opacities = rearrange(opacities, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
sh_coeffs = rearrange(sh_coeffs, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
covs = rearrange(covs, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
means = rearrange(means, 'b (v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
gs_feats = rearrange(gs_feats, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
pred1['scales'] = scales[0]
pred1['rotations'] = rotations[0]
pred1['covs'] = covs[0]
pred1['opacities'] = opacities[0]
pred1['sh_coeffs'] = sh_coeffs[0]
pred1['means'] = means[0]
pred1['gs_feats'] = gs_feats[0]
pred2['scales'] = scales[1]
pred2['rotations'] = rotations[1]
pred2['covs'] = covs[1]
pred2['opacities'] = opacities[1]
pred2['sh_coeffs'] = sh_coeffs[1]
pred2['means'] = means[1]
pred2['gs_feats'] = gs_feats[1]
return pred1, pred2