#include "../cuda_utils.h" #include "sampling_cuda_kernel.h" __device__ void __update(float *dists, int *dists_i, int idx1, int idx2) { const float v1 = dists[idx1], v2 = dists[idx2]; const int i1 = dists_i[idx1], i2 = dists_i[idx2]; dists[idx1] = max(v1, v2); dists_i[idx1] = v2 > v1 ? i2 : i1; } // input xyz: (n, 3), tmp: (b, n_max) // output idx (m) template __global__ void furthestsampling_cuda_kernel(const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx) { __shared__ float dists[block_size]; __shared__ int dists_i[block_size]; int bid = blockIdx.x; int start_n, end_n, start_m, end_m, old; if (bid == 0) { start_n = 0; end_n = offset[0]; start_m = 0; end_m = new_offset[0]; old = 0; } else { start_n = offset[bid - 1]; end_n = offset[bid]; start_m = new_offset[bid - 1]; end_m = new_offset[bid]; old = offset[bid - 1]; } const int stride = block_size; int tid = threadIdx.x; if (tid == 0) idx[start_m] = start_n; __syncthreads(); for (int j = start_m + 1; j < end_m; j++) { int besti = start_n; float best = -1; float x1 = xyz[old * 3 + 0]; float y1 = xyz[old * 3 + 1]; float z1 = xyz[old * 3 + 2]; for (int k = start_n + tid; k < end_n; k += stride) { float x2 = xyz[k * 3 + 0]; float y2 = xyz[k * 3 + 1]; float z2 = xyz[k * 3 + 2]; float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); float d2 = min(d, tmp[k]); tmp[k] = d2; besti = d2 > best ? k : besti; best = d2 > best ? d2 : best; } dists[tid] = best; dists_i[tid] = besti; __syncthreads(); if (block_size >= 1024) { if (tid < 512) { __update(dists, dists_i, tid, tid + 512); } __syncthreads(); } if (block_size >= 512) { if (tid < 256) { __update(dists, dists_i, tid, tid + 256); } __syncthreads(); } if (block_size >= 256) { if (tid < 128) { __update(dists, dists_i, tid, tid + 128); } __syncthreads(); } if (block_size >= 128) { if (tid < 64) { __update(dists, dists_i, tid, tid + 64); } __syncthreads(); } if (block_size >= 64) { if (tid < 32) { __update(dists, dists_i, tid, tid + 32); } __syncthreads(); } if (block_size >= 32) { if (tid < 16) { __update(dists, dists_i, tid, tid + 16); } __syncthreads(); } if (block_size >= 16) { if (tid < 8) { __update(dists, dists_i, tid, tid + 8); } __syncthreads(); } if (block_size >= 8) { if (tid < 4) { __update(dists, dists_i, tid, tid + 4); } __syncthreads(); } if (block_size >= 4) { if (tid < 2) { __update(dists, dists_i, tid, tid + 2); } __syncthreads(); } if (block_size >= 2) { if (tid < 1) { __update(dists, dists_i, tid, tid + 1); } __syncthreads(); } old = dists_i[0]; if (tid == 0) idx[j] = old; } } void furthestsampling_cuda_launcher(int b, int n, const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx) { unsigned int n_threads = opt_n_threads(n); switch (n_threads) { case 1024: furthestsampling_cuda_kernel<1024><<>>(xyz, offset, new_offset, tmp, idx); break; case 512: furthestsampling_cuda_kernel<512><<>>(xyz, offset, new_offset, tmp, idx); break; case 256: furthestsampling_cuda_kernel<256><<>>(xyz, offset, new_offset, tmp, idx); break; case 128: furthestsampling_cuda_kernel<128><<>>(xyz, offset, new_offset, tmp, idx); break; case 64: furthestsampling_cuda_kernel<64><<>>(xyz, offset, new_offset, tmp, idx); break; case 32: furthestsampling_cuda_kernel<32><<>>(xyz, offset, new_offset, tmp, idx); break; case 16: furthestsampling_cuda_kernel<16><<>>(xyz, offset, new_offset, tmp, idx); break; case 8: furthestsampling_cuda_kernel<8><<>>(xyz, offset, new_offset, tmp, idx); break; case 4: furthestsampling_cuda_kernel<4><<>>(xyz, offset, new_offset, tmp, idx); break; case 2: furthestsampling_cuda_kernel<2><<>>(xyz, offset, new_offset, tmp, idx); break; case 1: furthestsampling_cuda_kernel<1><<>>(xyz, offset, new_offset, tmp, idx); break; default: furthestsampling_cuda_kernel<512><<>>(xyz, offset, new_offset, tmp, idx); } }