File size: 6,975 Bytes
57746f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import torch.nn as nn
from einops import rearrange
from src.utils.gaussian_model import build_covariance
from simple_knn._C import distCUDA2
from src.utils.sh_utils import RGB2SH

class GaussianHead(nn.Module):
    def __init__(self, d_pt_feat=64, **kwargs):
        super().__init__()
        # args
        self.args = kwargs
        self.d_means = 3
        self.d_scales = 3
        self.d_rotations = 4
        self.d_opacities = 1
        self.sh_degree = 3
        self.d_view_dep_features = 3 # RGB
        self.d_sh = (self.sh_degree + 1) ** 2
        self.d_attr = (self.d_scales + self.d_rotations + self.d_opacities + self.d_view_dep_features * self.d_sh)
        if self.args.get('d_gs_feats'):
            self.d_attr += self.args['d_gs_feats']

        # Create a mask for the spherical harmonics coefficients. 
        # This ensures that at initialization, the coefficients are biased 
        # towards having a large DC component and small view-dependent components.
        self.register_buffer(
            "sh_mask",
            torch.ones((self.d_sh,), dtype=torch.float32),
            persistent=False,
        )
        for degree in range(1, self.sh_degree + 1):
            self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.5 * 0.25**degree
        
        self.gaussian_proj = nn.Linear(d_pt_feat, self.d_attr)
        
        # Activation functions
        self.scale_activation = torch.exp
        self.rotation_activation = torch.nn.functional.normalize
        self.opacity_activation = torch.sigmoid

    def forward(self, point_transformer_output, lseg_features=None):
        pred1 = {}
        pred2 = {}

        scene_scale = point_transformer_output['scale'] # B, 1, 1
        scene_center = point_transformer_output['center'] # B, 1, 3
        B, H, W, _ = point_transformer_output['shape']
        normalized_means = point_transformer_output['coord'] # B * V * H * W, 3
        colors = point_transformer_output['color'] # B * V * H * W, 3

        # split normalized_means to 2 views
        normalized_means = rearrange(normalized_means, '(b v h w) c -> v b (h w) c', v=2, b=B, h=H, w=W)
        means = normalized_means * scene_scale + scene_center # V, B, H * W, 3
        means = rearrange(means, 'v b (h w) c -> b (v h w) c', b=B, v=2, h=H, w=W)

        # get features
        feat = point_transformer_output['feat']
        gaussian_attr = self.gaussian_proj(feat)
        
        # # split gaussian attributes
        # scales, rotations, opacities, sh_coeffs = torch.split(gaussian_attr, 
        #                                                               [
        #                                                                   self.d_scales, 
        #                                                                   self.d_rotations, 
        #                                                                   self.d_opacities, 
        #                                                                   self.d_view_dep_features * self.d_sh
        #                                                               ], 
        #                                                               dim=-1)
        
        scales, rotations, opacities, sh_coeffs, gs_feats = torch.split(gaussian_attr, 
                                                                      [
                                                                          self.d_scales, 
                                                                          self.d_rotations, 
                                                                          self.d_opacities, 
                                                                          self.d_view_dep_features * self.d_sh,
                                                                          self.args['d_gs_feats']
                                                                      ], 
                                                                      dim=-1)

        # scales
        # calculate the distance between each point and its nearest neighbor
        all_dist = torch.stack([torch.sqrt(torch.clamp_min(distCUDA2(pts3d), 0.0000001)) for pts3d in means]) # B, V * H * W
        median_dist = all_dist.median(dim=-1)[0][:, None, None] # B, 1, 1
        scales = self.scale_activation(scales)
        scales = rearrange(scales, '(b v h w) c -> b (v h w) c', b=B, v=2, h=H, w=W)
        scales = scales * all_dist[..., None]
        # clip scales
        scales = torch.clamp(scales, min=0.1 * median_dist, max=3.0 * median_dist)
        scales = rearrange(scales, 'b (v h w) c -> (b v h w) c', b=B, v=2, h=H, w=W)
        
        # activation
        rotations = self.rotation_activation(rotations)
        opacities = self.opacity_activation(opacities)
        
        # build covariance matrix
        covs = build_covariance(scales, rotations)
        
        # sh_mask
        sh_coeffs = rearrange(sh_coeffs, '(b v h w) (c d) -> (b v h w) c d', b=B, v=2, h=H, w=W, c=self.d_sh, d=self.d_view_dep_features)
        sh_dc = sh_coeffs[..., 0, :]
        sh_rest = sh_coeffs[..., 1:, :]
        if self.args.get('rgb_residual'):
            # denormalize colors
            colors = colors * 0.5 + 0.5
            sh_rgb = RGB2SH(colors) # (B * V * H * W, 3)
            # add rgb residual to dc component
            sh_dc = sh_dc + sh_rgb
            # concatenate dc and rest
            sh_coeffs = torch.cat([sh_dc[..., None, :], sh_rest], dim=-2)
        sh_coeffs = sh_coeffs * self.sh_mask[None, :, None]
        
        # lseg_features(learning residual)
        lseg_features = rearrange(lseg_features, '(v b) c h w -> (b v h w) c', b=B, v=2, h=H, w=W)
        gs_feats = gs_feats + lseg_features

        # split to 2 views
        scales = rearrange(scales, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
        rotations = rearrange(rotations, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
        opacities = rearrange(opacities, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
        sh_coeffs = rearrange(sh_coeffs, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
        covs = rearrange(covs, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
        means = rearrange(means, 'b (v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)
        gs_feats = rearrange(gs_feats, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W)

        pred1['scales'] = scales[0]
        pred1['rotations'] = rotations[0]
        pred1['covs'] = covs[0]
        pred1['opacities'] = opacities[0]
        pred1['sh_coeffs'] = sh_coeffs[0]
        pred1['means'] = means[0]
        pred1['gs_feats'] = gs_feats[0]

        pred2['scales'] = scales[1]
        pred2['rotations'] = rotations[1]
        pred2['covs'] = covs[1]
        pred2['opacities'] = opacities[1]
        pred2['sh_coeffs'] = sh_coeffs[1]
        pred2['means'] = means[1]
        pred2['gs_feats'] = gs_feats[1]
        
        return pred1, pred2