Luigi Piccinelli
init demo
1ea89dd
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <float.h>
#include <iostream>
#include <tuple>
#include "utils/dispatch.cuh"
#include "utils/mink.cuh"
// A chunk of work is blocksize-many points of P1.
// The number of potential chunks to do is N*(1+(P1-1)/blocksize)
// call (1+(P1-1)/blocksize) chunks_per_cloud
// These chunks are divided among the gridSize-many blocks.
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
// In chunk i, we work on cloud i/chunks_per_cloud on points starting from
// blocksize*(i%chunks_per_cloud).
template <typename scalar_t>
__global__ void KNearestNeighborKernelV0(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2,
const size_t D,
const size_t K,
const size_t norm) {
// Store both dists and indices for knn in global memory.
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud;
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t p1 = start_point + threadIdx.x;
if (p1 >= lengths1[n])
continue;
int offset = n * P1 * K + p1 * K;
int64_t length2 = lengths2[n];
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
for (int p2 = 0; p2 < length2; ++p2) {
// Find the distance between points1[n, p1] and points[n, p2]
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
scalar_t coord1 = points1[n * P1 * D + p1 * D + d];
scalar_t coord2 = points2[n * P2 * D + p2 * D + d];
scalar_t diff = coord1 - coord2;
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
}
mink.add(dist, p2);
}
}
}
template <typename scalar_t, int64_t D>
__global__ void KNearestNeighborKernelV1(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2,
const size_t K,
const size_t norm) {
// Same idea as the previous version, but hoist D into a template argument
// so we can cache the current point in a thread-local array. We still store
// the current best K dists and indices in global memory, so this should work
// for very large K and fairly large D.
scalar_t cur_point[D];
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud;
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t p1 = start_point + threadIdx.x;
if (p1 >= lengths1[n])
continue;
for (int d = 0; d < D; ++d) {
cur_point[d] = points1[n * P1 * D + p1 * D + d];
}
int offset = n * P1 * K + p1 * K;
int64_t length2 = lengths2[n];
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
for (int p2 = 0; p2 < length2; ++p2) {
// Find the distance between cur_point and points[n, p2]
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d];
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
}
mink.add(dist, p2);
}
}
}
// This is a shim functor to allow us to dispatch using DispatchKernel1D
template <typename scalar_t, int64_t D>
struct KNearestNeighborV1Functor {
static void run(
size_t blocks,
size_t threads,
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2,
const size_t K,
const size_t norm) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K, norm);
}
};
template <typename scalar_t, int64_t D, int64_t K>
__global__ void KNearestNeighborKernelV2(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const int64_t N,
const int64_t P1,
const int64_t P2,
const size_t norm) {
// Same general implementation as V2, but also hoist K into a template arg.
scalar_t cur_point[D];
scalar_t min_dists[K];
int min_idxs[K];
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud;
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t p1 = start_point + threadIdx.x;
if (p1 >= lengths1[n])
continue;
for (int d = 0; d < D; ++d) {
cur_point[d] = points1[n * P1 * D + p1 * D + d];
}
int64_t length2 = lengths2[n];
MinK<scalar_t, int> mink(min_dists, min_idxs, K);
for (int p2 = 0; p2 < length2; ++p2) {
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d;
scalar_t diff = cur_point[d] - points2[offset];
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
}
mink.add(dist, p2);
}
for (int k = 0; k < mink.size(); ++k) {
idxs[n * P1 * K + p1 * K + k] = min_idxs[k];
dists[n * P1 * K + p1 * K + k] = min_dists[k];
}
}
}
// This is a shim so we can dispatch using DispatchKernel2D
template <typename scalar_t, int64_t D, int64_t K>
struct KNearestNeighborKernelV2Functor {
static void run(
size_t blocks,
size_t threads,
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const int64_t N,
const int64_t P1,
const int64_t P2,
const size_t norm) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
}
};
template <typename scalar_t, int D, int K>
__global__ void KNearestNeighborKernelV3(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2,
const size_t norm) {
// Same idea as V2, but use register indexing for thread-local arrays.
// Enabling sorting for this version leads to huge slowdowns; I suspect
// that it forces min_dists into local memory rather than registers.
// As a result this version is always unsorted.
scalar_t cur_point[D];
scalar_t min_dists[K];
int min_idxs[K];
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud;
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t p1 = start_point + threadIdx.x;
if (p1 >= lengths1[n])
continue;
for (int d = 0; d < D; ++d) {
cur_point[d] = points1[n * P1 * D + p1 * D + d];
}
int64_t length2 = lengths2[n];
RegisterMinK<scalar_t, int, K> mink(min_dists, min_idxs);
for (int p2 = 0; p2 < length2; ++p2) {
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d;
scalar_t diff = cur_point[d] - points2[offset];
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
}
mink.add(dist, p2);
}
for (int k = 0; k < mink.size(); ++k) {
idxs[n * P1 * K + p1 * K + k] = min_idxs[k];
dists[n * P1 * K + p1 * K + k] = min_dists[k];
}
}
}
// This is a shim so we can dispatch using DispatchKernel2D
template <typename scalar_t, int64_t D, int64_t K>
struct KNearestNeighborKernelV3Functor {
static void run(
size_t blocks,
size_t threads,
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2,
const size_t norm) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
}
};
constexpr int V1_MIN_D = 1;
constexpr int V1_MAX_D = 32;
constexpr int V2_MIN_D = 1;
constexpr int V2_MAX_D = 8;
constexpr int V2_MIN_K = 1;
constexpr int V2_MAX_K = 32;
constexpr int V3_MIN_D = 1;
constexpr int V3_MAX_D = 8;
constexpr int V3_MIN_K = 1;
constexpr int V3_MAX_K = 4;
bool InBounds(const int64_t min, const int64_t x, const int64_t max) {
return min <= x && x <= max;
}
bool KnnCheckVersion(int version, const int64_t D, const int64_t K) {
if (version == 0) {
return true;
} else if (version == 1) {
return InBounds(V1_MIN_D, D, V1_MAX_D);
} else if (version == 2) {
return InBounds(V2_MIN_D, D, V2_MAX_D) && InBounds(V2_MIN_K, K, V2_MAX_K);
} else if (version == 3) {
return InBounds(V3_MIN_D, D, V3_MAX_D) && InBounds(V3_MIN_K, K, V3_MAX_K);
}
return false;
}
int ChooseVersion(const int64_t D, const int64_t K) {
for (int version = 3; version >= 1; version--) {
if (KnnCheckVersion(version, D, K)) {
return version;
}
}
return 0;
}
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const int norm,
const int K,
int version) {
// Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
at::CheckedFrom c = "KNearestNeighborIdxCuda";
at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t});
at::checkAllSameType(c, {p1_t, p2_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(p1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto N = p1.size(0);
const auto P1 = p1.size(1);
const auto P2 = p2.size(1);
const auto D = p2.size(2);
const int64_t K_64 = K;
TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2.");
TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension");
auto long_dtype = lengths1.options().dtype(at::kLong);
auto idxs = at::zeros({N, P1, K}, long_dtype);
auto dists = at::zeros({N, P1, K}, p1.options());
if (idxs.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(idxs, dists);
}
if (version < 0) {
version = ChooseVersion(D, K);
} else if (!KnnCheckVersion(version, D, K)) {
int new_version = ChooseVersion(D, K);
std::cout << "WARNING: Requested KNN version " << version
<< " is not compatible with D = " << D << "; K = " << K
<< ". Falling back to version = " << new_version << std::endl;
version = new_version;
}
// At this point we should have a valid version no matter what data the user
// gave us. But we can check once more to be sure; however this time
// assert fail since failing at this point means we have a bug in our version
// selection or checking code.
AT_ASSERTM(KnnCheckVersion(version, D, K), "Invalid version");
const size_t threads = 256;
const size_t blocks = 256;
if (version == 0) {
AT_DISPATCH_FLOATING_TYPES(
p1.scalar_type(), "knn_kernel_cuda", ([&] {
KNearestNeighborKernelV0<scalar_t><<<blocks, threads, 0, stream>>>(
p1.contiguous().data_ptr<scalar_t>(),
p2.contiguous().data_ptr<scalar_t>(),
lengths1.contiguous().data_ptr<int64_t>(),
lengths2.contiguous().data_ptr<int64_t>(),
dists.data_ptr<scalar_t>(),
idxs.data_ptr<int64_t>(),
N,
P1,
P2,
D,
K,
norm);
}));
} else if (version == 1) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
DispatchKernel1D<
KNearestNeighborV1Functor,
scalar_t,
V1_MIN_D,
V1_MAX_D>(
D,
blocks,
threads,
p1.contiguous().data_ptr<scalar_t>(),
p2.contiguous().data_ptr<scalar_t>(),
lengths1.contiguous().data_ptr<int64_t>(),
lengths2.contiguous().data_ptr<int64_t>(),
dists.data_ptr<scalar_t>(),
idxs.data_ptr<int64_t>(),
N,
P1,
P2,
K,
norm);
}));
} else if (version == 2) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
DispatchKernel2D<
KNearestNeighborKernelV2Functor,
scalar_t,
V2_MIN_D,
V2_MAX_D,
V2_MIN_K,
V2_MAX_K>(
D,
K_64,
blocks,
threads,
p1.contiguous().data_ptr<scalar_t>(),
p2.contiguous().data_ptr<scalar_t>(),
lengths1.contiguous().data_ptr<int64_t>(),
lengths2.contiguous().data_ptr<int64_t>(),
dists.data_ptr<scalar_t>(),
idxs.data_ptr<int64_t>(),
N,
P1,
P2,
norm);
}));
} else if (version == 3) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
DispatchKernel2D<
KNearestNeighborKernelV3Functor,
scalar_t,
V3_MIN_D,
V3_MAX_D,
V3_MIN_K,
V3_MAX_K>(
D,
K_64,
blocks,
threads,
p1.contiguous().data_ptr<scalar_t>(),
p2.contiguous().data_ptr<scalar_t>(),
lengths1.contiguous().data_ptr<int64_t>(),
lengths2.contiguous().data_ptr<int64_t>(),
dists.data_ptr<scalar_t>(),
idxs.data_ptr<int64_t>(),
N,
P1,
P2,
norm);
}));
}
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(idxs, dists);
}
// ------------------------------------------------------------- //
// Backward Operators //
// ------------------------------------------------------------- //
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
// Currently, support is for floats only.
__global__ void KNearestNeighborBackwardKernel(
const float* __restrict__ p1, // (N, P1, D)
const float* __restrict__ p2, // (N, P2, D)
const int64_t* __restrict__ lengths1, // (N,)
const int64_t* __restrict__ lengths2, // (N,)
const int64_t* __restrict__ idxs, // (N, P1, K)
const float* __restrict__ grad_dists, // (N, P1, K)
float* __restrict__ grad_p1, // (N, P1, D)
float* __restrict__ grad_p2, // (N, P2, D)
const size_t N,
const size_t P1,
const size_t P2,
const size_t K,
const size_t D,
const size_t norm) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = gridDim.x * blockDim.x;
for (size_t i = tid; i < N * P1 * K * D; i += stride) {
const size_t n = i / (P1 * K * D); // batch index
size_t rem = i % (P1 * K * D);
const size_t p1_idx = rem / (K * D); // index of point in p1
rem = rem % (K * D);
const size_t k = rem / D; // k-th nearest neighbor
const size_t d = rem % D; // d-th dimension in the feature vector
const size_t num1 = lengths1[n]; // number of valid points in p1 in batch
const size_t num2 = lengths2[n]; // number of valid points in p2 in batch
if ((p1_idx < num1) && (k < num2)) {
const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
// index of point in p2 corresponding to the k-th nearest neighbor
const int64_t p2_idx = idxs[n * P1 * K + p1_idx * K + k];
// If the index is the pad value of -1 then ignore it
if (p2_idx == -1) {
continue;
}
float diff = 0.0;
if (norm == 1) {
float sign =
(p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d])
? 1.0
: -1.0;
diff = grad_dist * sign;
} else { // norm is 2
diff = 2.0 * grad_dist *
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
}
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
}
}
}
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
int norm,
const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4},
idxs_t{idxs, "idxs", 5}, grad_dists_t{grad_dists, "grad_dists", 6};
at::CheckedFrom c = "KNearestNeighborBackwardCuda";
at::checkAllSameGPU(
c, {p1_t, p2_t, lengths1_t, lengths2_t, idxs_t, grad_dists_t});
at::checkAllSameType(c, {p1_t, p2_t, grad_dists_t});
// This is nondeterministic because atomicAdd
at::globalContext().alertNotDeterministic("KNearestNeighborBackwardCuda");
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(p1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto N = p1.size(0);
const auto P1 = p1.size(1);
const auto P2 = p2.size(1);
const auto D = p2.size(2);
const auto K = idxs.size(2);
TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension");
TORCH_CHECK(idxs.size(0) == N, "KNN idxs must have the same batch dimension");
TORCH_CHECK(
idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1");
TORCH_CHECK(grad_dists.size(0) == N);
TORCH_CHECK(grad_dists.size(1) == P1);
TORCH_CHECK(grad_dists.size(2) == K);
auto grad_p1 = at::zeros({N, P1, D}, p1.options());
auto grad_p2 = at::zeros({N, P2, D}, p2.options());
if (grad_p1.numel() == 0 || grad_p2.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_p1, grad_p2);
}
const int blocks = 64;
const int threads = 512;
KNearestNeighborBackwardKernel<<<blocks, threads, 0, stream>>>(
p1.contiguous().data_ptr<float>(),
p2.contiguous().data_ptr<float>(),
lengths1.contiguous().data_ptr<int64_t>(),
lengths2.contiguous().data_ptr<int64_t>(),
idxs.contiguous().data_ptr<int64_t>(),
grad_dists.contiguous().data_ptr<float>(),
grad_p1.data_ptr<float>(),
grad_p2.data_ptr<float>(),
N,
P1,
P2,
K,
D,
norm);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_p1, grad_p2);
}