omnipart's picture
init
491eded
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
from ast import Dict
import math
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch_scatter import scatter_mean #, scatter_max
from .unet_3daware import setup_unet #UNetTriplane3dAware
from .conv_pointnet import ConvPointnet
from .pc_encoder import PVCNNEncoder #PointNet
import einops
from .dnnlib_util import ScopedTorchProfiler, printarr
def generate_plane_features(p, c, resolution, plane='xz'):
"""
Args:
p: (B,3,n_p)
c: (B,C,n_p)
"""
padding = 0.
c_dim = c.size(1)
# acquire indices of features in plane
xy = normalize_coordinate(p.clone(), plane=plane, padding=padding) # normalize to the range of (0, 1)
index = coordinate2index(xy, resolution)
# scatter plane features from points
fea_plane = c.new_zeros(p.size(0), c_dim, resolution**2)
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
fea_plane = fea_plane.reshape(p.size(0), c_dim, resolution, resolution) # sparce matrix (B x 512 x reso x reso)
return fea_plane
def normalize_coordinate(p, padding=0.1, plane='xz'):
''' Normalize coordinate to [0, 1] for unit cube experiments
Args:
p (tensor): point
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
plane (str): plane feature type, ['xz', 'xy', 'yz']
'''
if plane == 'xz':
xy = p[:, :, [0, 2]]
elif plane =='xy':
xy = p[:, :, [0, 1]]
else:
xy = p[:, :, [1, 2]]
xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
xy_new = xy_new + 0.5 # range (0, 1)
# if there are outliers out of the range
if xy_new.max() >= 1:
xy_new[xy_new >= 1] = 1 - 10e-6
if xy_new.min() < 0:
xy_new[xy_new < 0] = 0.0
return xy_new
def coordinate2index(x, resolution):
''' Normalize coordinate to [0, 1] for unit cube experiments.
Corresponds to our 3D model
Args:
x (tensor): coordinate
reso (int): defined resolution
coord_type (str): coordinate type
'''
x = (x * resolution).long()
index = x[:, :, 0] + resolution * x[:, :, 1]
index = index[:, None, :]
return index
def softclip(x, min, max, hardness=5):
# Soft clipping for the logsigma
x = min + F.softplus(hardness*(x - min))/hardness
x = max - F.softplus(-hardness*(x - max))/hardness
return x
def sample_triplane_feat(feature_triplane, normalized_pos):
'''
normalized_pos [-1, 1]
'''
tri_plane = torch.unbind(feature_triplane, dim=1)
x_feat = F.grid_sample(
tri_plane[0],
torch.cat(
[normalized_pos[:, :, 0:1], normalized_pos[:, :, 1:2]],
dim=-1).unsqueeze(dim=1), padding_mode='border',
align_corners=True)
y_feat = F.grid_sample(
tri_plane[1],
torch.cat(
[normalized_pos[:, :, 1:2], normalized_pos[:, :, 2:3]],
dim=-1).unsqueeze(dim=1), padding_mode='border',
align_corners=True)
z_feat = F.grid_sample(
tri_plane[2],
torch.cat(
[normalized_pos[:, :, 0:1], normalized_pos[:, :, 2:3]],
dim=-1).unsqueeze(dim=1), padding_mode='border',
align_corners=True)
final_feat = (x_feat + y_feat + z_feat)
final_feat = final_feat.squeeze(dim=2).permute(0, 2, 1) # 32dimension
return final_feat
# @persistence.persistent_class
class TriPlanePC2Encoder(torch.nn.Module):
# Encoder that encode point cloud to triplane feature vector similar to ConvOccNet
def __init__(
self,
cfg,
device='cuda',
shape_min=-1.0,
shape_length=2.0,
use_2d_feat=False,
# point_encoder='pvcnn',
# use_point_scatter=False
):
"""
Outputs latent triplane from PC input
Configs:
max_logsigma: (float) Soft clip upper range for logsigm
min_logsigma: (float)
point_encoder_type: (str) one of ['pvcnn', 'pointnet']
pvcnn_flatten_voxels: (bool) for pvcnn whether to reduce voxel
features (instead of scattering point features)
unet_cfg: (dict)
z_triplane_channels: (int) output latent triplane
z_triplane_resolution: (int)
Args:
"""
# assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
super().__init__()
self.device = device
self.cfg = cfg
self.shape_min = shape_min
self.shape_length = shape_length
self.z_triplane_resolution = cfg.z_triplane_resolution
z_triplane_channels = cfg.z_triplane_channels
point_encoder_out_dim = z_triplane_channels #* 2
in_channels = 6
# self.resample_filter=[1, 3, 3, 1]
if cfg.point_encoder_type == 'pvcnn':
self.pc_encoder = PVCNNEncoder(point_encoder_out_dim,
device=self.device, in_channels=in_channels, use_2d_feat=use_2d_feat) # Encode it to a volume vector.
elif cfg.point_encoder_type == 'pointnet':
# TODO the pointnet was buggy, investigate
self.pc_encoder = ConvPointnet(c_dim=point_encoder_out_dim,
dim=in_channels, hidden_dim=32,
plane_resolution=self.z_triplane_resolution,
padding=0)
else:
raise NotImplementedError(f"Point encoder {cfg.point_encoder_type} not implemented")
if cfg.unet_cfg.enabled:
self.unet_encoder = setup_unet(
output_channels=point_encoder_out_dim,
input_channels=point_encoder_out_dim,
unet_cfg=cfg.unet_cfg)
else:
self.unet_encoder = None
# @ScopedTorchProfiler('encode')
def encode(self, point_cloud_xyz, point_cloud_feature, mv_feat=None, pc2pc_idx=None) -> Dict:
# output = AttrDict()
point_cloud_xyz = (point_cloud_xyz - self.shape_min) / self.shape_length # [0, 1]
point_cloud_xyz = point_cloud_xyz - 0.5 # [-0.5, 0.5]
point_cloud = torch.cat([point_cloud_xyz, point_cloud_feature], dim=-1)
if self.cfg.point_encoder_type == 'pvcnn':
if mv_feat is not None:
pc_feat, points_feat = self.pc_encoder(point_cloud, mv_feat, pc2pc_idx)
else:
pc_feat, points_feat = self.pc_encoder(point_cloud) # 3D feature volume: BxDx32x32x32
if self.cfg.use_point_scatter:
# Scattering from PVCNN point features
points_feat_ = points_feat[0]
# shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
pc_feat_1 = generate_plane_features(point_cloud_xyz, points_feat_,
resolution=self.z_triplane_resolution, plane='xy')
pc_feat_2 = generate_plane_features(point_cloud_xyz, points_feat_,
resolution=self.z_triplane_resolution, plane='yz')
pc_feat_3 = generate_plane_features(point_cloud_xyz, points_feat_,
resolution=self.z_triplane_resolution, plane='xz')
pc_feat = pc_feat[0]
else:
pc_feat = pc_feat[0]
sf = self.z_triplane_resolution//32 # 32 is PVCNN's voxel dim
pc_feat_1 = torch.mean(pc_feat, dim=-1) #xy_plane, normalize in z plane
pc_feat_2 = torch.mean(pc_feat, dim=-3) #yz_plane, normalize in x plane
pc_feat_3 = torch.mean(pc_feat, dim=-2) #xz_plane, normalize in y plane
# nearest upsample
pc_feat_1 = einops.repeat(pc_feat_1, 'b c h w -> b c (h hm ) (w wm)', hm = sf, wm = sf)
pc_feat_2 = einops.repeat(pc_feat_2, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf)
pc_feat_3 = einops.repeat(pc_feat_3, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf)
elif self.cfg.point_encoder_type == 'pointnet':
assert self.cfg.use_point_scatter
# Run ConvPointnet
pc_feat = self.pc_encoder(point_cloud)
pc_feat_1 = pc_feat['xy'] #
pc_feat_2 = pc_feat['yz']
pc_feat_3 = pc_feat['xz']
else:
raise NotImplementedError()
if self.unet_encoder is not None:
# TODO eval adding a skip connection
# Unet expects B, 3, C, H, W
pc_feat_tri_plane_stack_pre = torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1)
# dpc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre)
# pc_feat_tri_plane_stack = pc_feat_tri_plane_stack_pre + dpc_feat_tri_plane_stack
pc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre)
pc_feat_1, pc_feat_2, pc_feat_3 = torch.unbind(pc_feat_tri_plane_stack, dim=1)
return torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1)
def forward(self, point_cloud_xyz, point_cloud_feature=None, mv_feat=None, pc2pc_idx=None):
return self.encode(point_cloud_xyz, point_cloud_feature=point_cloud_feature, mv_feat=mv_feat, pc2pc_idx=pc2pc_idx)