File size: 6,975 Bytes
57746f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
import torch.nn as nn
from einops import rearrange
from src.utils.gaussian_model import build_covariance
from simple_knn._C import distCUDA2
from src.utils.sh_utils import RGB2SH
class GaussianHead(nn.Module):
def __init__(self, d_pt_feat=64, **kwargs):
super().__init__()
# args
self.args = kwargs
self.d_means = 3
self.d_scales = 3
self.d_rotations = 4
self.d_opacities = 1
self.sh_degree = 3
self.d_view_dep_features = 3 # RGB
self.d_sh = (self.sh_degree + 1) ** 2
self.d_attr = (self.d_scales + self.d_rotations + self.d_opacities + self.d_view_dep_features * self.d_sh)
if self.args.get('d_gs_feats'):
self.d_attr += self.args['d_gs_feats']
# Create a mask for the spherical harmonics coefficients.
# This ensures that at initialization, the coefficients are biased
# towards having a large DC component and small view-dependent components.
self.register_buffer(
"sh_mask",
torch.ones((self.d_sh,), dtype=torch.float32),
persistent=False,
)
for degree in range(1, self.sh_degree + 1):
self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.5 * 0.25**degree
self.gaussian_proj = nn.Linear(d_pt_feat, self.d_attr)
# Activation functions
self.scale_activation = torch.exp
self.rotation_activation = torch.nn.functional.normalize
self.opacity_activation = torch.sigmoid
def forward(self, point_transformer_output, lseg_features=None):
pred1 = {}
pred2 = {}
scene_scale = point_transformer_output['scale'] # B, 1, 1
scene_center = point_transformer_output['center'] # B, 1, 3
B, H, W, _ = point_transformer_output['shape']
normalized_means = point_transformer_output['coord'] # B * V * H * W, 3
colors = point_transformer_output['color'] # B * V * H * W, 3
# split normalized_means to 2 views
normalized_means = rearrange(normalized_means, '(b v h w) c -> v b (h w) c', v=2, b=B, h=H, w=W)
means = normalized_means * scene_scale + scene_center # V, B, H * W, 3
means = rearrange(means, 'v b (h w) c -> b (v h w) c', b=B, v=2, h=H, w=W)
# get features
feat = point_transformer_output['feat']
gaussian_attr = self.gaussian_proj(feat)
# # split gaussian attributes
# scales, rotations, opacities, sh_coeffs = torch.split(gaussian_attr,
# [
# self.d_scales,
# self.d_rotations,
# self.d_opacities,
# self.d_view_dep_features * self.d_sh
# ],
# dim=-1)
scales, rotations, opacities, sh_coeffs, gs_feats = torch.split(gaussian_attr,
[
self.d_scales,
self.d_rotations,
self.d_opacities,
self.d_view_dep_features * self.d_sh,
self.args['d_gs_feats']
],
dim=-1)
# scales
# calculate the distance between each point and its nearest neighbor
all_dist = torch.stack([torch.sqrt(torch.clamp_min(distCUDA2(pts3d), 0.0000001)) for pts3d in means]) # B, V * H * W
median_dist = all_dist.median(dim=-1)[0][:, None, None] # B, 1, 1
scales = self.scale_activation(scales)
scales = rearrange(scales, '(b v h w) c -> b (v h w) c', b=B, v=2, h=H, w=W)
scales = scales * all_dist[..., None]
# clip scales
scales = torch.clamp(scales, min=0.1 * median_dist, max=3.0 * median_dist)
scales = rearrange(scales, 'b (v h w) c -> (b v h w) c', b=B, v=2, h=H, w=W)
# activation
rotations = self.rotation_activation(rotations)
opacities = self.opacity_activation(opacities)
# build covariance matrix
covs = build_covariance(scales, rotations)
# sh_mask
sh_coeffs = rearrange(sh_coeffs, '(b v h w) (c d) -> (b v h w) c d', b=B, v=2, h=H, w=W, c=self.d_sh, d=self.d_view_dep_features)
sh_dc = sh_coeffs[..., 0, :]
sh_rest = sh_coeffs[..., 1:, :]
if self.args.get('rgb_residual'):
# denormalize colors
colors = colors * 0.5 + 0.5
sh_rgb = RGB2SH(colors) # (B * V * H * W, 3)
# add rgb residual to dc component
sh_dc = sh_dc + sh_rgb
# concatenate dc and rest
sh_coeffs = torch.cat([sh_dc[..., None, :], sh_rest], dim=-2)
sh_coeffs = sh_coeffs * self.sh_mask[None, :, None]
# lseg_features(learning residual)
lseg_features = rearrange(lseg_features, '(v b) c h w -> (b v h w) c', b=B, v=2, h=H, w=W)
gs_feats = gs_feats + lseg_features
# split to 2 views
scales = rearrange(scales, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
rotations = rearrange(rotations, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
opacities = rearrange(opacities, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
sh_coeffs = rearrange(sh_coeffs, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
covs = rearrange(covs, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
means = rearrange(means, 'b (v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
gs_feats = rearrange(gs_feats, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
pred1['scales'] = scales[0]
pred1['rotations'] = rotations[0]
pred1['covs'] = covs[0]
pred1['opacities'] = opacities[0]
pred1['sh_coeffs'] = sh_coeffs[0]
pred1['means'] = means[0]
pred1['gs_feats'] = gs_feats[0]
pred2['scales'] = scales[1]
pred2['rotations'] = rotations[1]
pred2['covs'] = covs[1]
pred2['opacities'] = opacities[1]
pred2['sh_coeffs'] = sh_coeffs[1]
pred2['means'] = means[1]
pred2['gs_feats'] = gs_feats[1]
return pred1, pred2
|