Spaces:
Runtime error
Runtime error
| # From kmeans_pytorch | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| def initialize(X, num_clusters): | |
| """ | |
| initialize cluster centers | |
| :param X: (torch.tensor) matrix | |
| :param num_clusters: (int) number of clusters | |
| :return: (np.array) initial state | |
| """ | |
| num_samples = len(X) | |
| indices = np.random.choice(num_samples, num_clusters, replace=False) | |
| initial_state = X[indices] | |
| return initial_state | |
| def kmeans( | |
| X, | |
| num_clusters, | |
| distance='euclidean', | |
| tol=1e-4, | |
| device=torch.device('cuda') | |
| ): | |
| """ | |
| perform kmeans | |
| :param X: (torch.tensor) matrix | |
| :param num_clusters: (int) number of clusters | |
| :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] | |
| :param tol: (float) threshold [default: 0.0001] | |
| :param device: (torch.device) device [default: cpu] | |
| :return: (torch.tensor, torch.tensor) cluster ids, cluster centers | |
| """ | |
| print(f'running k-means on {device}..') | |
| if distance == 'euclidean': | |
| pairwise_distance_function = pairwise_distance | |
| elif distance == 'cosine': | |
| pairwise_distance_function = pairwise_cosine | |
| else: | |
| raise NotImplementedError | |
| # convert to float | |
| X = X.float() | |
| # transfer to device | |
| X = X.to(device) | |
| # initialize | |
| initial_state = initialize(X, num_clusters) | |
| iteration = 0 | |
| tqdm_meter = tqdm(desc='[running kmeans]') | |
| while True: | |
| dis = pairwise_distance_function(X, initial_state) | |
| choice_cluster = torch.argmin(dis, dim=1) | |
| initial_state_pre = initial_state.clone() | |
| for index in range(num_clusters): | |
| selected = torch.nonzero(choice_cluster == index).squeeze().to(device) | |
| selected = torch.index_select(X, 0, selected) | |
| initial_state[index] = selected.mean(dim=0) | |
| center_shift = torch.sum( | |
| torch.sqrt( | |
| torch.sum((initial_state - initial_state_pre) ** 2, dim=1) | |
| )) | |
| # increment iteration | |
| iteration = iteration + 1 | |
| # update tqdm meter | |
| tqdm_meter.set_postfix( | |
| iteration=f'{iteration}', | |
| center_shift=f'{center_shift ** 2:0.6f}', | |
| tol=f'{tol:0.6f}' | |
| ) | |
| tqdm_meter.update() | |
| if center_shift ** 2 < tol: | |
| break | |
| return choice_cluster, initial_state | |
| def kmeans_predict( | |
| X, | |
| cluster_centers, | |
| distance='euclidean', | |
| device=torch.device('cpu') | |
| ): | |
| """ | |
| predict using cluster centers | |
| :param X: (torch.tensor) matrix | |
| :param cluster_centers: (torch.tensor) cluster centers | |
| :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] | |
| :param device: (torch.device) device [default: 'cpu'] | |
| :return: (torch.tensor) cluster ids | |
| """ | |
| print(f'predicting on {device}..') | |
| if distance == 'euclidean': | |
| pairwise_distance_function = pairwise_distance | |
| elif distance == 'cosine': | |
| pairwise_distance_function = pairwise_cosine | |
| else: | |
| raise NotImplementedError | |
| # convert to float | |
| X = X.float() | |
| # transfer to device | |
| X = X.to(device) | |
| dis = pairwise_distance_function(X, cluster_centers) | |
| choice_cluster = torch.argmin(dis, dim=1) | |
| return choice_cluster.cpu() | |
| def pairwise_distance(data1, data2): | |
| return torch.cdist(data1[None, :, :], data2[None, :, :])[0] | |
| def pairwise_cosine(data1, data2): | |
| # N*1*M | |
| A = data1.unsqueeze(dim=1) | |
| # 1*N*M | |
| B = data2.unsqueeze(dim=0) | |
| # normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5] | |
| A_normalized = A / A.norm(dim=-1, keepdim=True) | |
| B_normalized = B / B.norm(dim=-1, keepdim=True) | |
| cosine = A_normalized * B_normalized | |
| # return N*N matrix for pairwise distance | |
| cosine_dis = 1 - cosine.sum(dim=-1).squeeze() | |
| return cosine_dis | |