|
import torch |
|
from torch.autograd import Function |
|
|
|
from pointops._C import farthest_point_sampling_cuda |
|
|
|
|
|
class FarthestPointSampling(Function): |
|
@staticmethod |
|
def forward(ctx, xyz, offset, new_offset): |
|
""" |
|
input: coords: (n, 3), offset: (b), new_offset: (b) |
|
output: idx: (m) |
|
""" |
|
assert xyz.is_contiguous() |
|
n, b, n_max = xyz.shape[0], offset.shape[0], offset[0] |
|
for i in range(1, b): |
|
n_max = max(offset[i] - offset[i - 1], n_max) |
|
idx = torch.cuda.IntTensor(new_offset[b - 1].item()).zero_() |
|
tmp = torch.cuda.FloatTensor(n).fill_(1e10) |
|
farthest_point_sampling_cuda( |
|
b, n_max, xyz, offset.int(), new_offset.int(), tmp, idx |
|
) |
|
del tmp |
|
return idx |
|
|
|
|
|
farthest_point_sampling = FarthestPointSampling.apply |
|
|