jev-aleks's picture
remove crf import for demo
f2a9411
from itertools import chain
import torch
from torch import nn
import torch.nn.functional as F
from datasets.kitti_360.labels import labels as kitti_labels
from datasets.kitti_360.labels import trainId2label
# from pykeops.torch import LazyTensor
from multiprocessing import Pool
# from .crf import dense_crf
def _five_crop(features, sample_factor=1):
_, _, h, w, _, _ = features.shape
assert h % (4*sample_factor) == 0 and w % (4*sample_factor) == 0
center_shift = sample_factor // 2
crop_length = min(h, w) // 4
crop_centers = [
(h//2, w//2),
(3*h//4, w//4),
(3*h//4, 3*w//4),
(h//4, w//4),
(h//4, 3*w//4),
]
result = torch.cat([
features[:, :,
crop_center[0]-crop_length+center_shift : crop_center[0]+crop_length+center_shift : sample_factor,
crop_center[1]-crop_length+center_shift : crop_center[1]+crop_length+center_shift : sample_factor]
for crop_center in crop_centers
])
return result
def _norm(x):
return F.normalize(x, dim=-1, eps=1e-10)
class SemanticHead(nn.Module):
def __init__(self,
n_classes,
gt_classes,
input_dim,
code_dim,
buffer_size,
patch_sample_size,
knn_neighbors,
mode,
mlp_head,
apply_crf,
):
super().__init__()
self.n_classes = n_classes
self.gt_classes = gt_classes
self.input_dim = input_dim
self.code_dim = code_dim
self.knn_neighbors = knn_neighbors
self.mode = mode
self.apply_crf = apply_crf
self.buffer_size = buffer_size
self.buffer_idx = 0
self.buffer_filled = 1
self.dino_patch_buffer = torch.zeros((buffer_size, patch_sample_size, input_dim), device="cuda")
self.dino_gap_buffer = torch.zeros((buffer_size, input_dim), device="cuda")
self.direct_cluster_head = KMeansParamHead(n_classes, gt_classes, input_dim)
self.stego_head = StegoClusterHead(input_dim, code_dim)
self.stego_cluster_head = KMeansParamHead(n_classes, gt_classes, code_dim)
if mlp_head:
self.direct_linear_head = MLPHead(input_dim, gt_classes)
self.stego_linear_head = MLPHead(code_dim, gt_classes)
else:
self.direct_linear_head = LinearHead(input_dim, gt_classes)
self.stego_linear_head = LinearHead(code_dim, gt_classes)
self.label_colors = [torch.Tensor(trainId2label[train_id].color) for train_id in range(gt_classes)]
self.label_colors.append(torch.Tensor([0, 0, 0]))
self.label_colors = torch.stack(self.label_colors, dim=0).to("cuda") / 255.0
self.dropout = nn.Dropout2d(p=.1)
self.dropout1d = nn.Dropout1d(p=.1)
@classmethod
def from_conf(cls, config):
return SemanticHead(
n_classes=config.n_classes,
gt_classes=config.gt_classes,
input_dim=config.input_dim,
code_dim=config.code_dim,
buffer_size=config.buffer_size,
patch_sample_size=config.patch_sample_size,
knn_neighbors=config.knn_neighbors,
mode=config.get("mode", "2d"),
mlp_head=config.get("mlp_head", False),
apply_crf=config.get("apply_crf", False),
)
def forward(self, features, mode="stego_kmeans"):
features = _norm(features)
if mode == "stego_kmeans":
stego_features = self.stego_head(features)
return self.stego_cluster_head(stego_features)["segs_pred"]
elif mode == "stego_linear":
stego_features = self.stego_head(features)
return self.stego_linear_head(stego_features)["segs_pred"]
elif mode == "direct_kmeans":
return self.direct_cluster_head(features)["segs_pred"]
elif mode == "direct_linear":
return self.direct_linear_head(features)["segs_pred"]
else:
raise NotImplementedError(f"Mode '{mode}' is not known!")
def forward_training(self, data, visualize=False, sample_factor=4): # TODO: visualization
rgb_image = data["coarse"][0]["rgb"].detach()
dino_features = data["coarse"][0]["dino_features"].detach() # [n, v, h, w, 1, c]
dino_features = _norm(dino_features)
n, v, h, w, _, c = dino_features.shape
reshaped_dino_features = dino_features.squeeze(-2).flatten(0, 1) # [n*v, h, w, c]
stego_features = self.stego_head(reshaped_dino_features).reshape(n, v, h, w, 1, -1)
dino_features = self.dropout(dino_features.reshape(n*v, h, w, c).permute(0, 3, 1, 2))
dino_features = dino_features.permute(0, 2, 3, 1).reshape(n, v, h, w, 1, c)
dino_features = dino_features.reshape(n, v, h, w, 1, c)
data["segmentation"] = {}
if data["sample_surface_sigma"] is not None:
if self.mode == "3d":
cropped_dino_features = data["sample_surface_dino_features"].detach().squeeze(0)
cropped_dino_features = _norm(cropped_dino_features)
cropped_dino_features = self.dropout1d(cropped_dino_features.swapaxes(-2, -1)).swapaxes(-2, -1)
stego_self_features = self.stego_head(cropped_dino_features.unsqueeze(1)).squeeze(1)
elif self.mode == "2d":
dino_features = dino_features[:, :1]
stego_features = stego_features[:, :1]
# Single view
cropped_dino_features = _five_crop(dino_features, sample_factor).flatten(0, 1).flatten(1, 2).squeeze(-2)
stego_self_features = _five_crop(stego_features, sample_factor).flatten(0, 1).flatten(1, 2).squeeze(-2)
dino_feature_gap = cropped_dino_features.mean(dim=-2)
dino_feature_gap = _norm(dino_feature_gap)
# Just in training: update knn buffer
if self.training:
new_idx = self._update_buffer(self.dino_patch_buffer, cropped_dino_features)
assert new_idx == self._update_buffer(self.dino_gap_buffer, dino_feature_gap)
if new_idx < self.buffer_idx:
self.buffer_filled = self.buffer_size
else:
self.buffer_filled = max(new_idx, self.buffer_filled)
self.buffer_idx = new_idx
# Calculate from buffer - "kNN", "random"
pairwise_cos_sims = torch.einsum("nf,mf->nm", dino_feature_gap, self.dino_gap_buffer)
topk_indices = torch.topk(pairwise_cos_sims, self.knn_neighbors+1, dim=1)[1][:, 1:] # (n, k_nn)
n = cropped_dino_features.size(0)
random_nn_indices = topk_indices[torch.arange(n), torch.randint(self.knn_neighbors, (n,))]
dino_nn_features = self.dino_patch_buffer[random_nn_indices].detach()
stego_nn_features = self.stego_head(dino_nn_features)
random_indices = torch.randint(self.buffer_filled, (n,))
dino_random_features = self.dino_patch_buffer[random_indices].detach()
stego_random_features = self.stego_head(dino_random_features)
stego_corr = {
"dino_self_corr": self._compute_stego_correlation(cropped_dino_features, cropped_dino_features),
"stego_self_corr": self._compute_stego_correlation(stego_self_features, stego_self_features),
"dino_nn_corr": self._compute_stego_correlation(cropped_dino_features, dino_nn_features),
"stego_nn_corr": self._compute_stego_correlation(stego_self_features, stego_nn_features),
"dino_random_corr": self._compute_stego_correlation(cropped_dino_features, dino_random_features),
"stego_random_corr": self._compute_stego_correlation(stego_self_features, stego_random_features),
}
data["segmentation"]["stego_corr"] = stego_corr
else:
data["sample_surface_sigma"] = torch.Tensor([0.0])
data["sample_surface_dino_features"] = torch.Tensor([0.0])
# IMPORTANT, train heads after detaching features!
dino_features = dino_features.detach()
stego_features = stego_features.detach()
direct_cluster_result = self.direct_cluster_head(dino_features)
stego_cluster_result = self.stego_cluster_head(stego_features)
data["segmentation"]["results"] = {
"direct_cluster": direct_cluster_result,
"stego_cluster": stego_cluster_result,
}
data["segmentation"]["visualization"] = {
"direct_cluster": self.visualize(direct_cluster_result["segs_pred"]),
"stego_cluster": self.visualize(stego_cluster_result["segs_pred"])
}
if "segs" in data:
seg_target = self.map_kitti_id_to_train_id(data["segs"][0]).to(stego_features.device)
direct_linear_result = self.direct_linear_head(dino_features, seg_target)
stego_linear_result = self.stego_linear_head(stego_features, seg_target)
data["segmentation"]["target"] = seg_target
data["segmentation"]["results"]["direct_linear"] = direct_linear_result
data["segmentation"]["results"]["stego_linear"] = stego_linear_result
data["segmentation"]["visualization"]["target"] = self.visualize(seg_target)
if self.apply_crf:
result_names = list(data["segmentation"]["results"].keys())
for result_name in result_names:
pred_no_crf = data["segmentation"]["results"][result_name]["segs_pred"]
pred_crf = self.forward_crf(pred_no_crf, rgb_image)
data["segmentation"]["results"][result_name + "_crf"] = {"segs_pred": pred_crf}
for result_name, result in data["segmentation"]["results"].items():
data["segmentation"]["visualization"][result_name] = self.visualize(result["segs_pred"])
return data
def forward_crf(self, pred_no_crf, rgb_image):
pred_no_crf_logits = F.one_hot(pred_no_crf.squeeze()).permute(2,0,1).float()
pred_crf = torch.Tensor(dense_crf(rgb_image.squeeze().permute(2,0,1), pred_no_crf_logits)).to(pred_no_crf.device)
pred_crf = pred_crf.argmax(dim=0).reshape(pred_no_crf.shape)
return pred_crf
def update_model_eval(self, metrics):
self.direct_cluster_head.pseudo_assignment[:] = metrics["direct_cluster_assignment"]
self.stego_cluster_head.pseudo_assignment[:] = metrics["stego_cluster_assignment"]
def map_kitti_id_to_train_id(self, labels):
result = torch.zeros(labels.shape).long()
for kitti_label in kitti_labels:
result[labels == kitti_label.id] = kitti_label.trainId
result[result == 255] = -1
return result
def visualize(self, labels):
label_map = self.label_colors[labels.long()]
return label_map
def parameters_lr(self):
return [
(1.0, self.stego_head.parameters()),
(10.0, self.direct_cluster_head.parameters()),
(10.0, self.stego_cluster_head.parameters()),
(10.0, self.direct_linear_head.parameters()),
(10.0, self.stego_linear_head.parameters()),
]
def _compute_stego_correlation(self, tensor1, tensor2):
corr = torch.einsum("npf,nqf->npq", _norm(tensor1), _norm(tensor2))
return corr
def _update_buffer(self, buffer, x):
n = x.size(0)
if n >= self.buffer_size:
buffer[:] = x[-self.buffer_size:]
new_buffer_idx = 0 # Reset write index
else:
indices = (torch.arange(n) + self.buffer_idx) % self.buffer_size
buffer[indices] = x
new_buffer_idx = (self.buffer_idx + n) % self.buffer_size
return new_buffer_idx
class StegoClusterHead(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if mid_channels is None:
mid_channels = in_channels
self.linear_path = nn.Sequential(
nn.Conv2d(in_channels, out_channels, (1, 1)),
nn.Dropout2d(p=.1),
)
self.nonlinear_path = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, (1, 1)),
nn.ReLU(),
nn.Conv2d(mid_channels, out_channels, (1, 1)),
nn.Dropout2d(p=.1),
)
def forward(self, x):
x = x.swapaxes(-1, -3)
result = self.linear_path(x) + self.nonlinear_path(x)
return _norm(result.swapaxes(-1, -3)).to(x.dtype)
class KMeansParamHead(nn.Module):
def __init__(self,
n_classes: int,
gt_classes: int,
dim: int,
):
super().__init__()
self.n_classes = n_classes
self.dim = dim
self.init_type = "random"
self.cluster_centers = torch.nn.Parameter(torch.randn(self.n_classes, self.dim))
self.centroids_initialized = False
self.register_buffer("pseudo_assignment", torch.arange(0, n_classes).remainder(gt_classes))
def forward(self, features, weight=None):
features_flat = features.flatten(0, -2) # (n, d)
if weight is not None:
weight_flat = weight.flatten()
else:
weight_flat = torch.ones(features_flat.size(0), device=features.device)
# K-means++ init
if not self.centroids_initialized and self.training:
if self.init_type == "kmeans++":
cluster_centers = torch.empty(self.n_classes, self.dim, device=features.device)
first_idx = torch.randint(0, features_flat.size(0), (1,))
cluster_centers[0] = features_flat[first_idx]
for k in range(1, self.n_classes):
current_centroids = cluster_centers[:k] # (k, d)
similarity = (current_centroids @ features_flat.transpose(1, 0)) # (k, n)
max_similarity = similarity.max(dim=0).values # Closest centroid per point
distances = 1 - max_similarity
probabilities = distances ** 2
probabilities /= probabilities.sum()
next_idx = torch.multinomial(probabilities, 1)
cluster_centers[k] = features_flat[next_idx]
self.cluster_centers.data = cluster_centers
else:
self.cluster_centers.data = torch.randn(self.n_classes, self.dim, device=self.cluster_centers.device)
self.centroids_initialized = True
class_labels, cluster_loss, _ = self._kmeans_cosine(features_flat)
pseudo_segs_pred = class_labels.view(*features.shape[:-1])
result = {
"pseudo_segs_pred": pseudo_segs_pred,
"segs_pred": self._assign_pseudo_labels(pseudo_segs_pred),
"loss": torch.mean(cluster_loss * weight_flat)
}
return result
def _assign_pseudo_labels(self, pseudo_labels):
return self.pseudo_assignment[pseudo_labels.cpu()].long()
def _kmeans_cosine(self, features):
normed_clusters = F.normalize(self.cluster_centers, dim=1)
normed_features = F.normalize(features, dim=1)
inner_products = normed_features.matmul(normed_clusters.t())
class_labels = torch.argmax(inner_products, dim=1)
# cluster_probs = F.softmax(inner_products, dim=1)
cluster_probs = F.one_hot(class_labels, normed_clusters.shape[0]).to(torch.float32)
cluster_loss = -(cluster_probs * inner_products).sum(1)
# return nn.functional.log_softmax(inner_products * alpha, dim=1)
return class_labels, cluster_loss, cluster_probs
class KMeansIterHead(nn.Module):
def __init__(self,
n_classes: int,
gt_classes: int,
dim: int,
reassignment_threshold: int = 5000,
kmeans_update_factor: float = 1.0,
training_chunk: int = 100000,
):
super().__init__()
self.n_classes = n_classes
self.dim = dim
self.reassignment_threshold = reassignment_threshold
self.kmeans_update_factor = kmeans_update_factor
self.training_chunk = training_chunk
self.centroids_initialized = False
self.register_buffer("cluster_centers", torch.empty(self.n_classes, self.dim, device="cuda"))
self.register_buffer("pseudo_assignment", torch.arange(0, n_classes).remainder(gt_classes))
def forward(self, features):
features_flat = features.flatten(0, -2) # (n, d)
# K-means++ init
if not self.centroids_initialized and self.training:
first_idx = torch.randint(0, features_flat.size(0), (1,))
self.cluster_centers[0] = features_flat[first_idx]
for k in range(1, self.n_classes):
current_centroids = self.cluster_centers[:k] # (k, d)
similarity = (current_centroids @ features_flat.transpose(1, 0)) # (k, n)
max_similarity = similarity.max(dim=0).values # Closest centroid per point
distances = 1 - max_similarity
probabilities = distances ** 2
probabilities /= probabilities.sum()
next_idx = torch.multinomial(probabilities, 1)
self.cluster_centers[k] = features_flat[next_idx]
self.centroids_initialized = True
class_labels = self._kmeans_cosine(features_flat)
pseudo_segs_pred = class_labels.view(*features.shape[:-1])
result = {
"pseudo_segs_pred": pseudo_segs_pred,
"segs_pred": self._assign_pseudo_labels(pseudo_segs_pred),
}
return result
def _assign_pseudo_labels(self, pseudo_labels):
return self.pseudo_assignment[pseudo_labels.cpu()].long()
def _kmeans_cosine(self, features):
"""Implements Lloyd's algorithm for the Cosine similarity metric."""
features = F.normalize(features, dim=1, p=2)
n, d = features.shape
x_i = LazyTensor(features.view(n, 1, d).contiguous()) # (n, 1, d) samples
c_j = LazyTensor(self.cluster_centers.view(1, self.n_classes, d).contiguous()) # (1, n_classes, d) centroids
s_ij = x_i | c_j # (N, K) symbolic Gram matrix of dot products
class_labels = s_ij.argmax(dim=1).long().view(-1) # Points -> Nearest cluster
if self.training:
class_labels_count = class_labels.bincount(minlength=self.n_classes)
cluster_center_update = torch.zeros_like(self.cluster_centers)
if self.training_chunk:
for i in range(0, n, self.training_chunk):
if i + self.training_chunk < n:
cluster_center_update.scatter_add_(0, class_labels[i:i+self.training_chunk, None].repeat(1, d), features)
else:
cluster_center_update.scatter_add_(0, class_labels[i:, None].repeat(1, d), features)
else:
cluster_center_update.scatter_add_(0, class_labels[:, None].repeat(1, d), features)
cluster_center_update = F.normalize(cluster_center_update)
update_factor = self.kmeans_update_factor * (class_labels_count > self.reassignment_threshold)
update_factor = update_factor.unsqueeze(-1)
self.cluster_centers[:] = F.normalize(cluster_center_update * update_factor + self.cluster_centers * (1-update_factor))
return class_labels
class LinearHead(nn.Module):
def __init__(self,
dim: int,
gt_classes: int
):
super().__init__()
self.linear = torch.nn.Linear(dim, gt_classes)
def forward(self, features, target=None):
logit = self.linear(features).float()
result = {
"segs_pred": logit.argmax(-1),
}
if target is not None:
target = target.long().to(logit.device)
result["loss"] = F.cross_entropy(logit[:, 0].movedim(-1, 1).squeeze(-1), target, ignore_index=-1)
return result
class MLPHead(nn.Module):
def __init__(self,
dim: int,
gt_classes: int
):
super().__init__()
self.linear1 = torch.nn.Linear(dim, 2*dim)
self.linear2 = torch.nn.Linear(2*dim, gt_classes)
self.activation = torch.nn.ReLU()
def forward(self, features, target=None):
features = self.linear1(features)
features = self.activation(features)
logit = self.linear2(features).float()
result = {
"segs_pred": logit.argmax(-1),
}
if target is not None:
target = target.long().to(logit.device)
result["loss"] = F.cross_entropy(logit[:, 0].movedim(-1, 1).squeeze(-1), target, ignore_index=-1)
return result