File size: 9,852 Bytes
491eded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.

from ast import Dict
import math

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch_scatter import scatter_mean #, scatter_max

from .unet_3daware import setup_unet #UNetTriplane3dAware
from .conv_pointnet import ConvPointnet

from .pc_encoder import PVCNNEncoder #PointNet

import einops

from .dnnlib_util import ScopedTorchProfiler, printarr

def generate_plane_features(p, c, resolution, plane='xz'):
    """
    Args:
        p: (B,3,n_p)
        c: (B,C,n_p)
    """
    padding = 0.
    c_dim = c.size(1)
    # acquire indices of features in plane
    xy = normalize_coordinate(p.clone(), plane=plane, padding=padding) # normalize to the range of (0, 1)
    index = coordinate2index(xy, resolution)

    # scatter plane features from points
    fea_plane = c.new_zeros(p.size(0), c_dim, resolution**2)
    fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
    fea_plane = fea_plane.reshape(p.size(0), c_dim, resolution, resolution) # sparce matrix (B x 512 x reso x reso)
    return fea_plane

def normalize_coordinate(p, padding=0.1, plane='xz'):
    ''' Normalize coordinate to [0, 1] for unit cube experiments

    Args:
        p (tensor): point
        padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
        plane (str): plane feature type, ['xz', 'xy', 'yz']
    '''
    if plane == 'xz':
        xy = p[:, :, [0, 2]]
    elif plane =='xy':
        xy = p[:, :, [0, 1]]
    else:
        xy = p[:, :, [1, 2]]

    xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
    xy_new = xy_new + 0.5 # range (0, 1)

    # if there are outliers out of the range
    if xy_new.max() >= 1:
        xy_new[xy_new >= 1] = 1 - 10e-6
    if xy_new.min() < 0:
        xy_new[xy_new < 0] = 0.0
    return xy_new


def coordinate2index(x, resolution):
    ''' Normalize coordinate to [0, 1] for unit cube experiments.
        Corresponds to our 3D model

    Args:
        x (tensor): coordinate
        reso (int): defined resolution
        coord_type (str): coordinate type
    '''
    x = (x * resolution).long()
    index = x[:, :, 0] + resolution * x[:, :, 1]
    index = index[:, None, :]
    return index

def softclip(x, min, max, hardness=5):
    # Soft clipping for the logsigma
    x = min + F.softplus(hardness*(x - min))/hardness
    x = max - F.softplus(-hardness*(x - max))/hardness
    return x


def sample_triplane_feat(feature_triplane, normalized_pos):
    '''
        normalized_pos [-1, 1]
    '''
    tri_plane = torch.unbind(feature_triplane, dim=1)
    
    x_feat = F.grid_sample(
        tri_plane[0],
        torch.cat(
            [normalized_pos[:, :, 0:1], normalized_pos[:, :, 1:2]],
            dim=-1).unsqueeze(dim=1), padding_mode='border',
        align_corners=True)
    y_feat = F.grid_sample(
        tri_plane[1],
        torch.cat(
            [normalized_pos[:, :, 1:2], normalized_pos[:, :, 2:3]],
            dim=-1).unsqueeze(dim=1), padding_mode='border',
        align_corners=True)

    z_feat = F.grid_sample(
        tri_plane[2],
        torch.cat(
            [normalized_pos[:, :, 0:1], normalized_pos[:, :, 2:3]],
            dim=-1).unsqueeze(dim=1), padding_mode='border',
        align_corners=True)
    final_feat = (x_feat + y_feat + z_feat)
    final_feat = final_feat.squeeze(dim=2).permute(0, 2, 1)  # 32dimension
    return final_feat


