import os import math import torch import torch.nn.functional as F import numpy as np from typing import List from sklearn.decomposition import PCA from typing import Optional, Tuple from PIL import Image from model.modules.new_object_detection import * class DIFTLatentStore: def __init__(self, steps: List[int], up_ft_indices: List[int]): self.steps = steps self.up_ft_indices = up_ft_indices self.dift_features = {} self.smoothed_dift_features = {} def __call__(self, features: torch.Tensor, t: int, layer_index: int): if t in self.steps and layer_index in self.up_ft_indices: self.dift_features[f'{int(t)}_{layer_index}'] = features def smooth(self, kernel_size=3, sigma=1): for key, value in self.dift_features.items(): if key not in self.smoothed_dift_features: self.smoothed_dift_features[key] = torch.stack([gaussian_smooth(x, kernel_size=kernel_size, sigma=sigma) for x in value], dim=0) def copy(self): copy_dift = DIFTLatentStore(self.steps, self.up_ft_indices) for key, value in self.dift_features.items(): copy_dift.dift_features[key] = value.clone() return copy_dift def reset(self): self.dift_features = {} self.smoothed_dift_features = {} def gaussian_smooth(input_tensor, kernel_size=3, sigma=1): kernel = np.fromfunction( lambda x, y: (1/ (2 * np.pi * sigma ** 2)) * np.exp(-((x - (kernel_size - 1) / 2) ** 2 + (y - (kernel_size - 1) / 2) ** 2) / (2 * sigma ** 2)), (kernel_size, kernel_size) ) kernel = torch.Tensor(kernel / kernel.sum()).to(input_tensor.dtype).to(input_tensor.device) kernel = kernel.unsqueeze(0).unsqueeze(0) smoothed_slices = [] for i in range(input_tensor.size(0)): slice_tensor = input_tensor[i, :, :] slice_tensor = F.conv2d(slice_tensor.unsqueeze(0).unsqueeze(0), kernel, padding=kernel_size // 2)[0, 0] smoothed_slices.append(slice_tensor) smoothed_tensor = torch.stack(smoothed_slices, dim=0) return smoothed_tensor def cos_dist(a, b): a_norm = F.normalize(a, dim=-1) b_norm = F.normalize(b, dim=-1) res = a_norm @ b_norm.T return 1 - res def extract_patches(feature_map: torch.Tensor, patch_size: int, stride: int) -> torch.Tensor: # feature_map is (C, H, W). Unfold requires (B, C, H, W). feature_map = feature_map.unsqueeze(0) # (1, C, H, W) # Unfold: output shape will be (B, C * patch_size^2, num_patches) patches = F.unfold( feature_map, kernel_size=patch_size, stride=stride ) # Now patches is (1, C*patch_size^2, num_patches) # Transpose to get shape (num_patches, C*patch_size^2) patches = patches.squeeze(0).transpose(0, 1) # (num_patches, C*patch_size^2) return patches def reassemble_patches( patches: torch.Tensor, out_shape: Tuple[int, int, int], patch_size: int, stride: int ) -> torch.Tensor: C, H, W = out_shape # 1) Convert from (num_patches, C*patch_size^2) to (B=1, C*patch_size^2, num_patches) patches_4d = patches.transpose(0, 1).unsqueeze(0) # (1, C*patch_size^2, num_patches) # 2) fold: reassemble patches to (1, C, H, W) reassembled = F.fold( patches_4d, output_size=(H, W), kernel_size=patch_size, stride=stride ) # 3) Create a divisor mask to account for overlapping regions. # We do this by folding a "ones" tensor of the same shape as patches_4d. ones_input = torch.ones_like(patches_4d) overlap_count = F.fold( ones_input, output_size=(H, W), kernel_size=patch_size, stride=stride ) # 4) Divide to normalize overlapping areas reassembled = reassembled / overlap_count.clamp_min(1e-8) # 5) Remove the batch dimension -> (C, H, W) reassembled = reassembled.squeeze(0) return reassembled def calculate_patch_distance(index1: int, index2: int, grid_size: int, stride: int, patch_size: int) -> float: row1, col1 = index1 // grid_size, index1 % grid_size row2, col2 = index2 // grid_size, index2 % grid_size # print('row1, col1:', row1, col1) x_center1, y_center1 = (row1 * stride) + (patch_size / 2), (col1 * stride) + (patch_size / 2) x_center2, y_center2 = (row2 * stride) + (patch_size / 2), (col2 * stride) + (patch_size / 2) return math.sqrt((x_center2 - x_center1)**2 + (y_center2 - y_center1)**2) def gen_nn_map( latent, src_features, tgt_features, device, kernel_size=3, stride=1, return_newness=False, **kwargs ): batch_size = kwargs.get("batch_size", None) timestep = kwargs.get("timestep", None) if kwargs.get("visualize", False): dift_visualization(src_features, tgt_features, filename_out=f"output/feat_colors_{timestep}.png") src_patches = extract_patches(src_features, kernel_size, stride) tgt_patches = extract_patches(tgt_features, kernel_size, stride) if isinstance(latent, list): latent_patches = [extract_patches(l, kernel_size, stride) for l in latent] else: latent_patches = extract_patches(latent, kernel_size, stride) num_tgt = src_patches.size(0) batch = batch_size or num_tgt nearest_neighbor_indices = torch.empty(num_tgt, dtype=torch.long, device=device) nearest_neighbor_distances = torch.empty(num_tgt, dtype=torch.long, device=device) dist_chunks = [] for start in range(0, num_tgt, batch): sims = cos_dist(src_patches, tgt_patches[start : start + batch]) dist_chunks.append(sims) min_distances, best_idx = sims.min(0) nearest_neighbor_indices[start : start + batch] = best_idx nearest_neighbor_distances[start : start + batch] = min_distances if not isinstance(latent, list): aligned_latent = latent_patches[nearest_neighbor_indices] aligned_latent = reassemble_patches(aligned_latent, latent.shape, kernel_size, stride) else: aligned_latent = [latent_patches[i][nearest_neighbor_indices] for i in range(len(latent_patches))] aligned_latent = [reassemble_patches(l, latent[0].shape, kernel_size, stride) for l in aligned_latent] if return_newness: dist_matrix = torch.cat(dist_chunks, dim=0) newness_method = 'two_sided' # newness_method = 'distance' if newness_method.lower() == "distance": newness = detect_newness_distance(nearest_neighbor_distances, quantile=0.97) elif newness_method.lower() == "two_sided": newness = detect_newness_two_sided(dist_matrix, k=4) out_shape = latent[0].shape if isinstance(latent, list) else latent.shape out_shape = (1, out_shape[1], out_shape[2]) newness = reassemble_patches(newness.unsqueeze(-1), out_shape, kernel_size, stride) del src_patches, tgt_patches, latent_patches, nearest_neighbor_indices, nearest_neighbor_distances ################## visualization of changing source features to match target ################## if False: updated_src_patches = src_patches[nearest_neighbor_indices] updated_src_patches = reassemble_patches(updated_src_patches, src_features.shape, kernel_size, stride) dift_visualization( updated_src_patches, tgt_features, filename_out=f"output/updated_feat_colors_{timestep}.png", ) if return_newness: if isinstance(aligned_latent, list): aligned_latent.append(newness) else: return aligned_latent, newness return aligned_latent def dift_visualization( src_feature: torch.Tensor, tgt_feature: torch.Tensor, filename_out: str, resize_to: Optional[Tuple[int, int]] = (512, 512) ): """ Flatten features, apply PCA for 3D embedding, normalize for RGB, then reshape and save as image """ C, H_s, W_s = src_feature.shape _, H_t, W_t = tgt_feature.shape src_flat = src_feature.permute(1, 2, 0).reshape(-1, C) # (H_s*W_s, C) tgt_flat = tgt_feature.permute(1, 2, 0).reshape(-1, C) # (H_t*W_t, C) all_features = torch.cat([src_flat, tgt_flat], dim=0) # shape: (N_total, C) all_features_np = all_features.detach().cpu().numpy() num_components = 3 pca = PCA(n_components=num_components) all_features_3d = pca.fit_transform(all_features_np) # shape: (N_total, 3) # 6) Normalize each dimension to [0,1] def normalize_to_01(array_2d): min_vals = array_2d.min(axis=0) max_vals = array_2d.max(axis=0) denom = (max_vals - min_vals) + 1e-8 return (array_2d - min_vals) / denom all_features_rgb = normalize_to_01(all_features_3d) N_src = H_s * W_s src_rgb_flat = all_features_rgb[:N_src] # (N_src, 3) tgt_rgb_flat = all_features_rgb[N_src:] # (N_tgt, 3) src_color_map = src_rgb_flat.reshape(H_s, W_s, 3) tgt_color_map = tgt_rgb_flat.reshape(H_t, W_t, 3) src_img = Image.fromarray((src_color_map * 255).astype(np.uint8)) tgt_img = Image.fromarray((tgt_color_map * 255).astype(np.uint8)) src_img_resized = src_img.resize(resize_to, Image.Resampling.LANCZOS) tgt_img_resized = tgt_img.resize(resize_to, Image.Resampling.LANCZOS) combined_width = resize_to[0] * 2 combined_height = resize_to[1] combined_img = Image.new("RGB", (combined_width, combined_height)) combined_img.paste(src_img_resized, (0, 0)) combined_img.paste(tgt_img_resized, (resize_to[0], 0)) os.makedirs(os.path.dirname(filename_out), exist_ok=True) combined_img.save(filename_out) print(f"Saved visualization to {filename_out}")