/* Ball Query with BatchIdx & Clustering Algorithm Written by Li Jiang All Rights Reserved 2020. */ #include #include #include #include #include #include #include #include #include int ballquery_batch_p_cuda(int n, int meanActive, float radius, const float *xyz, const int *batch_idxs, const int *batch_offsets, int *idx, int *start_len, cudaStream_t stream); using Int = int32_t; class ConnectedComponent{ public: std::vector pt_idxs {}; ConnectedComponent(){}; void addPoint(Int pt_idx) { pt_idxs.push_back(pt_idx); } }; using ConnectedComponents = std::vector; /* ================================== ballquery_batch_p ================================== */ // input xyz: (n, 3) float // input batch_idxs: (n) int // input batch_offsets: (B+1) int, batch_offsets[-1] // output idx: (n * meanActive) dim 0 for number of points in the ball, idx in n // output start_len: (n, 2), int int ballquery_batch_p(at::Tensor xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor batch_offsets_tensor, at::Tensor idx_tensor, at::Tensor start_len_tensor, int n, int meanActive, float radius){ const float *xyz = xyz_tensor.data(); const int *batch_idxs = batch_idxs_tensor.data(); const int *batch_offsets = batch_offsets_tensor.data(); int *idx = idx_tensor.data(); int *start_len = start_len_tensor.data(); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int cumsum = ballquery_batch_p_cuda(n, meanActive, radius, xyz, batch_idxs, batch_offsets, idx, start_len, stream); return cumsum; } /* ================================== bfs_cluster ================================== */ ConnectedComponent find_cc(Int idx, int *semantic_label, Int *ball_query_idxs, int *start_len, int *visited){ ConnectedComponent cc; cc.addPoint(idx); visited[idx] = 1; std::queue Q; assert(Q.empty()); Q.push(idx); while(!Q.empty()){ Int cur = Q.front(); Q.pop(); int start = start_len[cur * 2]; int len = start_len[cur * 2 + 1]; int label_cur = semantic_label[cur]; for(Int i = start; i < start + len; i++){ Int idx_i = ball_query_idxs[i]; if(semantic_label[idx_i] != label_cur) continue; if(visited[idx_i] == 1) continue; cc.addPoint(idx_i); visited[idx_i] = 1; Q.push(idx_i); } } return cc; } //input: semantic_label, int, N //input: ball_query_idxs, Int, (nActive) //input: start_len, int, (N, 2) //output: clusters, CCs int get_clusters(int *semantic_label, Int *ball_query_idxs, int *start_len, const Int nPoint, int threshold, ConnectedComponents &clusters){ int visited[nPoint] = {0}; int sumNPoint = 0; for(Int i = 0; i < nPoint; i++){ if(visited[i] == 0){ ConnectedComponent CC = find_cc(i, semantic_label, ball_query_idxs, start_len, visited); if((int)CC.pt_idxs.size() >= threshold){ clusters.push_back(CC); sumNPoint += (int)CC.pt_idxs.size(); } } } return sumNPoint; } void fill_cluster_idxs_(ConnectedComponents &CCs, int *cluster_idxs, int *cluster_offsets){ for(int i = 0; i < (int)CCs.size(); i++){ cluster_offsets[i + 1] = cluster_offsets[i] + (int)CCs[i].pt_idxs.size(); for(int j = 0; j < (int)CCs[i].pt_idxs.size(); j++){ int idx = CCs[i].pt_idxs[j]; cluster_idxs[(cluster_offsets[i] + j) * 2 + 0] = i; cluster_idxs[(cluster_offsets[i] + j) * 2 + 1] = idx; } } } //input: semantic_label, int, N //input: ball_query_idxs, int, (nActive) //input: start_len, int, (N, 2) //output: cluster_idxs, int (sumNPoint, 2), dim 0 for cluster_id, dim 1 for corresponding point idxs in N //output: cluster_offsets, int (nCluster + 1) void bfs_cluster(at::Tensor semantic_label_tensor, at::Tensor ball_query_idxs_tensor, at::Tensor start_len_tensor, at::Tensor cluster_idxs_tensor, at::Tensor cluster_offsets_tensor, const int N, int threshold){ int *semantic_label = semantic_label_tensor.data(); Int *ball_query_idxs = ball_query_idxs_tensor.data(); int *start_len = start_len_tensor.data(); ConnectedComponents CCs; int sumNPoint = get_clusters(semantic_label, ball_query_idxs, start_len, N, threshold, CCs); int nCluster = (int)CCs.size(); cluster_idxs_tensor.resize_({sumNPoint, 2}); cluster_offsets_tensor.resize_({nCluster + 1}); cluster_idxs_tensor.zero_(); cluster_offsets_tensor.zero_(); int *cluster_idxs = cluster_idxs_tensor.data(); int *cluster_offsets = cluster_offsets_tensor.data(); fill_cluster_idxs_(CCs, cluster_idxs, cluster_offsets); } //------------------------------------API------------------------------------------ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ m.def("ballquery_batch_p", &ballquery_batch_p, "ballquery_batch_p"); m.def("bfs_cluster", &bfs_cluster, "bfs_cluster"); }