Spaces:
Running
on
Zero
Running
on
Zero
# from pykeops.torch import LazyTensor | |
from typing import Tuple | |
import matplotlib.pyplot as plt | |
import torch | |
from torch import nn, Tensor | |
class VisualizationModule(nn.Module): | |
def __init__(self, in_channels, reduce_images=3): | |
super().__init__() | |
self.batch_rgb_mean = torch.zeros(in_channels) | |
self.batch_rgb_comp = torch.eye(in_channels, 3) | |
self.reduce_images = reduce_images | |
self.fitted_pca = False | |
self.n_kmeans_clusters = 8 | |
self.kmeans_cluster_centers = torch.zeros(self.n_kmeans_clusters, in_channels) | |
self.cmap_kmeans = plt.get_cmap("tab10") | |
def fit_pca(self, batch_features, refit): | |
if batch_features.dim() > 2: | |
raise ValueError(f"Wrong dims for PCA: {batch_features.shape}") | |
if not self.fitted_pca or refit: | |
# filter nan values | |
batch_features = batch_features[~torch.isnan(batch_features).any(dim=1)] | |
self._pca_fast(batch_features, num_components=3*self.reduce_images) | |
self.fitted_pca = True | |
def transform_pca(self, features, norm, from_dim): | |
features = features - self.batch_rgb_mean | |
if norm: | |
features = features / torch.linalg.norm(features, dim=-1, keepdim=True) | |
return features @ self.batch_rgb_comp[..., from_dim:from_dim+3] | |
def _pca_fast(self, data: Tensor, num_components: int = 3) -> Tuple[Tensor, Tensor]: | |
"""Function implements PCA using PyTorch fast low-rank approximation. | |
Args: | |
data (Tensor): Data matrix of the shape [N, C] or [B, N, C]. | |
num_components (int): Number of principal components to be used. | |
Returns: | |
data_pca (Tensor): Transformed low-dimensional data of the shape [N, num_components] or [B, N, num_components]. | |
pca_components (Tensor): Principal components of the shape [num_components, C] or [B, num_components, C]. | |
""" | |
# Normalize data | |
data_mean = data.mean(dim=-2, keepdim=True) | |
data_normalize = (data - data_mean) / (data.std(dim=-2, keepdim=True) + 1e-08) | |
# Perform fast low-rank PCA | |
u, _, v = torch.pca_lowrank(data_normalize, q=max(num_components, 6), niter=2, center=True) | |
v = v.transpose(-1, -2) | |
# Perform SVD flip | |
u, v = self._svd_flip(u, v) # type: Tensor, Tensor | |
# Transpose PCA components to match scikit-learn | |
if data_normalize.ndim == 2: | |
pca_components = v[:num_components] | |
else: | |
pca_components = v[:, :num_components] | |
self.batch_rgb_mean = data_mean | |
self.batch_rgb_comp = pca_components.transpose(-1, -2) | |
def _svd_flip(self, u: Tensor, v: Tensor) -> Tuple[Tensor, Tensor]: | |
"""Perform SVD flip to solve sign issue of SVD. | |
Args: | |
u (Tensor): u matrix of the shape [N, C] or [B, N, C]. | |
v (Tensor): v matrix of the shape [C, C] or [B, C, C]. | |
Returns: | |
u (Tensor): Fixed u matrix of the shape [N, C] or [B, N, C]. | |
v (Tensor): Fixed v matrix of the shape [C, C] or [B, C, C]. | |
""" | |
max_abs: Tensor = torch.abs(u).argmax(dim=-2) | |
indexes: Tensor = torch.arange(u.shape[-1], device=u.device) | |
if u.ndim == 2: | |
signs: Tensor = torch.sign(u[max_abs, indexes]) | |
u = u * signs | |
v = v * signs.unsqueeze(dim=-1) | |
else: | |
# Maybe fix looping the future... | |
signs = torch.stack( | |
[torch.sign(u[batch_index, max_abs[batch_index], indexes]) for batch_index in range(u.shape[0])], dim=0 | |
) | |
u = u * signs.unsqueeze(dim=1) | |
v = v * signs.unsqueeze(dim=-1) | |
return u, v | |
def old_fit_transform_kmeans_batch(self, batch_features, subsample_size=20000): | |
feats_map_flattened = batch_features.flatten(1, -2) | |
from torch_kmeans import KMeans, CosineSimilarity | |
kmeans_engine = KMeans(n_clusters=self.n_kmeans_clusters, distance=CosineSimilarity) | |
n = feats_map_flattened.size(1) | |
if subsample_size is not None and subsample_size < n: | |
indices = torch.randperm(n)[:subsample_size] | |
feats_map_subsampled = feats_map_flattened[:, indices] | |
kmeans_engine.fit(feats_map_subsampled) | |
else: | |
kmeans_engine.fit(feats_map_flattened) | |
labels = kmeans_engine.predict(feats_map_flattened) | |
labels = labels.reshape(batch_features.shape[:-1]).float().cpu().numpy() | |
label_map = self.cmap_kmeans(labels / (self.n_kmeans_clusters - 1))[..., :3] | |
label_map = torch.Tensor(label_map).squeeze(-2) | |
return label_map | |
def fit_transform_kmeans_batch(self, batch_features): | |
feats_map_flattened = batch_features.flatten(0, -2) | |
with torch.no_grad(): | |
cl, c = self._KMeans_cosine(feats_map_flattened.float(), K=self.n_kmeans_clusters) | |
self.kmeans_cluster_centers = c | |
labels = cl.reshape(batch_features.shape[:-1]).float().cpu().numpy() | |
label_map = self.cmap_kmeans(labels / (self.n_kmeans_clusters - 1))[..., :3] | |
label_map = torch.Tensor(label_map).squeeze(-2) | |
return label_map | |
def _KMeans_cosine(self, x, K=19, Niter=100): | |
"""Implements Lloyd's algorithm for the Cosine similarity metric.""" | |
N, D = x.shape # Number of samples, dimension of the ambient space | |
c = x[:K, :].clone() # Simplistic initialization for the centroids | |
# Normalize the centroids for the cosine similarity: | |
c[:] = torch.nn.functional.normalize(c, dim=1, p=2) | |
x_i = LazyTensor(x.view(N, 1, D)) # (N, 1, D) samples | |
c_j = LazyTensor(c.view(1, K, D)) # (1, K, D) centroids | |
# K-means loop: | |
# - x is the (N, D) point cloud, | |
# - cl is the (N,) vector of class labels | |
# - c is the (K, D) cloud of cluster centroids | |
for i in range(Niter): | |
# E step: assign points to the closest cluster ------------------------- | |
S_ij = x_i | c_j # (N, K) symbolic Gram matrix of dot products | |
cl = S_ij.argmax(dim=1).long().view(-1) # Points -> Nearest cluster | |
# M step: update the centroids to the normalized cluster average: ------ | |
# Compute the sum of points per cluster: | |
c.zero_() | |
c.scatter_add_(0, cl[:, None].repeat(1, D), x) | |
# Normalize the centroids, in place: | |
c[:] = torch.nn.functional.normalize(c, dim=1, p=2) | |
return cl, c |