Spaces:
Running
on
Zero
Running
on
Zero
| # %% | |
| import numpy as np | |
| import torch | |
| def build_tree(all_dots, dist='euclidean'): | |
| num_sample = all_dots.shape[0] | |
| if dist == 'euclidean': | |
| A = all_dots[:, None] - all_dots[None, :] | |
| A = (A ** 2).sum(-1) | |
| A = np.sqrt(A) | |
| A = torch.tensor(A) | |
| elif dist == 'cosine': | |
| # assume all_dots is normalized | |
| A = all_dots @ all_dots.T | |
| A = torch.tensor(A) | |
| A = 1 - A | |
| else: | |
| raise ValueError('dist must be euclidean or cosine') | |
| d_sum = A.mean(dim=1) | |
| start_idx = torch.argmin(d_sum).item() | |
| indices = [start_idx] | |
| distances = [114514,] | |
| for i in range(num_sample - 1): | |
| _A = A[indices] | |
| min_dist = _A.min(dim=0).values | |
| next_idx = torch.argmax(min_dist).item() | |
| distance = min_dist[next_idx].item() | |
| indices.append(next_idx) | |
| distances.append(distance) | |
| indices = np.array(indices) | |
| distances = np.array(distances) | |
| levels = np.log2(distances[1] / distances) | |
| levels = np.floor(levels).astype(int) + 1 | |
| levels[0] = 0 | |
| n_levels = levels.max() + 1 | |
| pi_indices = [indices[0],] | |
| for i_level in range(1, n_levels): | |
| current_level_indices = levels == i_level | |
| prev_level_indices = levels < i_level | |
| current_level_indices = indices[current_level_indices] | |
| prev_level_indices = indices[prev_level_indices] | |
| _A = A[prev_level_indices][:, current_level_indices] | |
| _pi = _A.min(dim=0).indices | |
| pi = prev_level_indices[_pi] | |
| if isinstance(pi, np.int64) or isinstance(pi, int): | |
| pi = [pi,] | |
| if isinstance(pi, np.ndarray): | |
| pi = pi.tolist() | |
| pi_indices.extend(pi) | |
| pi_indices = np.array(pi_indices) | |
| edges = np.stack([indices, pi_indices], axis=1) | |
| return edges, levels | |
| def find_connected_component(edges, start_node): | |
| # Dictionary to store adjacency list | |
| adjacency_list = {} | |
| for edge in edges: | |
| # Unpack edge | |
| a, b = edge | |
| # Add the connection for both nodes | |
| if a in adjacency_list: | |
| adjacency_list[a].append(b) | |
| else: | |
| adjacency_list[a] = [b] | |
| if b in adjacency_list: | |
| adjacency_list[b].append(a) | |
| else: | |
| adjacency_list[b] = [a] | |
| # Use BFS to find all nodes in the connected component | |
| connected_component = set() | |
| queue = [start_node] | |
| while queue: | |
| node = queue.pop(0) | |
| if node not in connected_component: | |
| connected_component.add(node) | |
| queue.extend(adjacency_list.get(node, [])) # Add neighbors to the queue | |
| return np.array(list(connected_component)) |