Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
from typing import Optional, Tuple, List | |
import torch | |
import torchvision | |
from torch import nn | |
import torch.nn.functional as F | |
from .vit import dino_small8, dino_base8, dino_small, dino_base, dinov2_small, dinov2_base, dino_reg_small, dino_reg_base | |
from .decoder import NoDecoder, SimpleFeaturePyramidDecoder | |
from .dpt_head import DPTHead | |
from .downsampler import PatchSalienceDownsampler, BilinearDownsampler | |
from .upsampler import InterpolatedGT, MultiScaleCropGT, MultiScaleCropGT_kornia | |
from .dim_reduction import OrthogonalLinearDimReduction, MlpDimReduction, NoDimReduction | |
from .visualization import VisualizationModule | |
def build_encoder(backbone: str, image_size: Tuple[int, int], intermediate_features: List[int], key_features: bool, version: str): | |
match backbone: | |
case "vit-s" | "vit-b" | "fit3d-s": | |
return DINOv2Encoder(backbone, | |
image_size, | |
intermediate_features=intermediate_features, | |
key_features=key_features, | |
version=version) | |
case _: | |
raise NotImplementedError | |
def build_decoder(decoder_arch: str, patch_size: int, image_size: Tuple[int, int], latent_size: int, num_ch_enc: List[int], decoder_out_dim: int): | |
match decoder_arch: | |
case "nearest" | "bilinear" | "bicubic": | |
return NoDecoder(image_size, | |
interpolation=decoder_arch, | |
normalize_features=True) | |
case "spf": | |
# TODO: SPF with patch size 8 is not implemented yet | |
num_ch_dec = np.array([128, 128, 256, 256, 512]) | |
scales = range(4) | |
return SimpleFeaturePyramidDecoder(latent_size=latent_size, | |
num_ch_enc=num_ch_enc, | |
num_ch_dec=num_ch_dec, | |
d_out=decoder_out_dim, | |
scales=scales, | |
use_skips=True, | |
device="cuda") | |
case "dpt": | |
return DPTHead(embed_dims=latent_size, | |
post_process_channels=num_ch_enc, | |
readout_type="ignore", | |
patch_size=patch_size, | |
d_out=decoder_out_dim, | |
expand_channels=False) | |
case _: | |
raise NotImplementedError | |
def build_downsampler(arch: str, dim: int, patch_size: int): | |
match arch: | |
case "featup": | |
return PatchSalienceDownsampler(dim, patch_size=patch_size, normalize_features=True) | |
case "bilinear": | |
return BilinearDownsampler(patch_size=patch_size) | |
case _: | |
raise NotImplementedError | |
def build_gt_upsampling_wrapper(arch: str, gt_encoder: nn.Module, image_size: Tuple[int, int]): | |
match arch: | |
case "nearest" | "bilinear" | "bicubic": | |
return InterpolatedGT(arch, gt_encoder, image_size) | |
case "multiscale-crop": | |
return MultiScaleCropGT_kornia(gt_encoder, num_views=4, image_size=image_size) | |
case _: | |
raise NotImplementedError | |
def build_dim_reduction(arch: str, full_channels: int, reduced_channels: int): | |
match arch: | |
case "none": | |
return NoDimReduction(full_channels, reduced_channels) | |
case "mlp": | |
return MlpDimReduction(full_channels, reduced_channels, latent_channels=128) | |
case "orthogonal-linear": | |
return OrthogonalLinearDimReduction(full_channels, reduced_channels) | |
case _: | |
raise NotImplementedError | |
class DINOv2Module(nn.Module): | |
def __init__(self, | |
mode: str, # downsample-prediction, upsample-gt | |
decoder_arch: str, # nearest, bilinear, sfp, dpt | |
upsampler_arch: Optional[str], # nearest, bilinear, multiscale-crop | |
downsampler_arch: Optional[str], # sample-center, featup | |
encoder_arch: str, # vit-s, vit-b | |
encoder_freeze: bool, | |
flip_avg_gt: bool, | |
dim_reduction_arch: str, # orthogonal-linear, mlp | |
num_ch_enc: np.array, | |
intermediate_features: List[int], | |
decoder_out_dim: int, | |
dino_pca_dim: int, | |
image_size: Tuple[int, int], | |
key_features: bool, | |
dino_version: str, # v1, v2, reg, fit3d | |
separate_gt_version: Optional[str], # v1, v2, reg, fit3d, None (reuses encoder) | |
): | |
super().__init__() | |
self.encoder = build_encoder(encoder_arch, image_size, intermediate_features, key_features, dino_version) | |
self.flip_avg_gt = flip_avg_gt | |
if encoder_freeze or separate_gt_version is None: | |
self.encoder_frozen = True | |
for p in self.encoder.parameters(True): | |
p.requires_grad = False | |
else: | |
self.encoder_frozen = False | |
self.decoder = build_decoder(decoder_arch, | |
self.encoder.patch_size, | |
image_size, | |
self.encoder.latent_size, | |
num_ch_enc, | |
decoder_out_dim) | |
if separate_gt_version is None: | |
self.gt_encoder = self.encoder | |
else: | |
self.gt_encoder = build_encoder(encoder_arch, image_size, [], key_features, separate_gt_version) | |
for p in self.gt_encoder.parameters(True): | |
p.requires_grad = False | |
# General way of creating loss | |
if mode == "downsample-prediction": | |
assert upsampler_arch is None | |
self.downsampler = build_downsampler(downsampler_arch, self.gt_encoder.latent_size, self.gt_encoder.patch_size) | |
self.gt_wrapper = None | |
elif mode == "upsample-gt": | |
assert downsampler_arch is None | |
self.downsampler = None | |
self.gt_wrapper = build_gt_upsampling_wrapper(upsampler_arch, self.gt_encoder, image_size) | |
else: | |
raise NotImplementedError | |
self.extra_outs = 0 | |
self.latent_size = decoder_out_dim | |
self.dino_pca_dim = dino_pca_dim | |
self.dim_reduction = build_dim_reduction(dim_reduction_arch, self.encoder.latent_size, dino_pca_dim) | |
self.visualization = VisualizationModule(self.encoder.latent_size) | |
def forward(self, x, ground_truth=False): | |
if ground_truth: | |
with torch.no_grad(): | |
if self.gt_wrapper is not None: | |
gt_0 = self.gt_wrapper(x) | |
if self.flip_avg_gt: | |
gt_flipped = self.gt_wrapper(x.flip([-1])) | |
gt_avg = [F.normalize(gt_flipped[i].flip([-1]) + gt_0[i], dim=1) for i in range(len(gt_0))] | |
return gt_avg | |
else: | |
return gt_0 | |
else: | |
gt_0 = self.gt_encoder(x)[-1] | |
if self.flip_avg_gt: | |
gt_flipped = self.gt_encoder(x.flip([-1]))[-1] | |
gt_avg = F.normalize(gt_flipped.flip([-1]) + gt_0, dim=1) | |
return [gt_avg] | |
else: | |
return [gt_0] | |
else: | |
if self.encoder_frozen: | |
with torch.no_grad(): | |
patch_features = self.encoder(x) | |
else: | |
patch_features = self.encoder(x) | |
return self.decoder(patch_features) | |
def downsample(self, x, mode="patch"): | |
if self.downsampler is None: | |
return None | |
else: | |
return self.downsampler(x, mode) | |
def expand_dim(self, features): | |
return self.dim_reduction.transform_expand(features) | |
def fit_visualization(self, features, refit=True): | |
return self.visualization.fit_pca(features, refit) | |
def transform_visualization(self, features, norm=False, from_dim=0): | |
return self.visualization.transform_pca(features, norm, from_dim) | |
def fit_transform_kmeans_visualization(self, features): | |
return self.visualization.fit_transform_kmeans_batch(features) | |
def from_conf(cls, conf): | |
return cls( | |
mode=conf.mode, | |
decoder_arch=conf.decoder_arch, | |
upsampler_arch=conf.get("upsampler_arch", None), | |
downsampler_arch=conf.get("downsampler_arch", None), | |
encoder_arch=conf.encoder_arch, | |
encoder_freeze=conf.encoder_freeze, | |
flip_avg_gt=conf.get("flip_avg_gt", False), | |
dim_reduction_arch=conf.dim_reduction_arch, | |
num_ch_enc=conf.get("num_ch_enc", None), | |
intermediate_features=conf.get("intermediate_features", []), | |
decoder_out_dim=conf.decoder_out_dim, | |
dino_pca_dim=conf.dino_pca_dim, | |
image_size=conf.image_size, | |
key_features=conf.key_features, | |
dino_version=conf.get("version", "reg"), | |
separate_gt_version=conf.get("separate_gt_version", None) | |
) | |
def _normalize_input(x): | |
norm_tf = torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) | |
return norm_tf(x / 2 + 0.5) | |
class DINOv2Encoder(nn.Module): | |
def __init__(self, backbone, image_size, intermediate_features, key_features, version): | |
super().__init__() | |
self.image_size = image_size | |
if version in ["fit3d", "v2", "reg"]: | |
# "Internal" patch size 14 is resized to "External" patch size 16 for decoder! | |
self.patch_size = 16 | |
adjusted_image_size = (image_size[0] * 14 // self.patch_size, image_size[1] * 14 // self.patch_size) | |
self.resize_tf = torchvision.transforms.Resize(size=adjusted_image_size, | |
interpolation=torchvision.transforms.InterpolationMode.BILINEAR) | |
elif version == "v1": | |
self.patch_size = 8 | |
adjusted_image_size = self.image_size | |
self.resize_tf = None | |
elif version == "v1_16": | |
self.patch_size = 16 | |
adjusted_image_size = self.image_size | |
self.resize_tf = None | |
else: | |
raise NotImplementedError() | |
self.key_features = key_features | |
self.backbone = backbone | |
self.version = version | |
self.model, self.latent_size = self.load_model(backbone, version, adjusted_image_size, intermediate_features) | |
def forward(self, x): | |
x = _normalize_input(x) | |
if self.resize_tf: | |
x = self.resize_tf(x) | |
output_dict = self.model(x) | |
if self.version == "fit3d": | |
output_dict = self.model.output_dict | |
inter_keys = [output_key for output_key in output_dict if output_key.startswith("intermediate_features.")] | |
result = [] | |
for inter_key in sorted(inter_keys): | |
output = output_dict[inter_key].transpose(-1, -2) # (L, B, C_dino, H*W) | |
output_grid = output.view(*output.size()[:-1], | |
x.size(-2) // self.model.patch_size, | |
x.size(-1) // self.model.patch_size) | |
result.append(output_grid) | |
if self.key_features: | |
output = output_dict['key_features'].transpose(-1, -2).flatten(1, 2) | |
output = F.normalize(output, dim=1) | |
else: | |
output = output_dict['features_normalized'].transpose(-1, -2) | |
output = F.normalize(output, dim=1) | |
output_grid = output.view(*output.size()[:-1], | |
x.size(-2) // self.model.patch_size, | |
x.size(-1) // self.model.patch_size) | |
result.append(output_grid) | |
return result | |
def load_model(self, backbone, version, image_size, intermediate_features): | |
if version == "fit3d": | |
if backbone == "vit-s": | |
model_name = "dinov2_reg_small_fine" | |
elif backbone == "vit-b": | |
model_name = "dinov2_reg_base_fine" | |
else: | |
raise NotImplementedError() | |
def get_features(model, key): | |
def hook(blk, input, output): | |
model.output_dict[key] = output[:, 5:] | |
return hook | |
model = torch.hub.load("ywyue/FiT3D", model_name).to("cuda") | |
model.norm.register_forward_hook(get_features(model, f"features_normalized")) | |
for i, _blk in enumerate(model.blocks): | |
if i in intermediate_features: | |
_blk.register_forward_hook(get_features(model, f"intermediate_features.{i}")) | |
model.output_dict = {} | |
model.patch_size = 14 | |
elif version == "v1" and backbone == "vit-s": | |
model = dino_small8(image_size=image_size, intermediate_features=intermediate_features) | |
elif version == "v1" and backbone == "vit-b": | |
model = dino_base8(image_size=image_size, intermediate_features=intermediate_features) | |
elif version == "v1_16" and backbone == "vit-s": | |
model = dino_small(image_size=image_size, intermediate_features=intermediate_features) | |
elif version == "v1_16" and backbone == "vit-b": | |
model = dino_base(image_size=image_size, intermediate_features=intermediate_features) | |
elif version == "v2" and backbone == "vit-s": | |
model = dinov2_small(image_size=image_size, intermediate_features=intermediate_features) | |
elif version == "v2" and backbone == "vit-b": | |
model = dinov2_base(image_size=image_size, intermediate_features=intermediate_features) | |
elif version == "reg" and backbone == "vit-s": | |
model = dino_reg_small(image_size=image_size, intermediate_features=intermediate_features) | |
elif version == "reg" and backbone == "vit-b": | |
model = dino_reg_base(image_size=image_size, intermediate_features=intermediate_features) | |
else: | |
raise NotImplementedError() | |
if backbone == "vit-s": | |
latent_size = 384 | |
elif backbone == "vit-b": | |
latent_size = 768 | |
return model, latent_size | |