jev-aleks's picture
scenedino init
9e15541
import numpy as np
from typing import Optional, Tuple, List
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from .vit import dino_small8, dino_base8, dino_small, dino_base, dinov2_small, dinov2_base, dino_reg_small, dino_reg_base
from .decoder import NoDecoder, SimpleFeaturePyramidDecoder
from .dpt_head import DPTHead
from .downsampler import PatchSalienceDownsampler, BilinearDownsampler
from .upsampler import InterpolatedGT, MultiScaleCropGT, MultiScaleCropGT_kornia
from .dim_reduction import OrthogonalLinearDimReduction, MlpDimReduction, NoDimReduction
from .visualization import VisualizationModule
def build_encoder(backbone: str, image_size: Tuple[int, int], intermediate_features: List[int], key_features: bool, version: str):
match backbone:
case "vit-s" | "vit-b" | "fit3d-s":
return DINOv2Encoder(backbone,
image_size,
intermediate_features=intermediate_features,
key_features=key_features,
version=version)
case _:
raise NotImplementedError
def build_decoder(decoder_arch: str, patch_size: int, image_size: Tuple[int, int], latent_size: int, num_ch_enc: List[int], decoder_out_dim: int):
match decoder_arch:
case "nearest" | "bilinear" | "bicubic":
return NoDecoder(image_size,
interpolation=decoder_arch,
normalize_features=True)
case "spf":
# TODO: SPF with patch size 8 is not implemented yet
num_ch_dec = np.array([128, 128, 256, 256, 512])
scales = range(4)
return SimpleFeaturePyramidDecoder(latent_size=latent_size,
num_ch_enc=num_ch_enc,
num_ch_dec=num_ch_dec,
d_out=decoder_out_dim,
scales=scales,
use_skips=True,
device="cuda")
case "dpt":
return DPTHead(embed_dims=latent_size,
post_process_channels=num_ch_enc,
readout_type="ignore",
patch_size=patch_size,
d_out=decoder_out_dim,
expand_channels=False)
case _:
raise NotImplementedError
def build_downsampler(arch: str, dim: int, patch_size: int):
match arch:
case "featup":
return PatchSalienceDownsampler(dim, patch_size=patch_size, normalize_features=True)
case "bilinear":
return BilinearDownsampler(patch_size=patch_size)
case _:
raise NotImplementedError
def build_gt_upsampling_wrapper(arch: str, gt_encoder: nn.Module, image_size: Tuple[int, int]):
match arch:
case "nearest" | "bilinear" | "bicubic":
return InterpolatedGT(arch, gt_encoder, image_size)
case "multiscale-crop":
return MultiScaleCropGT_kornia(gt_encoder, num_views=4, image_size=image_size)
case _:
raise NotImplementedError
def build_dim_reduction(arch: str, full_channels: int, reduced_channels: int):
match arch:
case "none":
return NoDimReduction(full_channels, reduced_channels)
case "mlp":
return MlpDimReduction(full_channels, reduced_channels, latent_channels=128)
case "orthogonal-linear":
return OrthogonalLinearDimReduction(full_channels, reduced_channels)
case _:
raise NotImplementedError
class DINOv2Module(nn.Module):
def __init__(self,
mode: str, # downsample-prediction, upsample-gt
decoder_arch: str, # nearest, bilinear, sfp, dpt
upsampler_arch: Optional[str], # nearest, bilinear, multiscale-crop
downsampler_arch: Optional[str], # sample-center, featup
encoder_arch: str, # vit-s, vit-b
encoder_freeze: bool,
flip_avg_gt: bool,
dim_reduction_arch: str, # orthogonal-linear, mlp
num_ch_enc: np.array,
intermediate_features: List[int],
decoder_out_dim: int,
dino_pca_dim: int,
image_size: Tuple[int, int],
key_features: bool,
dino_version: str, # v1, v2, reg, fit3d
separate_gt_version: Optional[str], # v1, v2, reg, fit3d, None (reuses encoder)
):
super().__init__()
self.encoder = build_encoder(encoder_arch, image_size, intermediate_features, key_features, dino_version)
self.flip_avg_gt = flip_avg_gt
if encoder_freeze or separate_gt_version is None:
self.encoder_frozen = True
for p in self.encoder.parameters(True):
p.requires_grad = False
else:
self.encoder_frozen = False
self.decoder = build_decoder(decoder_arch,
self.encoder.patch_size,
image_size,
self.encoder.latent_size,
num_ch_enc,
decoder_out_dim)
if separate_gt_version is None:
self.gt_encoder = self.encoder
else:
self.gt_encoder = build_encoder(encoder_arch, image_size, [], key_features, separate_gt_version)
for p in self.gt_encoder.parameters(True):
p.requires_grad = False
# General way of creating loss
if mode == "downsample-prediction":
assert upsampler_arch is None
self.downsampler = build_downsampler(downsampler_arch, self.gt_encoder.latent_size, self.gt_encoder.patch_size)
self.gt_wrapper = None
elif mode == "upsample-gt":
assert downsampler_arch is None
self.downsampler = None
self.gt_wrapper = build_gt_upsampling_wrapper(upsampler_arch, self.gt_encoder, image_size)
else:
raise NotImplementedError
self.extra_outs = 0
self.latent_size = decoder_out_dim
self.dino_pca_dim = dino_pca_dim
self.dim_reduction = build_dim_reduction(dim_reduction_arch, self.encoder.latent_size, dino_pca_dim)
self.visualization = VisualizationModule(self.encoder.latent_size)
def forward(self, x, ground_truth=False):
if ground_truth:
with torch.no_grad():
if self.gt_wrapper is not None:
gt_0 = self.gt_wrapper(x)
if self.flip_avg_gt:
gt_flipped = self.gt_wrapper(x.flip([-1]))
gt_avg = [F.normalize(gt_flipped[i].flip([-1]) + gt_0[i], dim=1) for i in range(len(gt_0))]
return gt_avg
else:
return gt_0
else:
gt_0 = self.gt_encoder(x)[-1]
if self.flip_avg_gt:
gt_flipped = self.gt_encoder(x.flip([-1]))[-1]
gt_avg = F.normalize(gt_flipped.flip([-1]) + gt_0, dim=1)
return [gt_avg]
else:
return [gt_0]
else:
if self.encoder_frozen:
with torch.no_grad():
patch_features = self.encoder(x)
else:
patch_features = self.encoder(x)
return self.decoder(patch_features)
def downsample(self, x, mode="patch"):
if self.downsampler is None:
return None
else:
return self.downsampler(x, mode)
def expand_dim(self, features):
return self.dim_reduction.transform_expand(features)
def fit_visualization(self, features, refit=True):
return self.visualization.fit_pca(features, refit)
def transform_visualization(self, features, norm=False, from_dim=0):
return self.visualization.transform_pca(features, norm, from_dim)
def fit_transform_kmeans_visualization(self, features):
return self.visualization.fit_transform_kmeans_batch(features)
@classmethod
def from_conf(cls, conf):
return cls(
mode=conf.mode,
decoder_arch=conf.decoder_arch,
upsampler_arch=conf.get("upsampler_arch", None),
downsampler_arch=conf.get("downsampler_arch", None),
encoder_arch=conf.encoder_arch,
encoder_freeze=conf.encoder_freeze,
flip_avg_gt=conf.get("flip_avg_gt", False),
dim_reduction_arch=conf.dim_reduction_arch,
num_ch_enc=conf.get("num_ch_enc", None),
intermediate_features=conf.get("intermediate_features", []),
decoder_out_dim=conf.decoder_out_dim,
dino_pca_dim=conf.dino_pca_dim,
image_size=conf.image_size,
key_features=conf.key_features,
dino_version=conf.get("version", "reg"),
separate_gt_version=conf.get("separate_gt_version", None)
)
def _normalize_input(x):
norm_tf = torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
return norm_tf(x / 2 + 0.5)
class DINOv2Encoder(nn.Module):
def __init__(self, backbone, image_size, intermediate_features, key_features, version):
super().__init__()
self.image_size = image_size
if version in ["fit3d", "v2", "reg"]:
# "Internal" patch size 14 is resized to "External" patch size 16 for decoder!
self.patch_size = 16
adjusted_image_size = (image_size[0] * 14 // self.patch_size, image_size[1] * 14 // self.patch_size)
self.resize_tf = torchvision.transforms.Resize(size=adjusted_image_size,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR)
elif version == "v1":
self.patch_size = 8
adjusted_image_size = self.image_size
self.resize_tf = None
elif version == "v1_16":
self.patch_size = 16
adjusted_image_size = self.image_size
self.resize_tf = None
else:
raise NotImplementedError()
self.key_features = key_features
self.backbone = backbone
self.version = version
self.model, self.latent_size = self.load_model(backbone, version, adjusted_image_size, intermediate_features)
def forward(self, x):
x = _normalize_input(x)
if self.resize_tf:
x = self.resize_tf(x)
output_dict = self.model(x)
if self.version == "fit3d":
output_dict = self.model.output_dict
inter_keys = [output_key for output_key in output_dict if output_key.startswith("intermediate_features.")]
result = []
for inter_key in sorted(inter_keys):
output = output_dict[inter_key].transpose(-1, -2) # (L, B, C_dino, H*W)
output_grid = output.view(*output.size()[:-1],
x.size(-2) // self.model.patch_size,
x.size(-1) // self.model.patch_size)
result.append(output_grid)
if self.key_features:
output = output_dict['key_features'].transpose(-1, -2).flatten(1, 2)
output = F.normalize(output, dim=1)
else:
output = output_dict['features_normalized'].transpose(-1, -2)
output = F.normalize(output, dim=1)
output_grid = output.view(*output.size()[:-1],
x.size(-2) // self.model.patch_size,
x.size(-1) // self.model.patch_size)
result.append(output_grid)
return result
def load_model(self, backbone, version, image_size, intermediate_features):
if version == "fit3d":
if backbone == "vit-s":
model_name = "dinov2_reg_small_fine"
elif backbone == "vit-b":
model_name = "dinov2_reg_base_fine"
else:
raise NotImplementedError()
def get_features(model, key):
def hook(blk, input, output):
model.output_dict[key] = output[:, 5:]
return hook
model = torch.hub.load("ywyue/FiT3D", model_name).to("cuda")
model.norm.register_forward_hook(get_features(model, f"features_normalized"))
for i, _blk in enumerate(model.blocks):
if i in intermediate_features:
_blk.register_forward_hook(get_features(model, f"intermediate_features.{i}"))
model.output_dict = {}
model.patch_size = 14
elif version == "v1" and backbone == "vit-s":
model = dino_small8(image_size=image_size, intermediate_features=intermediate_features)
elif version == "v1" and backbone == "vit-b":
model = dino_base8(image_size=image_size, intermediate_features=intermediate_features)
elif version == "v1_16" and backbone == "vit-s":
model = dino_small(image_size=image_size, intermediate_features=intermediate_features)
elif version == "v1_16" and backbone == "vit-b":
model = dino_base(image_size=image_size, intermediate_features=intermediate_features)
elif version == "v2" and backbone == "vit-s":
model = dinov2_small(image_size=image_size, intermediate_features=intermediate_features)
elif version == "v2" and backbone == "vit-b":
model = dinov2_base(image_size=image_size, intermediate_features=intermediate_features)
elif version == "reg" and backbone == "vit-s":
model = dino_reg_small(image_size=image_size, intermediate_features=intermediate_features)
elif version == "reg" and backbone == "vit-b":
model = dino_reg_base(image_size=image_size, intermediate_features=intermediate_features)
else:
raise NotImplementedError()
if backbone == "vit-s":
latent_size = 384
elif backbone == "vit-b":
latent_size = 768
return model, latent_size