Spaces:
Running
Running
| /* | |
| Ball Query with BatchIdx & Clustering Algorithm | |
| Written by Li Jiang | |
| All Rights Reserved 2020. | |
| */ | |
| int ballquery_batch_p_cuda(int n, int meanActive, float radius, const float *xyz, const int *batch_idxs, const int *batch_offsets, int *idx, int *start_len, cudaStream_t stream); | |
| using Int = int32_t; | |
| class ConnectedComponent{ | |
| public: | |
| std::vector<Int> pt_idxs {}; | |
| ConnectedComponent(){}; | |
| void addPoint(Int pt_idx) | |
| { | |
| pt_idxs.push_back(pt_idx); | |
| } | |
| }; | |
| using ConnectedComponents = std::vector<ConnectedComponent>; | |
| /* ================================== ballquery_batch_p ================================== */ | |
| // input xyz: (n, 3) float | |
| // input batch_idxs: (n) int | |
| // input batch_offsets: (B+1) int, batch_offsets[-1] | |
| // output idx: (n * meanActive) dim 0 for number of points in the ball, idx in n | |
| // output start_len: (n, 2), int | |
| int ballquery_batch_p(at::Tensor xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor batch_offsets_tensor, at::Tensor idx_tensor, at::Tensor start_len_tensor, int n, int meanActive, float radius){ | |
| const float *xyz = xyz_tensor.data<float>(); | |
| const int *batch_idxs = batch_idxs_tensor.data<int>(); | |
| const int *batch_offsets = batch_offsets_tensor.data<int>(); | |
| int *idx = idx_tensor.data<int>(); | |
| int *start_len = start_len_tensor.data<int>(); | |
| cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | |
| int cumsum = ballquery_batch_p_cuda(n, meanActive, radius, xyz, batch_idxs, batch_offsets, idx, start_len, stream); | |
| return cumsum; | |
| } | |
| /* ================================== bfs_cluster ================================== */ | |
| ConnectedComponent find_cc(Int idx, int *semantic_label, Int *ball_query_idxs, int *start_len, int *visited){ | |
| ConnectedComponent cc; | |
| cc.addPoint(idx); | |
| visited[idx] = 1; | |
| std::queue<Int> Q; | |
| assert(Q.empty()); | |
| Q.push(idx); | |
| while(!Q.empty()){ | |
| Int cur = Q.front(); Q.pop(); | |
| int start = start_len[cur * 2]; | |
| int len = start_len[cur * 2 + 1]; | |
| int label_cur = semantic_label[cur]; | |
| for(Int i = start; i < start + len; i++){ | |
| Int idx_i = ball_query_idxs[i]; | |
| if(semantic_label[idx_i] != label_cur) continue; | |
| if(visited[idx_i] == 1) continue; | |
| cc.addPoint(idx_i); | |
| visited[idx_i] = 1; | |
| Q.push(idx_i); | |
| } | |
| } | |
| return cc; | |
| } | |
| //input: semantic_label, int, N | |
| //input: ball_query_idxs, Int, (nActive) | |
| //input: start_len, int, (N, 2) | |
| //output: clusters, CCs | |
| int get_clusters(int *semantic_label, Int *ball_query_idxs, int *start_len, const Int nPoint, int threshold, ConnectedComponents &clusters){ | |
| int visited[nPoint] = {0}; | |
| int sumNPoint = 0; | |
| for(Int i = 0; i < nPoint; i++){ | |
| if(visited[i] == 0){ | |
| ConnectedComponent CC = find_cc(i, semantic_label, ball_query_idxs, start_len, visited); | |
| if((int)CC.pt_idxs.size() >= threshold){ | |
| clusters.push_back(CC); | |
| sumNPoint += (int)CC.pt_idxs.size(); | |
| } | |
| } | |
| } | |
| return sumNPoint; | |
| } | |
| void fill_cluster_idxs_(ConnectedComponents &CCs, int *cluster_idxs, int *cluster_offsets){ | |
| for(int i = 0; i < (int)CCs.size(); i++){ | |
| cluster_offsets[i + 1] = cluster_offsets[i] + (int)CCs[i].pt_idxs.size(); | |
| for(int j = 0; j < (int)CCs[i].pt_idxs.size(); j++){ | |
| int idx = CCs[i].pt_idxs[j]; | |
| cluster_idxs[(cluster_offsets[i] + j) * 2 + 0] = i; | |
| cluster_idxs[(cluster_offsets[i] + j) * 2 + 1] = idx; | |
| } | |
| } | |
| } | |
| //input: semantic_label, int, N | |
| //input: ball_query_idxs, int, (nActive) | |
| //input: start_len, int, (N, 2) | |
| //output: cluster_idxs, int (sumNPoint, 2), dim 0 for cluster_id, dim 1 for corresponding point idxs in N | |
| //output: cluster_offsets, int (nCluster + 1) | |
| void bfs_cluster(at::Tensor semantic_label_tensor, at::Tensor ball_query_idxs_tensor, at::Tensor start_len_tensor, | |
| at::Tensor cluster_idxs_tensor, at::Tensor cluster_offsets_tensor, const int N, int threshold){ | |
| int *semantic_label = semantic_label_tensor.data<int>(); | |
| Int *ball_query_idxs = ball_query_idxs_tensor.data<Int>(); | |
| int *start_len = start_len_tensor.data<int>(); | |
| ConnectedComponents CCs; | |
| int sumNPoint = get_clusters(semantic_label, ball_query_idxs, start_len, N, threshold, CCs); | |
| int nCluster = (int)CCs.size(); | |
| cluster_idxs_tensor.resize_({sumNPoint, 2}); | |
| cluster_offsets_tensor.resize_({nCluster + 1}); | |
| cluster_idxs_tensor.zero_(); | |
| cluster_offsets_tensor.zero_(); | |
| int *cluster_idxs = cluster_idxs_tensor.data<int>(); | |
| int *cluster_offsets = cluster_offsets_tensor.data<int>(); | |
| fill_cluster_idxs_(CCs, cluster_idxs, cluster_offsets); | |
| } | |
| //------------------------------------API------------------------------------------ | |
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ | |
| m.def("ballquery_batch_p", &ballquery_batch_p, "ballquery_batch_p"); | |
| m.def("bfs_cluster", &bfs_cluster, "bfs_cluster"); | |
| } | |