# Copyright (c) 2025 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. # LICENSE is in incl_licenses directory. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import argparse import gc import math import os import random from collections import defaultdict from typing import Any, Dict, Iterable, List, Tuple # import cv2 import numpy as np import torch import torch.distributed as dist import torch.nn.functional as F from datasets import load_dataset, load_dataset_builder from datasets.distributed import split_dataset_by_node from einops import rearrange from PIL import Image from torch import nn from torch.utils.data import DataLoader from torchvision.utils import make_grid from tqdm import tqdm # from common import rank_print, load_model, get_standard_transform, collate # # try: # import wandb # except ImportError: # wandb = None LAYER_STATS = dict() @torch.inference_mode() def main(rank: int = 0, world_size: int = 1): """ Computes the RankMe (http://arxiv.org/abs/2210.02885) and LiDAR (http://arxiv.org/abs/2312.04000) estimates of the rank of the produced embeddings. While RADIO doesn't train in a multi-view setting which is an assumption of LiDAR, the metric does integrate an important concept of the invariance of the summary features to different view/augmentations of the same image. """ local_rank = rank % torch.cuda.device_count() torch.cuda.set_device(local_rank) cv2.setNumThreads(1) device = torch.device("cuda", local_rank) parser = argparse.ArgumentParser(description="Compute SSL embedding rank estimates") parser.add_argument("-v", "--model-version", default="radio_v2", help="Which radio model to load.") parser.add_argument("-d", "--dataset", default="imagenet-1k", help="The name of the dataset to classify") parser.add_argument("--split", default="validation", help="The dataset split to use.") parser.add_argument("-n", default=10, type=int, help="The number of samples to load") parser.add_argument( "-r", "--resolution", nargs="+", type=int, default=None, help="The input image resolution." " If one value is specified, the shortest dimension is resized to this." " If two, the image is center cropped." " If not specified, center cropped 378px is used." " Default: The RADIO model's preferred resolution.", ) parser.add_argument( "--resize-multiple", type=int, default=None, help="Resize images with dimensions a multiple of this value." " This should be equal to the patch size of a ViT (e.g. RADIOv1)", ) parser.add_argument( "--batch-size", type=int, default=16, help="The batch size. If the input is variable sized, then this argument becomes a maximum.", ) parser.add_argument("--workers", default=8, type=int, help="Number of loader workers to use") parser.add_argument( "--vitdet-window-size", default=None, type=int, help="Enable ViTDet at the specific window size" ) parser.add_argument("--output-dir", default="vis_denoise", type=str) parser.add_argument("--adaptor-name", default=None, type=str, help="Generate features from a teacher adaptor") args, _ = parser.parse_known_args() torch.manual_seed(42 + rank) np.random.seed(42 + rank) random.seed(42 + rank) rank_print("Loading model...") model, preprocessor, info = load_model( args.model_version, vitdet_window_size=args.vitdet_window_size, adaptor_name=args.adaptor_name ) model.to(device=device).eval() if isinstance(preprocessor, nn.Module): preprocessor.to(device).eval() rank_print("Done") rank_print("Loading dataset...") ds_builder = load_dataset_builder(args.dataset, trust_remote_code=True) if args.resolution is None: args.resolution = (model.preferred_resolution.height, model.preferred_resolution.width) patch_size = model.patch_size if args.resize_multiple is None: args.resize_multiple = getattr(model, "min_resolution_step", model.patch_size) transform = get_standard_transform(args.resolution, args.resize_multiple) dataset = ds_builder.as_dataset(split=args.split) dataset = dataset.to_iterable_dataset(num_shards=world_size * max(1, args.workers)) dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) dataset = dataset.map( lambda ex: dict(image=transform(ex["image"]), label=torch.as_tensor(ex["label"], dtype=torch.int64)) ) loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, collate_fn=collate, pin_memory=args.workers > 0, drop_last=False, ) rank_print("Done") rank_print(f"Description: {ds_builder.info.description}") dirs = dict( orig=os.path.join(args.output_dir, "orig"), viz=os.path.join(args.output_dir, "viz"), sbs=os.path.join(args.output_dir, "sbs"), ) for d in dirs.values(): os.makedirs(d, exist_ok=True) ctr = 0 for batches in loader: if ctr >= args.n: break for images, _ in batches: images = images.to(device=device, non_blocking=True) all_feat = [] with torch.autocast(device.type, dtype=torch.bfloat16): p_images = preprocessor(images) output = model(p_images) if args.adaptor_name: all_feat = [ output["backbone"].features, output[args.adaptor_name].features, ] else: all_feat = [output[1]] all_feat = torch.stack(all_feat, dim=1) num_rows = images.shape[-2] // patch_size num_cols = images.shape[-1] // patch_size all_feat = rearrange(all_feat, "b m (h w) c -> b m h w c", h=num_rows, w=num_cols).float() for i, feats in enumerate(all_feat): colored = [] for features in feats: color = get_pca_map(features, images.shape[-2:]) colored.append(color) orig = cv2.cvtColor(images[i].permute(1, 2, 0).cpu().numpy(), cv2.COLOR_RGB2BGR) cv2.imwrite(f'{dirs["orig"]}/vis_{ctr}.jpg', orig * 255) cv2.imwrite(f'{dirs["viz"]}/vis_{ctr}.jpg', colored[-1] * 255) op = np.concatenate([orig] + colored, axis=1) * 255 cv2.imwrite(f'{dirs["sbs"]}/vis_{ctr}.jpg', op) ctr += 1 def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False): # features: (N, C) # m: a hyperparam controlling how many std dev outside for outliers assert len(features.shape) == 2, "features should be (N, C)" reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2] colors = features @ reduction_mat if remove_first_component: colors_min = colors.min(dim=0).values colors_max = colors.max(dim=0).values tmp_colors = (colors - colors_min) / (colors_max - colors_min) fg_mask = tmp_colors[..., 0] < 0.2 reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2] colors = features @ reduction_mat else: fg_mask = torch.ones_like(colors[:, 0]).bool() d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values) mdev = torch.median(d, dim=0).values s = d / mdev try: rins = colors[fg_mask][s[:, 0] < m, 0] gins = colors[fg_mask][s[:, 1] < m, 1] bins = colors[fg_mask][s[:, 2] < m, 2] rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) except: rins = colors gins = colors bins = colors rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat) def get_pca_map( feature_map: torch.Tensor, img_size, interpolation="bicubic", return_pca_stats=False, pca_stats=None, ): """ feature_map: (1, h, w, C) is the feature map of a single image. """ feature_map = feature_map.float() if feature_map.shape[0] != 1: # make it (1, h, w, C) feature_map = feature_map[None] if pca_stats is None: reduct_mat, color_min, color_max = get_robust_pca(feature_map.reshape(-1, feature_map.shape[-1])) else: reduct_mat, color_min, color_max = pca_stats pca_color = feature_map @ reduct_mat pca_color = (pca_color - color_min) / (color_max - color_min) pca_color = F.interpolate( pca_color.permute(0, 3, 1, 2), size=img_size, mode=interpolation, ).permute(0, 2, 3, 1) pca_color = pca_color.clamp(0, 1) pca_color = pca_color.cpu().numpy().squeeze(0) if return_pca_stats: return pca_color, (reduct_mat, color_min, color_max) return pca_color def get_scale_map( scalar_map: torch.Tensor, img_size, interpolation="nearest", ): """ scalar_map: (1, h, w, C) is the feature map of a single image. """ if scalar_map.shape[0] != 1: scalar_map = scalar_map[None] scalar_map = (scalar_map - scalar_map.min()) / (scalar_map.max() - scalar_map.min() + 1e-6) scalar_map = F.interpolate( scalar_map.permute(0, 3, 1, 2), size=img_size, mode=interpolation, ).permute(0, 2, 3, 1) # cmap = plt.get_cmap("viridis") # scalar_map = cmap(scalar_map)[..., :3] # make it 3 channels scalar_map = torch.cat([scalar_map] * 3, dim=-1) scalar_map = scalar_map.cpu().numpy().squeeze(0) return scalar_map def get_similarity_map(features: torch.Tensor, img_size=(224, 224)): """ compute the similarity map of the central patch to the rest of the image """ assert len(features.shape) == 4, "features should be (1, C, H, W)" H, W, C = features.shape[1:] center_patch_feature = features[0, H // 2, W // 2, :] center_patch_feature_normalized = center_patch_feature / center_patch_feature.norm() center_patch_feature_normalized = center_patch_feature_normalized.unsqueeze(1) # Reshape and normalize the entire feature tensor features_flat = features.view(-1, C) features_normalized = features_flat / features_flat.norm(dim=1, keepdim=True) similarity_map_flat = features_normalized @ center_patch_feature_normalized # Reshape the flat similarity map back to the spatial dimensions (H, W) similarity_map = similarity_map_flat.view(H, W) # Normalize the similarity map to be in the range [0, 1] for visualization similarity_map = (similarity_map - similarity_map.min()) / (similarity_map.max() - similarity_map.min()) # we don't want the center patch to be the most similar similarity_map[H // 2, W // 2] = -1.0 similarity_map = ( F.interpolate( similarity_map.unsqueeze(0).unsqueeze(0), size=img_size, mode="bilinear", ) .squeeze(0) .squeeze(0) ) similarity_map_np = similarity_map.cpu().numpy() negative_mask = similarity_map_np < 0 colormap = plt.get_cmap("turbo") # Apply the colormap directly to the normalized similarity map and multiply by 255 to get RGB values similarity_map_rgb = colormap(similarity_map_np)[..., :3] similarity_map_rgb[negative_mask] = [1.0, 0.0, 0.0] return similarity_map_rgb def get_cluster_map( feature_map: torch.Tensor, img_size, num_clusters=10, ) -> torch.Tensor: kmeans = KMeans(n_clusters=num_clusters, distance=CosineSimilarity, verbose=False) if feature_map.shape[0] != 1: # make it (1, h, w, C) feature_map = feature_map[None] labels = kmeans.fit_predict(feature_map.reshape(1, -1, feature_map.shape[-1])).float() labels = ( F.interpolate(labels.reshape(1, *feature_map.shape[:-1]), size=img_size, mode="nearest").squeeze().cpu().numpy() ).astype(int) cmap = plt.get_cmap("rainbow", num_clusters) cluster_map = cmap(labels)[..., :3] return cluster_map.reshape(img_size[0], img_size[1], 3) if __name__ == "__main__": rank = 0 world_size = 1 # if 'WORLD_SIZE' in os.environ: # dist.init_process_group(backend='nccl') # rank = dist.get_rank() # world_size = dist.get_world_size() main(rank, world_size)