Spaces:
Running
on
Zero
Running
on
Zero
/* | |
* Copyright (c) Meta Platforms, Inc. and affiliates. | |
* All rights reserved. | |
* | |
* This source code is licensed under the BSD-style license found in the | |
* LICENSE file in the root directory of this source tree. | |
*/ | |
// A chunk of work is blocksize-many points of P1. | |
// The number of potential chunks to do is N*(1+(P1-1)/blocksize) | |
// call (1+(P1-1)/blocksize) chunks_per_cloud | |
// These chunks are divided among the gridSize-many blocks. | |
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc . | |
// In chunk i, we work on cloud i/chunks_per_cloud on points starting from | |
// blocksize*(i%chunks_per_cloud). | |
template <typename scalar_t> | |
__global__ void KNearestNeighborKernelV0( | |
const scalar_t* __restrict__ points1, | |
const scalar_t* __restrict__ points2, | |
const int64_t* __restrict__ lengths1, | |
const int64_t* __restrict__ lengths2, | |
scalar_t* __restrict__ dists, | |
int64_t* __restrict__ idxs, | |
const size_t N, | |
const size_t P1, | |
const size_t P2, | |
const size_t D, | |
const size_t K, | |
const size_t norm) { | |
// Store both dists and indices for knn in global memory. | |
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); | |
const int64_t chunks_to_do = N * chunks_per_cloud; | |
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { | |
const int64_t n = chunk / chunks_per_cloud; | |
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); | |
int64_t p1 = start_point + threadIdx.x; | |
if (p1 >= lengths1[n]) | |
continue; | |
int offset = n * P1 * K + p1 * K; | |
int64_t length2 = lengths2[n]; | |
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K); | |
for (int p2 = 0; p2 < length2; ++p2) { | |
// Find the distance between points1[n, p1] and points[n, p2] | |
scalar_t dist = 0; | |
for (int d = 0; d < D; ++d) { | |
scalar_t coord1 = points1[n * P1 * D + p1 * D + d]; | |
scalar_t coord2 = points2[n * P2 * D + p2 * D + d]; | |
scalar_t diff = coord1 - coord2; | |
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); | |
dist += norm_diff; | |
} | |
mink.add(dist, p2); | |
} | |
} | |
} | |
template <typename scalar_t, int64_t D> | |
__global__ void KNearestNeighborKernelV1( | |
const scalar_t* __restrict__ points1, | |
const scalar_t* __restrict__ points2, | |
const int64_t* __restrict__ lengths1, | |
const int64_t* __restrict__ lengths2, | |
scalar_t* __restrict__ dists, | |
int64_t* __restrict__ idxs, | |
const size_t N, | |
const size_t P1, | |
const size_t P2, | |
const size_t K, | |
const size_t norm) { | |
// Same idea as the previous version, but hoist D into a template argument | |
// so we can cache the current point in a thread-local array. We still store | |
// the current best K dists and indices in global memory, so this should work | |
// for very large K and fairly large D. | |
scalar_t cur_point[D]; | |
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); | |
const int64_t chunks_to_do = N * chunks_per_cloud; | |
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { | |
const int64_t n = chunk / chunks_per_cloud; | |
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); | |
int64_t p1 = start_point + threadIdx.x; | |
if (p1 >= lengths1[n]) | |
continue; | |
for (int d = 0; d < D; ++d) { | |
cur_point[d] = points1[n * P1 * D + p1 * D + d]; | |
} | |
int offset = n * P1 * K + p1 * K; | |
int64_t length2 = lengths2[n]; | |
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K); | |
for (int p2 = 0; p2 < length2; ++p2) { | |
// Find the distance between cur_point and points[n, p2] | |
scalar_t dist = 0; | |
for (int d = 0; d < D; ++d) { | |
scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d]; | |
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); | |
dist += norm_diff; | |
} | |
mink.add(dist, p2); | |
} | |
} | |
} | |
// This is a shim functor to allow us to dispatch using DispatchKernel1D | |
template <typename scalar_t, int64_t D> | |
struct KNearestNeighborV1Functor { | |
static void run( | |
size_t blocks, | |
size_t threads, | |
const scalar_t* __restrict__ points1, | |
const scalar_t* __restrict__ points2, | |
const int64_t* __restrict__ lengths1, | |
const int64_t* __restrict__ lengths2, | |
scalar_t* __restrict__ dists, | |
int64_t* __restrict__ idxs, | |
const size_t N, | |
const size_t P1, | |
const size_t P2, | |
const size_t K, | |
const size_t norm) { | |
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | |
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads, 0, stream>>>( | |
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K, norm); | |
} | |
}; | |
template <typename scalar_t, int64_t D, int64_t K> | |
__global__ void KNearestNeighborKernelV2( | |
const scalar_t* __restrict__ points1, | |
const scalar_t* __restrict__ points2, | |
const int64_t* __restrict__ lengths1, | |
const int64_t* __restrict__ lengths2, | |
scalar_t* __restrict__ dists, | |
int64_t* __restrict__ idxs, | |
const int64_t N, | |
const int64_t P1, | |
const int64_t P2, | |
const size_t norm) { | |
// Same general implementation as V2, but also hoist K into a template arg. | |
scalar_t cur_point[D]; | |
scalar_t min_dists[K]; | |
int min_idxs[K]; | |
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); | |
const int64_t chunks_to_do = N * chunks_per_cloud; | |
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { | |
const int64_t n = chunk / chunks_per_cloud; | |
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); | |
int64_t p1 = start_point + threadIdx.x; | |
if (p1 >= lengths1[n]) | |
continue; | |
for (int d = 0; d < D; ++d) { | |
cur_point[d] = points1[n * P1 * D + p1 * D + d]; | |
} | |
int64_t length2 = lengths2[n]; | |
MinK<scalar_t, int> mink(min_dists, min_idxs, K); | |
for (int p2 = 0; p2 < length2; ++p2) { | |
scalar_t dist = 0; | |
for (int d = 0; d < D; ++d) { | |
int offset = n * P2 * D + p2 * D + d; | |
scalar_t diff = cur_point[d] - points2[offset]; | |
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); | |
dist += norm_diff; | |
} | |
mink.add(dist, p2); | |
} | |
for (int k = 0; k < mink.size(); ++k) { | |
idxs[n * P1 * K + p1 * K + k] = min_idxs[k]; | |
dists[n * P1 * K + p1 * K + k] = min_dists[k]; | |
} | |
} | |
} | |
// This is a shim so we can dispatch using DispatchKernel2D | |
template <typename scalar_t, int64_t D, int64_t K> | |
struct KNearestNeighborKernelV2Functor { | |
static void run( | |
size_t blocks, | |
size_t threads, | |
const scalar_t* __restrict__ points1, | |
const scalar_t* __restrict__ points2, | |
const int64_t* __restrict__ lengths1, | |
const int64_t* __restrict__ lengths2, | |
scalar_t* __restrict__ dists, | |
int64_t* __restrict__ idxs, | |
const int64_t N, | |
const int64_t P1, | |
const int64_t P2, | |
const size_t norm) { | |
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | |
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads, 0, stream>>>( | |
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm); | |
} | |
}; | |
template <typename scalar_t, int D, int K> | |
__global__ void KNearestNeighborKernelV3( | |
const scalar_t* __restrict__ points1, | |
const scalar_t* __restrict__ points2, | |
const int64_t* __restrict__ lengths1, | |
const int64_t* __restrict__ lengths2, | |
scalar_t* __restrict__ dists, | |
int64_t* __restrict__ idxs, | |
const size_t N, | |
const size_t P1, | |
const size_t P2, | |
const size_t norm) { | |
// Same idea as V2, but use register indexing for thread-local arrays. | |
// Enabling sorting for this version leads to huge slowdowns; I suspect | |
// that it forces min_dists into local memory rather than registers. | |
// As a result this version is always unsorted. | |
scalar_t cur_point[D]; | |
scalar_t min_dists[K]; | |
int min_idxs[K]; | |
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); | |
const int64_t chunks_to_do = N * chunks_per_cloud; | |
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { | |
const int64_t n = chunk / chunks_per_cloud; | |
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); | |
int64_t p1 = start_point + threadIdx.x; | |
if (p1 >= lengths1[n]) | |
continue; | |
for (int d = 0; d < D; ++d) { | |
cur_point[d] = points1[n * P1 * D + p1 * D + d]; | |
} | |
int64_t length2 = lengths2[n]; | |
RegisterMinK<scalar_t, int, K> mink(min_dists, min_idxs); | |
for (int p2 = 0; p2 < length2; ++p2) { | |
scalar_t dist = 0; | |
for (int d = 0; d < D; ++d) { | |
int offset = n * P2 * D + p2 * D + d; | |
scalar_t diff = cur_point[d] - points2[offset]; | |
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); | |
dist += norm_diff; | |
} | |
mink.add(dist, p2); | |
} | |
for (int k = 0; k < mink.size(); ++k) { | |
idxs[n * P1 * K + p1 * K + k] = min_idxs[k]; | |
dists[n * P1 * K + p1 * K + k] = min_dists[k]; | |
} | |
} | |
} | |
// This is a shim so we can dispatch using DispatchKernel2D | |
template <typename scalar_t, int64_t D, int64_t K> | |
struct KNearestNeighborKernelV3Functor { | |
static void run( | |
size_t blocks, | |
size_t threads, | |
const scalar_t* __restrict__ points1, | |
const scalar_t* __restrict__ points2, | |
const int64_t* __restrict__ lengths1, | |
const int64_t* __restrict__ lengths2, | |
scalar_t* __restrict__ dists, | |
int64_t* __restrict__ idxs, | |
const size_t N, | |
const size_t P1, | |
const size_t P2, | |
const size_t norm) { | |
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | |
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads, 0, stream>>>( | |
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm); | |
} | |
}; | |
constexpr int V1_MIN_D = 1; | |
constexpr int V1_MAX_D = 32; | |
constexpr int V2_MIN_D = 1; | |
constexpr int V2_MAX_D = 8; | |
constexpr int V2_MIN_K = 1; | |
constexpr int V2_MAX_K = 32; | |
constexpr int V3_MIN_D = 1; | |
constexpr int V3_MAX_D = 8; | |
constexpr int V3_MIN_K = 1; | |
constexpr int V3_MAX_K = 4; | |
bool InBounds(const int64_t min, const int64_t x, const int64_t max) { | |
return min <= x && x <= max; | |
} | |
bool KnnCheckVersion(int version, const int64_t D, const int64_t K) { | |
if (version == 0) { | |
return true; | |
} else if (version == 1) { | |
return InBounds(V1_MIN_D, D, V1_MAX_D); | |
} else if (version == 2) { | |
return InBounds(V2_MIN_D, D, V2_MAX_D) && InBounds(V2_MIN_K, K, V2_MAX_K); | |
} else if (version == 3) { | |
return InBounds(V3_MIN_D, D, V3_MAX_D) && InBounds(V3_MIN_K, K, V3_MAX_K); | |
} | |
return false; | |
} | |
int ChooseVersion(const int64_t D, const int64_t K) { | |
for (int version = 3; version >= 1; version--) { | |
if (KnnCheckVersion(version, D, K)) { | |
return version; | |
} | |
} | |
return 0; | |
} | |
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda( | |
const at::Tensor& p1, | |
const at::Tensor& p2, | |
const at::Tensor& lengths1, | |
const at::Tensor& lengths2, | |
const int norm, | |
const int K, | |
int version) { | |
// Check inputs are on the same device | |
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, | |
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4}; | |
at::CheckedFrom c = "KNearestNeighborIdxCuda"; | |
at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t}); | |
at::checkAllSameType(c, {p1_t, p2_t}); | |
// Set the device for the kernel launch based on the device of the input | |
at::cuda::CUDAGuard device_guard(p1.device()); | |
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | |
const auto N = p1.size(0); | |
const auto P1 = p1.size(1); | |
const auto P2 = p2.size(1); | |
const auto D = p2.size(2); | |
const int64_t K_64 = K; | |
TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2."); | |
TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension"); | |
auto long_dtype = lengths1.options().dtype(at::kLong); | |
auto idxs = at::zeros({N, P1, K}, long_dtype); | |
auto dists = at::zeros({N, P1, K}, p1.options()); | |
if (idxs.numel() == 0) { | |
AT_CUDA_CHECK(cudaGetLastError()); | |
return std::make_tuple(idxs, dists); | |
} | |
if (version < 0) { | |
version = ChooseVersion(D, K); | |
} else if (!KnnCheckVersion(version, D, K)) { | |
int new_version = ChooseVersion(D, K); | |
std::cout << "WARNING: Requested KNN version " << version | |
<< " is not compatible with D = " << D << "; K = " << K | |
<< ". Falling back to version = " << new_version << std::endl; | |
version = new_version; | |
} | |
// At this point we should have a valid version no matter what data the user | |
// gave us. But we can check once more to be sure; however this time | |
// assert fail since failing at this point means we have a bug in our version | |
// selection or checking code. | |
AT_ASSERTM(KnnCheckVersion(version, D, K), "Invalid version"); | |
const size_t threads = 256; | |
const size_t blocks = 256; | |
if (version == 0) { | |
AT_DISPATCH_FLOATING_TYPES( | |
p1.scalar_type(), "knn_kernel_cuda", ([&] { | |
KNearestNeighborKernelV0<scalar_t><<<blocks, threads, 0, stream>>>( | |
p1.contiguous().data_ptr<scalar_t>(), | |
p2.contiguous().data_ptr<scalar_t>(), | |
lengths1.contiguous().data_ptr<int64_t>(), | |
lengths2.contiguous().data_ptr<int64_t>(), | |
dists.data_ptr<scalar_t>(), | |
idxs.data_ptr<int64_t>(), | |
N, | |
P1, | |
P2, | |
D, | |
K, | |
norm); | |
})); | |
} else if (version == 1) { | |
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { | |
DispatchKernel1D< | |
KNearestNeighborV1Functor, | |
scalar_t, | |
V1_MIN_D, | |
V1_MAX_D>( | |
D, | |
blocks, | |
threads, | |
p1.contiguous().data_ptr<scalar_t>(), | |
p2.contiguous().data_ptr<scalar_t>(), | |
lengths1.contiguous().data_ptr<int64_t>(), | |
lengths2.contiguous().data_ptr<int64_t>(), | |
dists.data_ptr<scalar_t>(), | |
idxs.data_ptr<int64_t>(), | |
N, | |
P1, | |
P2, | |
K, | |
norm); | |
})); | |
} else if (version == 2) { | |
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { | |
DispatchKernel2D< | |
KNearestNeighborKernelV2Functor, | |
scalar_t, | |
V2_MIN_D, | |
V2_MAX_D, | |
V2_MIN_K, | |
V2_MAX_K>( | |
D, | |
K_64, | |
blocks, | |
threads, | |
p1.contiguous().data_ptr<scalar_t>(), | |
p2.contiguous().data_ptr<scalar_t>(), | |
lengths1.contiguous().data_ptr<int64_t>(), | |
lengths2.contiguous().data_ptr<int64_t>(), | |
dists.data_ptr<scalar_t>(), | |
idxs.data_ptr<int64_t>(), | |
N, | |
P1, | |
P2, | |
norm); | |
})); | |
} else if (version == 3) { | |
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { | |
DispatchKernel2D< | |
KNearestNeighborKernelV3Functor, | |
scalar_t, | |
V3_MIN_D, | |
V3_MAX_D, | |
V3_MIN_K, | |
V3_MAX_K>( | |
D, | |
K_64, | |
blocks, | |
threads, | |
p1.contiguous().data_ptr<scalar_t>(), | |
p2.contiguous().data_ptr<scalar_t>(), | |
lengths1.contiguous().data_ptr<int64_t>(), | |
lengths2.contiguous().data_ptr<int64_t>(), | |
dists.data_ptr<scalar_t>(), | |
idxs.data_ptr<int64_t>(), | |
N, | |
P1, | |
P2, | |
norm); | |
})); | |
} | |
AT_CUDA_CHECK(cudaGetLastError()); | |
return std::make_tuple(idxs, dists); | |
} | |
// ------------------------------------------------------------- // | |
// Backward Operators // | |
// ------------------------------------------------------------- // | |
// TODO(gkioxari) support all data types once AtomicAdd supports doubles. | |
// Currently, support is for floats only. | |
__global__ void KNearestNeighborBackwardKernel( | |
const float* __restrict__ p1, // (N, P1, D) | |
const float* __restrict__ p2, // (N, P2, D) | |
const int64_t* __restrict__ lengths1, // (N,) | |
const int64_t* __restrict__ lengths2, // (N,) | |
const int64_t* __restrict__ idxs, // (N, P1, K) | |
const float* __restrict__ grad_dists, // (N, P1, K) | |
float* __restrict__ grad_p1, // (N, P1, D) | |
float* __restrict__ grad_p2, // (N, P2, D) | |
const size_t N, | |
const size_t P1, | |
const size_t P2, | |
const size_t K, | |
const size_t D, | |
const size_t norm) { | |
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; | |
const size_t stride = gridDim.x * blockDim.x; | |
for (size_t i = tid; i < N * P1 * K * D; i += stride) { | |
const size_t n = i / (P1 * K * D); // batch index | |
size_t rem = i % (P1 * K * D); | |
const size_t p1_idx = rem / (K * D); // index of point in p1 | |
rem = rem % (K * D); | |
const size_t k = rem / D; // k-th nearest neighbor | |
const size_t d = rem % D; // d-th dimension in the feature vector | |
const size_t num1 = lengths1[n]; // number of valid points in p1 in batch | |
const size_t num2 = lengths2[n]; // number of valid points in p2 in batch | |
if ((p1_idx < num1) && (k < num2)) { | |
const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k]; | |
// index of point in p2 corresponding to the k-th nearest neighbor | |
const int64_t p2_idx = idxs[n * P1 * K + p1_idx * K + k]; | |
// If the index is the pad value of -1 then ignore it | |
if (p2_idx == -1) { | |
continue; | |
} | |
float diff = 0.0; | |
if (norm == 1) { | |
float sign = | |
(p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d]) | |
? 1.0 | |
: -1.0; | |
diff = grad_dist * sign; | |
} else { // norm is 2 | |
diff = 2.0 * grad_dist * | |
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]); | |
} | |
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff); | |
atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff); | |
} | |
} | |
} | |
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda( | |
const at::Tensor& p1, | |
const at::Tensor& p2, | |
const at::Tensor& lengths1, | |
const at::Tensor& lengths2, | |
const at::Tensor& idxs, | |
int norm, | |
const at::Tensor& grad_dists) { | |
// Check inputs are on the same device | |
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, | |
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4}, | |
idxs_t{idxs, "idxs", 5}, grad_dists_t{grad_dists, "grad_dists", 6}; | |
at::CheckedFrom c = "KNearestNeighborBackwardCuda"; | |
at::checkAllSameGPU( | |
c, {p1_t, p2_t, lengths1_t, lengths2_t, idxs_t, grad_dists_t}); | |
at::checkAllSameType(c, {p1_t, p2_t, grad_dists_t}); | |
// This is nondeterministic because atomicAdd | |
at::globalContext().alertNotDeterministic("KNearestNeighborBackwardCuda"); | |
// Set the device for the kernel launch based on the device of the input | |
at::cuda::CUDAGuard device_guard(p1.device()); | |
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | |
const auto N = p1.size(0); | |
const auto P1 = p1.size(1); | |
const auto P2 = p2.size(1); | |
const auto D = p2.size(2); | |
const auto K = idxs.size(2); | |
TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension"); | |
TORCH_CHECK(idxs.size(0) == N, "KNN idxs must have the same batch dimension"); | |
TORCH_CHECK( | |
idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1"); | |
TORCH_CHECK(grad_dists.size(0) == N); | |
TORCH_CHECK(grad_dists.size(1) == P1); | |
TORCH_CHECK(grad_dists.size(2) == K); | |
auto grad_p1 = at::zeros({N, P1, D}, p1.options()); | |
auto grad_p2 = at::zeros({N, P2, D}, p2.options()); | |
if (grad_p1.numel() == 0 || grad_p2.numel() == 0) { | |
AT_CUDA_CHECK(cudaGetLastError()); | |
return std::make_tuple(grad_p1, grad_p2); | |
} | |
const int blocks = 64; | |
const int threads = 512; | |
KNearestNeighborBackwardKernel<<<blocks, threads, 0, stream>>>( | |
p1.contiguous().data_ptr<float>(), | |
p2.contiguous().data_ptr<float>(), | |
lengths1.contiguous().data_ptr<int64_t>(), | |
lengths2.contiguous().data_ptr<int64_t>(), | |
idxs.contiguous().data_ptr<int64_t>(), | |
grad_dists.contiguous().data_ptr<float>(), | |
grad_p1.data_ptr<float>(), | |
grad_p2.data_ptr<float>(), | |
N, | |
P1, | |
P2, | |
K, | |
D, | |
norm); | |
AT_CUDA_CHECK(cudaGetLastError()); | |
return std::make_tuple(grad_p1, grad_p2); | |
} |