Spaces:
Runtime error
Runtime error
| from typing import Tuple | |
| import torch | |
| from torch.autograd import Function | |
| from ..utils import ext_loader | |
| ext_module = ext_loader.load_ext('_ext', ['three_nn_forward']) | |
| class ThreeNN(Function): | |
| """Find the top-3 nearest neighbors of the target set from the source set. | |
| Please refer to `Paper of PointNet++ <https://arxiv.org/abs/1706.02413>`_ | |
| for more details. | |
| """ | |
| def forward(ctx, target: torch.Tensor, | |
| source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Args: | |
| target (Tensor): shape (B, N, 3), points set that needs to | |
| find the nearest neighbors. | |
| source (Tensor): shape (B, M, 3), points set that is used | |
| to find the nearest neighbors of points in target set. | |
| Returns: | |
| Tensor: shape (B, N, 3), L2 distance of each point in target | |
| set to their corresponding nearest neighbors. | |
| """ | |
| target = target.contiguous() | |
| source = source.contiguous() | |
| B, N, _ = target.size() | |
| m = source.size(1) | |
| dist2 = torch.cuda.FloatTensor(B, N, 3) | |
| idx = torch.cuda.IntTensor(B, N, 3) | |
| ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m) | |
| if torch.__version__ != 'parrots': | |
| ctx.mark_non_differentiable(idx) | |
| return torch.sqrt(dist2), idx | |
| def backward(ctx, a=None, b=None): | |
| return None, None | |
| three_nn = ThreeNN.apply | |