# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. from ast import Dict import math import numpy as np import torch from torch import nn import torch.nn.functional as F from torch_scatter import scatter_mean #, scatter_max from .unet_3daware import setup_unet #UNetTriplane3dAware from .conv_pointnet import ConvPointnet from .pc_encoder import PVCNNEncoder #PointNet import einops from .dnnlib_util import ScopedTorchProfiler, printarr def generate_plane_features(p, c, resolution, plane='xz'): """ Args: p: (B,3,n_p) c: (B,C,n_p) """ padding = 0. c_dim = c.size(1) # acquire indices of features in plane xy = normalize_coordinate(p.clone(), plane=plane, padding=padding) # normalize to the range of (0, 1) index = coordinate2index(xy, resolution) # scatter plane features from points fea_plane = c.new_zeros(p.size(0), c_dim, resolution**2) fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2 fea_plane = fea_plane.reshape(p.size(0), c_dim, resolution, resolution) # sparce matrix (B x 512 x reso x reso) return fea_plane def normalize_coordinate(p, padding=0.1, plane='xz'): ''' Normalize coordinate to [0, 1] for unit cube experiments Args: p (tensor): point padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] plane (str): plane feature type, ['xz', 'xy', 'yz'] ''' if plane == 'xz': xy = p[:, :, [0, 2]] elif plane =='xy': xy = p[:, :, [0, 1]] else: xy = p[:, :, [1, 2]] xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5) xy_new = xy_new + 0.5 # range (0, 1) # if there are outliers out of the range if xy_new.max() >= 1: xy_new[xy_new >= 1] = 1 - 10e-6 if xy_new.min() < 0: xy_new[xy_new < 0] = 0.0 return xy_new def coordinate2index(x, resolution): ''' Normalize coordinate to [0, 1] for unit cube experiments. Corresponds to our 3D model Args: x (tensor): coordinate reso (int): defined resolution coord_type (str): coordinate type ''' x = (x * resolution).long() index = x[:, :, 0] + resolution * x[:, :, 1] index = index[:, None, :] return index def softclip(x, min, max, hardness=5): # Soft clipping for the logsigma x = min + F.softplus(hardness*(x - min))/hardness x = max - F.softplus(-hardness*(x - max))/hardness return x def sample_triplane_feat(feature_triplane, normalized_pos): ''' normalized_pos [-1, 1] ''' tri_plane = torch.unbind(feature_triplane, dim=1) x_feat = F.grid_sample( tri_plane[0], torch.cat( [normalized_pos[:, :, 0:1], normalized_pos[:, :, 1:2]], dim=-1).unsqueeze(dim=1), padding_mode='border', align_corners=True) y_feat = F.grid_sample( tri_plane[1], torch.cat( [normalized_pos[:, :, 1:2], normalized_pos[:, :, 2:3]], dim=-1).unsqueeze(dim=1), padding_mode='border', align_corners=True) z_feat = F.grid_sample( tri_plane[2], torch.cat( [normalized_pos[:, :, 0:1], normalized_pos[:, :, 2:3]], dim=-1).unsqueeze(dim=1), padding_mode='border', align_corners=True) final_feat = (x_feat + y_feat + z_feat) final_feat = final_feat.squeeze(dim=2).permute(0, 2, 1) # 32dimension return final_feat # @persistence.persistent_class class TriPlanePC2Encoder(torch.nn.Module): # Encoder that encode point cloud to triplane feature vector similar to ConvOccNet def __init__( self, cfg, device='cuda', shape_min=-1.0, shape_length=2.0, use_2d_feat=False, # point_encoder='pvcnn', # use_point_scatter=False ): """ Outputs latent triplane from PC input Configs: max_logsigma: (float) Soft clip upper range for logsigm min_logsigma: (float) point_encoder_type: (str) one of ['pvcnn', 'pointnet'] pvcnn_flatten_voxels: (bool) for pvcnn whether to reduce voxel features (instead of scattering point features) unet_cfg: (dict) z_triplane_channels: (int) output latent triplane z_triplane_resolution: (int) Args: """ # assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 super().__init__() self.device = device self.cfg = cfg self.shape_min = shape_min self.shape_length = shape_length self.z_triplane_resolution = cfg.z_triplane_resolution z_triplane_channels = cfg.z_triplane_channels point_encoder_out_dim = z_triplane_channels #* 2 in_channels = 6 # self.resample_filter=[1, 3, 3, 1] if cfg.point_encoder_type == 'pvcnn': self.pc_encoder = PVCNNEncoder(point_encoder_out_dim, device=self.device, in_channels=in_channels, use_2d_feat=use_2d_feat) # Encode it to a volume vector. elif cfg.point_encoder_type == 'pointnet': # TODO the pointnet was buggy, investigate self.pc_encoder = ConvPointnet(c_dim=point_encoder_out_dim, dim=in_channels, hidden_dim=32, plane_resolution=self.z_triplane_resolution, padding=0) else: raise NotImplementedError(f"Point encoder {cfg.point_encoder_type} not implemented") if cfg.unet_cfg.enabled: self.unet_encoder = setup_unet( output_channels=point_encoder_out_dim, input_channels=point_encoder_out_dim, unet_cfg=cfg.unet_cfg) else: self.unet_encoder = None # @ScopedTorchProfiler('encode') def encode(self, point_cloud_xyz, point_cloud_feature, mv_feat=None, pc2pc_idx=None) -> Dict: # output = AttrDict() point_cloud_xyz = (point_cloud_xyz - self.shape_min) / self.shape_length # [0, 1] point_cloud_xyz = point_cloud_xyz - 0.5 # [-0.5, 0.5] point_cloud = torch.cat([point_cloud_xyz, point_cloud_feature], dim=-1) if self.cfg.point_encoder_type == 'pvcnn': if mv_feat is not None: pc_feat, points_feat = self.pc_encoder(point_cloud, mv_feat, pc2pc_idx) else: pc_feat, points_feat = self.pc_encoder(point_cloud) # 3D feature volume: BxDx32x32x32 if self.cfg.use_point_scatter: # Scattering from PVCNN point features points_feat_ = points_feat[0] # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64) pc_feat_1 = generate_plane_features(point_cloud_xyz, points_feat_, resolution=self.z_triplane_resolution, plane='xy') pc_feat_2 = generate_plane_features(point_cloud_xyz, points_feat_, resolution=self.z_triplane_resolution, plane='yz') pc_feat_3 = generate_plane_features(point_cloud_xyz, points_feat_, resolution=self.z_triplane_resolution, plane='xz') pc_feat = pc_feat[0] else: pc_feat = pc_feat[0] sf = self.z_triplane_resolution//32 # 32 is PVCNN's voxel dim pc_feat_1 = torch.mean(pc_feat, dim=-1) #xy_plane, normalize in z plane pc_feat_2 = torch.mean(pc_feat, dim=-3) #yz_plane, normalize in x plane pc_feat_3 = torch.mean(pc_feat, dim=-2) #xz_plane, normalize in y plane # nearest upsample pc_feat_1 = einops.repeat(pc_feat_1, 'b c h w -> b c (h hm ) (w wm)', hm = sf, wm = sf) pc_feat_2 = einops.repeat(pc_feat_2, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf) pc_feat_3 = einops.repeat(pc_feat_3, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf) elif self.cfg.point_encoder_type == 'pointnet': assert self.cfg.use_point_scatter # Run ConvPointnet pc_feat = self.pc_encoder(point_cloud) pc_feat_1 = pc_feat['xy'] # pc_feat_2 = pc_feat['yz'] pc_feat_3 = pc_feat['xz'] else: raise NotImplementedError() if self.unet_encoder is not None: # TODO eval adding a skip connection # Unet expects B, 3, C, H, W pc_feat_tri_plane_stack_pre = torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1) # dpc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre) # pc_feat_tri_plane_stack = pc_feat_tri_plane_stack_pre + dpc_feat_tri_plane_stack pc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre) pc_feat_1, pc_feat_2, pc_feat_3 = torch.unbind(pc_feat_tri_plane_stack, dim=1) return torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1) def forward(self, point_cloud_xyz, point_cloud_feature=None, mv_feat=None, pc2pc_idx=None): return self.encode(point_cloud_xyz, point_cloud_feature=point_cloud_feature, mv_feat=mv_feat, pc2pc_idx=pc2pc_idx)