|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
from src.utils.gaussian_model import build_covariance |
|
from simple_knn._C import distCUDA2 |
|
from src.utils.sh_utils import RGB2SH |
|
|
|
class GaussianHead(nn.Module): |
|
def __init__(self, d_pt_feat=64, **kwargs): |
|
super().__init__() |
|
|
|
self.args = kwargs |
|
self.d_means = 3 |
|
self.d_scales = 3 |
|
self.d_rotations = 4 |
|
self.d_opacities = 1 |
|
self.sh_degree = 3 |
|
self.d_view_dep_features = 3 |
|
self.d_sh = (self.sh_degree + 1) ** 2 |
|
self.d_attr = (self.d_scales + self.d_rotations + self.d_opacities + self.d_view_dep_features * self.d_sh) |
|
if self.args.get('d_gs_feats'): |
|
self.d_attr += self.args['d_gs_feats'] |
|
|
|
|
|
|
|
|
|
self.register_buffer( |
|
"sh_mask", |
|
torch.ones((self.d_sh,), dtype=torch.float32), |
|
persistent=False, |
|
) |
|
for degree in range(1, self.sh_degree + 1): |
|
self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.5 * 0.25**degree |
|
|
|
self.gaussian_proj = nn.Linear(d_pt_feat, self.d_attr) |
|
|
|
|
|
self.scale_activation = torch.exp |
|
self.rotation_activation = torch.nn.functional.normalize |
|
self.opacity_activation = torch.sigmoid |
|
|
|
def forward(self, point_transformer_output, lseg_features=None): |
|
pred1 = {} |
|
pred2 = {} |
|
|
|
scene_scale = point_transformer_output['scale'] |
|
scene_center = point_transformer_output['center'] |
|
B, H, W, _ = point_transformer_output['shape'] |
|
normalized_means = point_transformer_output['coord'] |
|
colors = point_transformer_output['color'] |
|
|
|
|
|
normalized_means = rearrange(normalized_means, '(b v h w) c -> v b (h w) c', v=2, b=B, h=H, w=W) |
|
means = normalized_means * scene_scale + scene_center |
|
means = rearrange(means, 'v b (h w) c -> b (v h w) c', b=B, v=2, h=H, w=W) |
|
|
|
|
|
feat = point_transformer_output['feat'] |
|
gaussian_attr = self.gaussian_proj(feat) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scales, rotations, opacities, sh_coeffs, gs_feats = torch.split(gaussian_attr, |
|
[ |
|
self.d_scales, |
|
self.d_rotations, |
|
self.d_opacities, |
|
self.d_view_dep_features * self.d_sh, |
|
self.args['d_gs_feats'] |
|
], |
|
dim=-1) |
|
|
|
|
|
|
|
all_dist = torch.stack([torch.sqrt(torch.clamp_min(distCUDA2(pts3d), 0.0000001)) for pts3d in means]) |
|
median_dist = all_dist.median(dim=-1)[0][:, None, None] |
|
scales = self.scale_activation(scales) |
|
scales = rearrange(scales, '(b v h w) c -> b (v h w) c', b=B, v=2, h=H, w=W) |
|
scales = scales * all_dist[..., None] |
|
|
|
scales = torch.clamp(scales, min=0.1 * median_dist, max=3.0 * median_dist) |
|
scales = rearrange(scales, 'b (v h w) c -> (b v h w) c', b=B, v=2, h=H, w=W) |
|
|
|
|
|
rotations = self.rotation_activation(rotations) |
|
opacities = self.opacity_activation(opacities) |
|
|
|
|
|
covs = build_covariance(scales, rotations) |
|
|
|
|
|
sh_coeffs = rearrange(sh_coeffs, '(b v h w) (c d) -> (b v h w) c d', b=B, v=2, h=H, w=W, c=self.d_sh, d=self.d_view_dep_features) |
|
sh_dc = sh_coeffs[..., 0, :] |
|
sh_rest = sh_coeffs[..., 1:, :] |
|
if self.args.get('rgb_residual'): |
|
|
|
colors = colors * 0.5 + 0.5 |
|
sh_rgb = RGB2SH(colors) |
|
|
|
sh_dc = sh_dc + sh_rgb |
|
|
|
sh_coeffs = torch.cat([sh_dc[..., None, :], sh_rest], dim=-2) |
|
sh_coeffs = sh_coeffs * self.sh_mask[None, :, None] |
|
|
|
|
|
lseg_features = rearrange(lseg_features, '(v b) c h w -> (b v h w) c', b=B, v=2, h=H, w=W) |
|
gs_feats = gs_feats + lseg_features |
|
|
|
|
|
scales = rearrange(scales, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W) |
|
rotations = rearrange(rotations, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W) |
|
opacities = rearrange(opacities, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W) |
|
sh_coeffs = rearrange(sh_coeffs, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W) |
|
covs = rearrange(covs, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W) |
|
means = rearrange(means, 'b (v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W) |
|
gs_feats = rearrange(gs_feats, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W) |
|
|
|
pred1['scales'] = scales[0] |
|
pred1['rotations'] = rotations[0] |
|
pred1['covs'] = covs[0] |
|
pred1['opacities'] = opacities[0] |
|
pred1['sh_coeffs'] = sh_coeffs[0] |
|
pred1['means'] = means[0] |
|
pred1['gs_feats'] = gs_feats[0] |
|
|
|
pred2['scales'] = scales[1] |
|
pred2['rotations'] = rotations[1] |
|
pred2['covs'] = covs[1] |
|
pred2['opacities'] = opacities[1] |
|
pred2['sh_coeffs'] = sh_coeffs[1] |
|
pred2['means'] = means[1] |
|
pred2['gs_feats'] = gs_feats[1] |
|
|
|
return pred1, pred2 |
|
|