jev-aleks's picture
scenedino init
9e15541
# from pykeops.torch import LazyTensor
from typing import Tuple
import matplotlib.pyplot as plt
import torch
from torch import nn, Tensor
class VisualizationModule(nn.Module):
def __init__(self, in_channels, reduce_images=3):
super().__init__()
self.batch_rgb_mean = torch.zeros(in_channels)
self.batch_rgb_comp = torch.eye(in_channels, 3)
self.reduce_images = reduce_images
self.fitted_pca = False
self.n_kmeans_clusters = 8
self.kmeans_cluster_centers = torch.zeros(self.n_kmeans_clusters, in_channels)
self.cmap_kmeans = plt.get_cmap("tab10")
def fit_pca(self, batch_features, refit):
if batch_features.dim() > 2:
raise ValueError(f"Wrong dims for PCA: {batch_features.shape}")
if not self.fitted_pca or refit:
# filter nan values
batch_features = batch_features[~torch.isnan(batch_features).any(dim=1)]
self._pca_fast(batch_features, num_components=3*self.reduce_images)
self.fitted_pca = True
def transform_pca(self, features, norm, from_dim):
features = features - self.batch_rgb_mean
if norm:
features = features / torch.linalg.norm(features, dim=-1, keepdim=True)
return features @ self.batch_rgb_comp[..., from_dim:from_dim+3]
def _pca_fast(self, data: Tensor, num_components: int = 3) -> Tuple[Tensor, Tensor]:
"""Function implements PCA using PyTorch fast low-rank approximation.
Args:
data (Tensor): Data matrix of the shape [N, C] or [B, N, C].
num_components (int): Number of principal components to be used.
Returns:
data_pca (Tensor): Transformed low-dimensional data of the shape [N, num_components] or [B, N, num_components].
pca_components (Tensor): Principal components of the shape [num_components, C] or [B, num_components, C].
"""
# Normalize data
data_mean = data.mean(dim=-2, keepdim=True)
data_normalize = (data - data_mean) / (data.std(dim=-2, keepdim=True) + 1e-08)
# Perform fast low-rank PCA
u, _, v = torch.pca_lowrank(data_normalize, q=max(num_components, 6), niter=2, center=True)
v = v.transpose(-1, -2)
# Perform SVD flip
u, v = self._svd_flip(u, v) # type: Tensor, Tensor
# Transpose PCA components to match scikit-learn
if data_normalize.ndim == 2:
pca_components = v[:num_components]
else:
pca_components = v[:, :num_components]
self.batch_rgb_mean = data_mean
self.batch_rgb_comp = pca_components.transpose(-1, -2)
def _svd_flip(self, u: Tensor, v: Tensor) -> Tuple[Tensor, Tensor]:
"""Perform SVD flip to solve sign issue of SVD.
Args:
u (Tensor): u matrix of the shape [N, C] or [B, N, C].
v (Tensor): v matrix of the shape [C, C] or [B, C, C].
Returns:
u (Tensor): Fixed u matrix of the shape [N, C] or [B, N, C].
v (Tensor): Fixed v matrix of the shape [C, C] or [B, C, C].
"""
max_abs: Tensor = torch.abs(u).argmax(dim=-2)
indexes: Tensor = torch.arange(u.shape[-1], device=u.device)
if u.ndim == 2:
signs: Tensor = torch.sign(u[max_abs, indexes])
u = u * signs
v = v * signs.unsqueeze(dim=-1)
else:
# Maybe fix looping the future...
signs = torch.stack(
[torch.sign(u[batch_index, max_abs[batch_index], indexes]) for batch_index in range(u.shape[0])], dim=0
)
u = u * signs.unsqueeze(dim=1)
v = v * signs.unsqueeze(dim=-1)
return u, v
def old_fit_transform_kmeans_batch(self, batch_features, subsample_size=20000):
feats_map_flattened = batch_features.flatten(1, -2)
from torch_kmeans import KMeans, CosineSimilarity
kmeans_engine = KMeans(n_clusters=self.n_kmeans_clusters, distance=CosineSimilarity)
n = feats_map_flattened.size(1)
if subsample_size is not None and subsample_size < n:
indices = torch.randperm(n)[:subsample_size]
feats_map_subsampled = feats_map_flattened[:, indices]
kmeans_engine.fit(feats_map_subsampled)
else:
kmeans_engine.fit(feats_map_flattened)
labels = kmeans_engine.predict(feats_map_flattened)
labels = labels.reshape(batch_features.shape[:-1]).float().cpu().numpy()
label_map = self.cmap_kmeans(labels / (self.n_kmeans_clusters - 1))[..., :3]
label_map = torch.Tensor(label_map).squeeze(-2)
return label_map
def fit_transform_kmeans_batch(self, batch_features):
feats_map_flattened = batch_features.flatten(0, -2)
with torch.no_grad():
cl, c = self._KMeans_cosine(feats_map_flattened.float(), K=self.n_kmeans_clusters)
self.kmeans_cluster_centers = c
labels = cl.reshape(batch_features.shape[:-1]).float().cpu().numpy()
label_map = self.cmap_kmeans(labels / (self.n_kmeans_clusters - 1))[..., :3]
label_map = torch.Tensor(label_map).squeeze(-2)
return label_map
def _KMeans_cosine(self, x, K=19, Niter=100):
"""Implements Lloyd's algorithm for the Cosine similarity metric."""
N, D = x.shape # Number of samples, dimension of the ambient space
c = x[:K, :].clone() # Simplistic initialization for the centroids
# Normalize the centroids for the cosine similarity:
c[:] = torch.nn.functional.normalize(c, dim=1, p=2)
x_i = LazyTensor(x.view(N, 1, D)) # (N, 1, D) samples
c_j = LazyTensor(c.view(1, K, D)) # (1, K, D) centroids
# K-means loop:
# - x is the (N, D) point cloud,
# - cl is the (N,) vector of class labels
# - c is the (K, D) cloud of cluster centroids
for i in range(Niter):
# E step: assign points to the closest cluster -------------------------
S_ij = x_i | c_j # (N, K) symbolic Gram matrix of dot products
cl = S_ij.argmax(dim=1).long().view(-1) # Points -> Nearest cluster
# M step: update the centroids to the normalized cluster average: ------
# Compute the sum of points per cluster:
c.zero_()
c.scatter_add_(0, cl[:, None].repeat(1, D), x)
# Normalize the centroids, in place:
c[:] = torch.nn.functional.normalize(c, dim=1, p=2)
return cl, c