|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import functools |
|
|
|
from .pv_module import SharedMLP, PVConv |
|
|
|
def create_pointnet_components( |
|
blocks, in_channels, with_se=False, normalize=True, eps=0, |
|
width_multiplier=1, voxel_resolution_multiplier=1, scale_pvcnn=False, device='cuda'): |
|
r, vr = width_multiplier, voxel_resolution_multiplier |
|
layers, concat_channels = [], 0 |
|
for out_channels, num_blocks, voxel_resolution in blocks: |
|
out_channels = int(r * out_channels) |
|
if voxel_resolution is None: |
|
block = functools.partial(SharedMLP, device=device) |
|
else: |
|
block = functools.partial( |
|
PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), |
|
with_se=with_se, normalize=normalize, eps=eps, scale_pvcnn=scale_pvcnn, device=device) |
|
for _ in range(num_blocks): |
|
layers.append(block(in_channels, out_channels)) |
|
in_channels = out_channels |
|
concat_channels += out_channels |
|
return layers, in_channels, concat_channels |
|
|
|
class PCMerger(nn.Module): |
|
|
|
def __init__(self, in_channels=204, device="cuda"): |
|
super(PCMerger, self).__init__() |
|
self.mlp_normal = SharedMLP(3, [128, 128], device=device) |
|
self.mlp_rgb = SharedMLP(3, [128, 128], device=device) |
|
self.mlp_sam = SharedMLP(204 - 6, [128, 128], device=device) |
|
|
|
def forward(self, feat, mv_feat, pc2pc_idx): |
|
mv_feat_normal = self.mlp_normal(mv_feat[:, :3, :]) |
|
mv_feat_rgb = self.mlp_rgb(mv_feat[:, 3:6, :]) |
|
mv_feat_sam = self.mlp_sam(mv_feat[:, 6:, :]) |
|
|
|
mv_feat_normal = mv_feat_normal.permute(0, 2, 1) |
|
mv_feat_rgb = mv_feat_rgb.permute(0, 2, 1) |
|
mv_feat_sam = mv_feat_sam.permute(0, 2, 1) |
|
feat = feat.permute(0, 2, 1) |
|
|
|
for i in range(mv_feat.shape[0]): |
|
mask = (pc2pc_idx[i] != -1).reshape(-1) |
|
idx = pc2pc_idx[i][mask].reshape(-1) |
|
feat[i][mask] += mv_feat_normal[i][idx] + mv_feat_rgb[i][idx] + mv_feat_sam[i][idx] |
|
|
|
return feat.permute(0, 2, 1) |
|
|
|
|
|
class PVCNNEncoder(nn.Module): |
|
def __init__(self, pvcnn_feat_dim, device='cuda', in_channels=3, use_2d_feat=False): |
|
super(PVCNNEncoder, self).__init__() |
|
self.device = device |
|
self.blocks = ((pvcnn_feat_dim, 1, 32), (128, 2, 16), (256, 1, 8)) |
|
self.use_2d_feat=use_2d_feat |
|
if in_channels == 6: |
|
self.append_channel = 2 |
|
elif in_channels == 3: |
|
self.append_channel = 1 |
|
else: |
|
raise NotImplementedError |
|
layers, channels_point, concat_channels_point = create_pointnet_components( |
|
blocks=self.blocks, in_channels=in_channels + self.append_channel, with_se=False, normalize=False, |
|
width_multiplier=1, voxel_resolution_multiplier=1, scale_pvcnn=True, |
|
device=device |
|
) |
|
self.encoder = nn.ModuleList(layers) |
|
if self.use_2d_feat: |
|
self.merger = PCMerger() |
|
|
|
|
|
|
|
def forward(self, input_pc, mv_feat=None, pc2pc_idx=None): |
|
features = input_pc.permute(0, 2, 1) * 2 |
|
coords = features[:, :3, :] |
|
out_features_list = [] |
|
voxel_feature_list = [] |
|
zero_padding = torch.zeros(features.shape[0], self.append_channel, features.shape[-1], device=features.device, dtype=features.dtype) |
|
features = torch.cat([features, zero_padding], dim=1) |
|
|
|
for i in range(len(self.encoder)): |
|
features, _, voxel_feature = self.encoder[i]((features, coords)) |
|
if i == 0 and mv_feat is not None: |
|
features = self.merger(features, mv_feat.permute(0, 2, 1), pc2pc_idx) |
|
out_features_list.append(features) |
|
voxel_feature_list.append(voxel_feature) |
|
return voxel_feature_list, out_features_list |