# @persistence.persistent_class
class TriPlanePC2Encoder(torch.nn.Module):
    # Encoder that encode point cloud to triplane feature vector similar to ConvOccNet
    def __init__(
            self,
            cfg,
            device='cuda',
            shape_min=-1.0,
            shape_length=2.0,
            use_2d_feat=False,
            # point_encoder='pvcnn',
            # use_point_scatter=False
    ):
        """
        Outputs latent triplane from PC input
        Configs:
            max_logsigma: (float) Soft clip upper range for logsigm
            min_logsigma: (float)
            point_encoder_type: (str) one of ['pvcnn', 'pointnet']
            pvcnn_flatten_voxels: (bool) for pvcnn whether to reduce voxel 
                features (instead of scattering point features)
            unet_cfg: (dict)
            z_triplane_channels: (int) output latent triplane
            z_triplane_resolution: (int)
        Args:

        """
        # assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
        super().__init__()
        self.device = device

        self.cfg = cfg

        self.shape_min = shape_min
        self.shape_length = shape_length

        self.z_triplane_resolution = cfg.z_triplane_resolution
        z_triplane_channels = cfg.z_triplane_channels

        point_encoder_out_dim = z_triplane_channels #* 2

        in_channels = 6
        # self.resample_filter=[1, 3, 3, 1]
        if cfg.point_encoder_type == 'pvcnn':
            self.pc_encoder = PVCNNEncoder(point_encoder_out_dim, 
            device=self.device, in_channels=in_channels, use_2d_feat=use_2d_feat)  # Encode it to a volume vector.
        elif cfg.point_encoder_type == 'pointnet':
            # TODO the pointnet was buggy, investigate
            self.pc_encoder = ConvPointnet(c_dim=point_encoder_out_dim, 
                                           dim=in_channels, hidden_dim=32, 
                                           plane_resolution=self.z_triplane_resolution, 
                                           padding=0)
        else:
            raise NotImplementedError(f"Point encoder {cfg.point_encoder_type} not implemented")

        if cfg.unet_cfg.enabled:
            self.unet_encoder = setup_unet(
                output_channels=point_encoder_out_dim, 
                input_channels=point_encoder_out_dim, 
                unet_cfg=cfg.unet_cfg)
        else:
            self.unet_encoder = None

    # @ScopedTorchProfiler('encode')
    def encode(self, point_cloud_xyz, point_cloud_feature, mv_feat=None, pc2pc_idx=None) -> Dict:
        # output = AttrDict()
        point_cloud_xyz = (point_cloud_xyz - self.shape_min) / self.shape_length # [0, 1]
        point_cloud_xyz = point_cloud_xyz - 0.5 # [-0.5, 0.5]
        point_cloud = torch.cat([point_cloud_xyz, point_cloud_feature], dim=-1)

        if self.cfg.point_encoder_type == 'pvcnn':
            if mv_feat is not None:
                pc_feat, points_feat = self.pc_encoder(point_cloud, mv_feat, pc2pc_idx)
            else:
                pc_feat, points_feat = self.pc_encoder(point_cloud)  # 3D feature volume: BxDx32x32x32
            if self.cfg.use_point_scatter:
                # Scattering from PVCNN point features
                points_feat_ = points_feat[0]
                # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
                pc_feat_1 = generate_plane_features(point_cloud_xyz, points_feat_, 
                                                    resolution=self.z_triplane_resolution, plane='xy') 
                pc_feat_2 = generate_plane_features(point_cloud_xyz, points_feat_,
                                                    resolution=self.z_triplane_resolution, plane='yz') 
                pc_feat_3 = generate_plane_features(point_cloud_xyz, points_feat_,
                                                    resolution=self.z_triplane_resolution, plane='xz') 
                pc_feat = pc_feat[0]

            else:
                pc_feat = pc_feat[0]
                sf = self.z_triplane_resolution//32 # 32 is PVCNN's voxel dim

                pc_feat_1 = torch.mean(pc_feat, dim=-1) #xy_plane, normalize in z plane
                pc_feat_2 = torch.mean(pc_feat, dim=-3) #yz_plane, normalize in x plane
                pc_feat_3 = torch.mean(pc_feat, dim=-2) #xz_plane, normalize in y plane

                # nearest upsample
                pc_feat_1 = einops.repeat(pc_feat_1, 'b c h w -> b c (h hm ) (w wm)', hm = sf, wm = sf)
                pc_feat_2 = einops.repeat(pc_feat_2, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf)
                pc_feat_3 = einops.repeat(pc_feat_3, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf)
        elif self.cfg.point_encoder_type == 'pointnet':
            assert self.cfg.use_point_scatter
            # Run ConvPointnet
            pc_feat = self.pc_encoder(point_cloud)
            pc_feat_1 = pc_feat['xy'] #
            pc_feat_2 = pc_feat['yz']
            pc_feat_3 = pc_feat['xz']
        else:
            raise NotImplementedError()

        if self.unet_encoder is not None:
            # TODO eval adding a skip connection
            # Unet expects B, 3, C, H, W
            pc_feat_tri_plane_stack_pre = torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1)
            # dpc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre)
            # pc_feat_tri_plane_stack = pc_feat_tri_plane_stack_pre + dpc_feat_tri_plane_stack
            pc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre)
            pc_feat_1, pc_feat_2, pc_feat_3 = torch.unbind(pc_feat_tri_plane_stack, dim=1)

        return torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1)
        
    def forward(self, point_cloud_xyz, point_cloud_feature=None, mv_feat=None, pc2pc_idx=None):
        return self.encode(point_cloud_xyz, point_cloud_feature=point_cloud_feature, mv_feat=mv_feat, pc2pc_idx=pc2pc_idx)