import torch import torch.nn as nn import torch.nn.functional as F import functools from .pv_module import SharedMLP, PVConv def create_pointnet_components( blocks, in_channels, with_se=False, normalize=True, eps=0, width_multiplier=1, voxel_resolution_multiplier=1, scale_pvcnn=False, device='cuda'): r, vr = width_multiplier, voxel_resolution_multiplier layers, concat_channels = [], 0 for out_channels, num_blocks, voxel_resolution in blocks: out_channels = int(r * out_channels) if voxel_resolution is None: block = functools.partial(SharedMLP, device=device) else: block = functools.partial( PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), with_se=with_se, normalize=normalize, eps=eps, scale_pvcnn=scale_pvcnn, device=device) for _ in range(num_blocks): layers.append(block(in_channels, out_channels)) in_channels = out_channels concat_channels += out_channels return layers, in_channels, concat_channels class PCMerger(nn.Module): # merge surface sampled PC and rendering backprojected PC (w/ 2D features): def __init__(self, in_channels=204, device="cuda"): super(PCMerger, self).__init__() self.mlp_normal = SharedMLP(3, [128, 128], device=device) self.mlp_rgb = SharedMLP(3, [128, 128], device=device) self.mlp_sam = SharedMLP(204 - 6, [128, 128], device=device) def forward(self, feat, mv_feat, pc2pc_idx): mv_feat_normal = self.mlp_normal(mv_feat[:, :3, :]) mv_feat_rgb = self.mlp_rgb(mv_feat[:, 3:6, :]) mv_feat_sam = self.mlp_sam(mv_feat[:, 6:, :]) mv_feat_normal = mv_feat_normal.permute(0, 2, 1) mv_feat_rgb = mv_feat_rgb.permute(0, 2, 1) mv_feat_sam = mv_feat_sam.permute(0, 2, 1) feat = feat.permute(0, 2, 1) for i in range(mv_feat.shape[0]): mask = (pc2pc_idx[i] != -1).reshape(-1) idx = pc2pc_idx[i][mask].reshape(-1) feat[i][mask] += mv_feat_normal[i][idx] + mv_feat_rgb[i][idx] + mv_feat_sam[i][idx] return feat.permute(0, 2, 1) class PVCNNEncoder(nn.Module): def __init__(self, pvcnn_feat_dim, device='cuda', in_channels=3, use_2d_feat=False): super(PVCNNEncoder, self).__init__() self.device = device self.blocks = ((pvcnn_feat_dim, 1, 32), (128, 2, 16), (256, 1, 8)) self.use_2d_feat=use_2d_feat if in_channels == 6: self.append_channel = 2 elif in_channels == 3: self.append_channel = 1 else: raise NotImplementedError layers, channels_point, concat_channels_point = create_pointnet_components( blocks=self.blocks, in_channels=in_channels + self.append_channel, with_se=False, normalize=False, width_multiplier=1, voxel_resolution_multiplier=1, scale_pvcnn=True, device=device ) self.encoder = nn.ModuleList(layers)#.to(self.device) if self.use_2d_feat: self.merger = PCMerger() def forward(self, input_pc, mv_feat=None, pc2pc_idx=None): features = input_pc.permute(0, 2, 1) * 2 # make point cloud [-1, 1] coords = features[:, :3, :] out_features_list = [] voxel_feature_list = [] zero_padding = torch.zeros(features.shape[0], self.append_channel, features.shape[-1], device=features.device, dtype=features.dtype) features = torch.cat([features, zero_padding], dim=1)################## for i in range(len(self.encoder)): features, _, voxel_feature = self.encoder[i]((features, coords)) if i == 0 and mv_feat is not None: features = self.merger(features, mv_feat.permute(0, 2, 1), pc2pc_idx) out_features_list.append(features) voxel_feature_list.append(voxel_feature) return voxel_feature_list, out_features_list