Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torch.autograd import Function | |
| import pointnet2_cuda | |
| class KNN(nn.Module): | |
| def __init__(self, neighbors, transpose_mode=True): | |
| super(KNN, self).__init__() | |
| self.neighbors = neighbors | |
| def forward(self, support, query): | |
| """ | |
| Args: | |
| support ([tensor]): [B, N, C] | |
| query ([tensor]): [B, M, C] | |
| Returns: | |
| [int]: neighbor idx. [B, M, K] | |
| """ | |
| dist = torch.cdist(support, query) | |
| k_dist = dist.topk(k=self.neighbors, dim=1, largest=False) | |
| return k_dist.values, k_dist.indices.transpose(1, 2).contiguous().int() | |
| class GroupingOperation(Function): | |
| def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: | |
| """ | |
| :param ctx: | |
| :param features: (B, C, N) tensor of features to group | |
| :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with | |
| :return: | |
| output: (B, C, npoint, nsample) tensor | |
| """ | |
| assert features.is_contiguous() | |
| assert idx.is_contiguous() | |
| B, nfeatures, nsample = idx.size() | |
| _, C, N = features.size() | |
| output = torch.cuda.FloatTensor(B, C, nfeatures, nsample, device=features.device) | |
| pointnet2_cuda.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output) | |
| ctx.for_backwards = (idx, N) | |
| return output | |
| def backward(ctx, grad_out: torch.Tensor): | |
| """ | |
| :param ctx: | |
| :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward | |
| :return: | |
| grad_features: (B, C, N) gradient of the features | |
| """ | |
| idx, N = ctx.for_backwards | |
| B, C, npoint, nsample = grad_out.size() | |
| grad_features = torch.zeros([B, C, N], dtype=torch.float, device=grad_out.device, requires_grad=True) | |
| grad_out_data = grad_out.data.contiguous() | |
| pointnet2_cuda.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data) | |
| return grad_features, None | |
| grouping_operation = GroupingOperation.apply | |
| class KNNGroup(nn.Module): | |
| def __init__(self, nsample: int, | |
| relative_xyz=True, | |
| normalize_dp=False, | |
| return_only_idx=False, | |
| **kwargs | |
| ): | |
| """[summary] | |
| Args: | |
| nsample (int): maximum number of features to gather in the ball | |
| use_xyz (bool, optional): concate xyz. Defaults to True. | |
| ret_grouped_xyz (bool, optional): [description]. Defaults to False. | |
| normalize_dp (bool, optional): [description]. Defaults to False. | |
| """ | |
| super().__init__() | |
| self.nsample = nsample | |
| self.knn = KNN(nsample, transpose_mode=True) | |
| self.relative_xyz = relative_xyz | |
| self.normalize_dp = normalize_dp | |
| self.return_only_idx = return_only_idx | |
| def forward(self, query_xyz: torch.Tensor, support_xyz: torch.Tensor, features: torch.Tensor = None): | |
| """ | |
| :param query_xyz: (B, N, 3) xyz coordinates of the features | |
| :param support_xyz: (B, npoint, 3) centroids | |
| :param features: (B, C, N) descriptors of the features | |
| :return: | |
| new_features: (B, 3 + C, npoint, nsample) | |
| """ | |
| _, idx = self.knn(support_xyz, query_xyz) | |
| if self.return_only_idx: | |
| return idx | |
| idx = idx.int() | |
| xyz_trans = support_xyz.transpose(1, 2).contiguous() | |
| grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) | |
| if self.relative_xyz: | |
| grouped_xyz -= query_xyz.transpose(1, 2).unsqueeze(-1) # relative position | |
| if self.normalize_dp: | |
| grouped_xyz /= torch.amax(torch.sqrt(torch.sum(grouped_xyz**2, dim=1)), dim=(1, 2)).view(-1, 1, 1, 1) | |
| if features is not None: | |
| grouped_features = grouping_operation(features, idx) | |
| return grouped_xyz, grouped_features | |
| else: | |
| return grouped_xyz, None | |
| class FurthestPointSampling(Function): | |
| def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: | |
| """ | |
| Uses iterative furthest point sampling to select a set of npoint features that have the largest | |
| minimum distance | |
| :param ctx: | |
| :param xyz: (B, N, 3) where N > npoint | |
| :param npoint: int, number of features in the sampled set | |
| :return: | |
| output: (B, npoint) tensor containing the set (idx) | |
| """ | |
| assert xyz.is_contiguous() | |
| B, N, _ = xyz.size() | |
| # output = torch.cuda.IntTensor(B, npoint, device=xyz.device) | |
| # temp = torch.cuda.FloatTensor(B, N, device=xyz.device).fill_(1e10) | |
| output = torch.cuda.IntTensor(B, npoint) | |
| temp = torch.cuda.FloatTensor(B, N).fill_(1e10) | |
| pointnet2_cuda.furthest_point_sampling_wrapper( | |
| B, N, npoint, xyz, temp, output) | |
| return output | |
| def backward(xyz, a=None): | |
| return None, None | |
| furthest_point_sample = FurthestPointSampling.apply | |
| class PointPatchEmbed(nn.Module): | |
| def __init__(self, | |
| sample_ratio=0.0625, | |
| sample_number=1024, | |
| group_size=32, | |
| in_channels=6, | |
| channels=1024, | |
| kernel_size=1, | |
| stride=1, | |
| normalize_dp=False, | |
| relative_xyz=True, | |
| ): | |
| super().__init__() | |
| self.sample_ratio = sample_ratio | |
| self.sample_number = sample_number | |
| self.group_size = group_size | |
| self.sample_fn = furthest_point_sample | |
| self.grouper = KNNGroup(self.group_size, relative_xyz=relative_xyz, normalize_dp=normalize_dp) | |
| self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=kernel_size, stride=stride) | |
| def forward(self, x): | |
| # coordinates | |
| p = x[:, :, 3:].contiguous() | |
| B, N, _ = p.shape[:3] | |
| # idx = self.sample_fn(p, int(N * self.sample_ratio)).long() | |
| idx = self.sample_fn(p, self.sample_number).long() | |
| center_p = torch.gather(p, 1, idx.unsqueeze(-1).expand(-1, -1, 3)) | |
| # query neighbors. | |
| _, fj = self.grouper(center_p, p, x.permute(0, 2, 1).contiguous()) # [B, N, 6] -> [B, 6, N] -> [B, 6, 1024, 32] | |
| # [B, 6, 1024] -> [B, channels, 1024, 1] | |
| fj = self.conv1(fj).max(dim=-1, keepdim=True)[0] | |
| return fj | |
| if __name__ == '__main__': | |
| model = PointPatchEmbed(channels=256).cuda() | |
| input = torch.rand(4, 16384, 6).cuda() | |
| ou = model(input) | |
| import pdb;pdb.set_trace() |