jev-aleks's picture
scenedino init
9e15541
import os
import pathlib
from typing import Dict, Tuple, List
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.vision_transformer import VisionTransformer
from torch import Tensor
from torch.fx import GraphModule
from torchvision.models.feature_extraction import create_feature_extractor
# from ups.utils import normalize
__all__: Tuple[str, ...] = (
"dino_small",
"dino_base",
"dinov2_small",
"dinov2_base",
"dino_reg_small",
"dino_reg_base",
"i_jepa_huge",
"mae_base",
"self_patch_small",
"synclr_base",
"mocov3_base",
"msn_base",
"vmae_large",
)
def _disable_fused_attention(model: nn.Module) -> None:
"""Function disables the use of fused attention in Timm's ViT models. (Don't use anywhere else!)
Args:
model (nn.Module): Timm ViT model.
"""
# Get ViT depth
depth: int = len(model.blocks) # type: ignore
# Disable fused attention for last block
for name, module in model.named_modules():
if "Attention" in str(type(module)):
if str(depth - 1) in name:
module.fused_attn = False # type: ignore
def _load_vit(name: str, image_size: Tuple[int, int] = (224, 224), depth: int = 12) -> nn.Module:
"""Function to load ViT models from Timm.
Args:
name (str): Timm name of the model.
depth (int): Depth of the model.
Returns:
model (nn.Module): ViT model as a nn.Module.
"""
# Load model
model: nn.Module = timm.create_model(name, pretrained=True, img_size=image_size, num_classes=0)
# Force not to use fused attention to access attention maps
# _disable_fused_attention(model)
return model
def _interpolate_positional_embeddings(
positional_embeddings: Tensor,
original_image_size: Tuple[int, int],
target_image_size: Tuple[int, int],
patch_size: int,
num_additional_tokens: int = 1,
) -> Tensor:
"""Function interpolates positional embeddings to a different image size.
Args:
positional_embeddings (Tensor): Positional embeddings of the sape [1, N, C].
original_image_size (Tuple[int, int]): Original image size as a tuple.
target_image_size (Tuple[int, int]): Target image size as a tuple.
patch_size (int): Utilize patch size.
num_additional_tokens (int): Number of additional tokens used. Default 1 (class token).
Returns:
positional_embeddings_interpolated (Tensor): Interpolated positional embeddings [1, N_new, C].
"""
# Get positional embeddings for image
if num_additional_tokens > 0:
positional_embeddings_add_tokens: Tensor = positional_embeddings[:, :num_additional_tokens]
positional_embeddings_image: Tensor = positional_embeddings[:, num_additional_tokens:]
else:
positional_embeddings_image = positional_embeddings
# Reshape embeddings to 2D
positional_embeddings_image = positional_embeddings_image.view(
1, original_image_size[0] // patch_size, original_image_size[1] // patch_size, -1
)
# Interpolate positional embeddings
positional_embeddings_image = F.interpolate(
positional_embeddings_image.permute(0, 3, 1, 2),
size=(target_image_size[0] // patch_size, target_image_size[1] // patch_size),
mode="bicubic",
align_corners=False,
antialias=False,
).permute(0, 2, 3, 1)
# Stack positional embeddings again
if num_additional_tokens > 0:
positional_embeddings_interpolated: Tensor = torch.cat(
(positional_embeddings_add_tokens, positional_embeddings_image.flatten(1, 2)), dim=1
)
else:
positional_embeddings_interpolated = positional_embeddings_image.flatten(1, 2)
return positional_embeddings_interpolated
class _ViT(nn.Module):
"""This class wraps Timm's ViT's and always ensures eval mode."""
def __init__(
self,
vit: nn.Module,
patch_size: int,
registers: bool = False,
class_token: bool = False,
intermediate_features: List[int] = None,
) -> None:
"""Constructor method.
Args:
vit (nn.Module): Timm ViT model.
patch_size (int): Patch size utilized.
registers (bool): Set to true if registers are use. Default False.
class_token (bool): Set true if class token is use. Default False.
"""
# Call super constructor
super(_ViT, self).__init__()
# Save parameter
self.patch_size: int = patch_size
self.registers: bool = registers
self.class_token: bool = class_token
# Get ViT depth
depth: int = len(vit.blocks) # type: ignore
return_nodes = {
#f"blocks.{depth - 1}.attn.softmax": "attention_maps",
"norm": "features_normalized",
f"blocks.{depth - 1}.attn.getitem_4": "key_features",
}
if intermediate_features is not None:
for idx, feat in enumerate(intermediate_features):
return_nodes[f"blocks.{feat}"] = f"intermediate_features.{idx}"
# Make FX graph module for feature extraction
self.vit: GraphModule = create_feature_extractor(vit, return_nodes)
def forward(self, images: Tensor) -> Dict[str, Tensor]:
"""Forward pass.
Notes:
attention_maps have the shape of [B, num heads, N, N].
features have the shape of [B, N, C]
Class token and register tokens omitted!
Args:
images (Tensor): Images (w/ pix. range of [0, 1]) of the shape [B, 3, H, W].
Returns:
output_dict (Dict[str, Tensor]): Dict of features ("attention_maps" and "features").
"""
# Ensure model is in eval mode
self.vit.eval()
# Normalize images
images_normalized: Tensor = images # normalize(images)
# Perform forward pass
output_dict: Dict[str, Tensor] = self.vit(images_normalized)
# Omit class token (and registers) from attention maps and features
if self.registers:
output_dict["features_normalized"] = output_dict["features_normalized"][:, 5:]
#output_dict["attention_maps"] = output_dict["attention_maps"][..., 5:, 5:]
output_dict["key_features"] = output_dict["key_features"][:, :, 5:]
for output_key in output_dict:
if output_key.startswith('intermediate_features'):
output_dict[output_key] = output_dict[output_key][:, 5:]
elif self.class_token:
output_dict["features_normalized"] = output_dict["features_normalized"][:, 1:]
#output_dict["attention_maps"] = output_dict["attention_maps"][..., 1:, 1:]
output_dict["key_features"] = output_dict["key_features"][:, :, 1:]
for output_key in output_dict:
if output_key.startswith('intermediate_features'):
output_dict[output_key] = output_dict[output_key][:, 1:]
# Normalize features
output_dict["features_normalized"] = F.normalize(output_dict["features_normalized"], p=2, dim=2) # [B, N, C]
return output_dict
def mae_base(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
"""Builds the pre-trained MAE base model (patch size is 16 x 16).
Args:
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224).
Returns:
model (nn.Module): ViT MAE model as a nn.Module.
"""
return _ViT(
_load_vit(name="vit_base_patch16_224.mae", image_size=image_size),
patch_size=16,
class_token=True,
)
def vmae_large(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
"""Builds the pre-trained video MAE large model (patch size is 16 x 16).
Args:
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224).
Returns:
model (nn.Module): ViT MAE model as a nn.Module.
"""
# Init model
model = VisionTransformer(
img_size=image_size, patch_size=16, num_classes=0, qkv_bias=True, embed_dim=1024, depth=24, num_heads=16
)
# Force not to use fused attention to access attention maps
_disable_fused_attention(model)
# Load and adapt checkpoint
checkpoint: Dict[str, Tensor] = torch.load(
os.path.join(pathlib.Path(__file__).parent.resolve(), "checkpoints/mae_pretrain_vit_large_k700.pth"),
map_location="cpu",
)["model_state"]
checkpoint["pos_embed"] = checkpoint["pos_embed_spatial"] + checkpoint["pos_embed_temporal"].mean(
dim=1, keepdim=True
)
checkpoint["pos_embed"] = torch.cat((checkpoint["pos_embed_class"], checkpoint["pos_embed"]), dim=1)
checkpoint["patch_embed.proj.weight"] = checkpoint["patch_embed.proj.weight"][:, :, 0]
for layer in range(24):
checkpoint[f"blocks.{layer}.attn.qkv.weight"] = torch.cat(
(
checkpoint[f"blocks.{layer}.attn.q.weight"],
checkpoint[f"blocks.{layer}.attn.k.weight"],
checkpoint[f"blocks.{layer}.attn.v.weight"],
),
dim=0,
)
checkpoint[f"blocks.{layer}.attn.qkv.bias"] = torch.cat(
(
checkpoint[f"blocks.{layer}.attn.q.bias"],
checkpoint[f"blocks.{layer}.attn.k.bias"],
checkpoint[f"blocks.{layer}.attn.v.bias"],
),
dim=0,
)
checkpoint = {key: value for key, value in checkpoint.items() if key in model.state_dict().keys()}
# Interpolated positional embeddings
checkpoint["pos_embed"] = _interpolate_positional_embeddings(
checkpoint["pos_embed"],
original_image_size=(224, 224),
target_image_size=image_size,
patch_size=16,
num_additional_tokens=1,
)
# Load checkpoint
model.load_state_dict(checkpoint)
return _ViT(model, patch_size=16, class_token=True)
def dino_small(image_size: Tuple[int, int] = (224, 224), intermediate_features: List[int] = None) -> nn.Module:
"""Builds the pre-trained ViT DINO small model (patch size is 16 x 16).
Args:
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224).
Returns:
model (nn.Module): ViT Dino model as a nn.Module.
"""
return _ViT(
_load_vit(name="vit_small_patch16_224.dino", image_size=image_size),
patch_size=16,
class_token=True,
intermediate_features=intermediate_features,
)
def dino_small8(image_size: Tuple[int, int] = (224, 224), intermediate_features: List[int] = None) -> nn.Module:
"""Builds the pre-trained ViT DINO small model (patch size is 8 x 8).
Args:
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224).
Returns:
model (nn.Module): ViT Dino model as a nn.Module.
"""
return _ViT(
_load_vit(name="vit_small_patch8_224.dino", image_size=image_size),
patch_size=8,
class_token=True,
intermediate_features=intermediate_features,
)
def dino_base(image_size: Tuple[int, int] = (224, 224), intermediate_features: List[int] = None) -> nn.Module:
"""Builds the pre-trained ViT DINO base model (patch size is 16 x 16).
Args:
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224).
Returns:
model (nn.Module): ViT Dino model as a nn.Module.
"""
return _ViT(
_load_vit(name="vit_base_patch16_224.dino", image_size=image_size),
patch_size=16,
class_token=True,
intermediate_features=intermediate_features,
)
def dino_base8(image_size: Tuple[int, int] = (224, 224), intermediate_features: List[int] = None) -> nn.Module:
"""Builds the pre-trained ViT DINO base model (patch size is 16 x 16).
Args:
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224).
Returns:
model (nn.Module): ViT Dino model as a nn.Module.
"""
return _ViT(
_load_vit(name="vit_base_patch8_224.dino", image_size=image_size),
patch_size=8,
class_token=True,
intermediate_features=intermediate_features,
)
def dinov2_small(image_size: Tuple[int, int] = (224, 224), intermediate_features: List[int] = None) -> nn.Module:
"""Builds the pre-trained ViT DINO V2 small model (patch size is 14 x 14).
Args:
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224).
intermediate_features (List[int]): Index of intermediate layer features to return.
Returns:
model (nn.Module): ViT Dino model as a nn.Module.
"""
return _ViT(
_load_vit(name="vit_small_patch14_dinov2.lvd142m", image_size=image_size),
patch_size=14,
class_token=True,
intermediate_features=intermediate_features,
)
def dinov2_base(image_size: Tuple[int, int] = (224, 224), intermediate_features: List[int] = None) -> nn.Module:
"""Builds the pre-trained ViT DINO V2 base model (patch size is 14 x 14).
Args:
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224).
intermediate_features (List[int]): Index of intermediate layer features to return.
Returns:
model (nn.Module): ViT Dino model as a nn.Module.
"""
return _ViT(
_load_vit(name="vit_base_patch14_dinov2.lvd142m", image_size=image_size),
patch_size=14,
class_token=True,
intermediate_features=intermediate_features,
)
def dino_reg_small(image_size: Tuple[int, int] = (224, 224), intermediate_features: List[int] = None) -> nn.Module:
"""Builds the pre-trained ViT DINO (w/ registers) small model (patch size is 14 x 14).
Args:
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224).
intermediate_features (List[int]): Index of intermediate layer features to return.
Returns:
model (nn.Module): ViT Dino model as a nn.Module.
"""
return _ViT(
_load_vit(name="vit_small_patch14_reg4_dinov2.lvd142m", image_size=image_size),
patch_size=14,
registers=True,
intermediate_features=intermediate_features,
)
def dino_reg_base(image_size: Tuple[int, int] = (224, 224), intermediate_features: List[int] = None) -> nn.Module:
"""Builds the pre-trained ViT DINO (w/ registers) base model (patch size is 14 x 14).
Args:
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224).
intermediate_features (List[int]): Index of intermediate layer features to return.
Returns:
model (nn.Module): ViT Dino model as a nn.Module.
"""
return _ViT(
_load_vit(name="vit_base_patch14_reg4_dinov2.lvd142m", image_size=image_size),
patch_size=14,
registers=True,
intermediate_features=intermediate_features,
)
def synclr_base(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
"""Builds the pre-trained SynCLR ViT base model (patch size is 16 x 16).
Args:
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224).
Returns:
model (nn.Module): SynCLR ViT model as a nn.Module
"""
# Init model
model = _load_vit(name="vit_base_patch16_224", image_size=image_size)
# Load and adapt checkpoint
checkpoint: Dict[str, Tensor] = torch.load(
os.path.join(pathlib.Path(__file__).parent.resolve(), "checkpoints/synclr_vit_b_16.pth")
)["model"]
checkpoint = {key.replace("module.visual.", ""): value for key, value in checkpoint.items()}
# Interpolated positional embeddings
if image_size != (224, 224):
checkpoint["pos_embed"] = _interpolate_positional_embeddings(
checkpoint["pos_embed"],
original_image_size=(224, 224),
target_image_size=image_size,
patch_size=16,
num_additional_tokens=1,
)
# Load checkpoint
model.load_state_dict(checkpoint)
return _ViT(model, patch_size=16, class_token=True)
def mocov3_base(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
"""Builds the pre-trained MoCo-V3 ViT base model (patch size is 16 x 16).
Args:
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224).
Returns:
model (nn.Module): MoCo-V3 ViT model as a nn.Module
"""
# Init model
model = _load_vit(name="vit_base_patch16_224", image_size=image_size)
# Load and adapt checkpoint
checkpoint: Dict[str, Tensor] = torch.load(
os.path.join(pathlib.Path(__file__).parent.resolve(), "checkpoints/vit-b-300ep.pth.tar")
)["state_dict"]
checkpoint = {
key.replace("module.momentum_encoder.", ""): value
for key, value in checkpoint.items()
if ("module.momentum_encoder." in key) and ("head." not in key)
}
# Interpolated positional embeddings
if image_size != (224, 224):
checkpoint["pos_embed"] = _interpolate_positional_embeddings(
checkpoint["pos_embed"],
original_image_size=(224, 224),
target_image_size=image_size,
patch_size=16,
num_additional_tokens=1,
)
# Load checkpoint
model.load_state_dict(checkpoint)
return _ViT(model, patch_size=16, class_token=True)
def msn_base(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
"""Builds the pre-trained MSN ViT base model (patch size is 16 x 16).
Args:
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224).
Returns:
model (nn.Module): MSN ViT model as a nn.Module
"""
# Init model
model = _load_vit(name="vit_base_patch16_224", image_size=image_size)
# Load and adapt checkpoint
checkpoint: Dict[str, Tensor] = torch.load(
os.path.join(pathlib.Path(__file__).parent.resolve(), "checkpoints/vitb16_600ep.pth.tar")
)["target_encoder"]
checkpoint = {
key.replace("module.", ""): value
for key, value in checkpoint.items()
if key.replace("module.", "") in model.state_dict().keys()
}
# Interpolated positional embeddings
if image_size != (224, 224):
checkpoint["pos_embed"] = _interpolate_positional_embeddings(
checkpoint["pos_embed"],
original_image_size=(224, 224),
target_image_size=image_size,
patch_size=16,
num_additional_tokens=1,
)
# Load checkpoint
model.load_state_dict(checkpoint)
return _ViT(model, patch_size=16, class_token=True)
def self_patch_small(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
"""Builds the pre-trained Self-Patch ViT small model (patch size is 16 x 16).
Args:
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224).
Returns:
model (nn.Module): Self-Patch ViT model as a nn.Module
"""
# Init model
model = VisionTransformer(
img_size=image_size,
patch_size=16,
num_classes=0,
class_token=False,
qkv_bias=True,
global_pool="avg",
embed_dim=384,
depth=12,
num_heads=6,
)
# Force not to use fused attention to access attention maps
_disable_fused_attention(model)
# Load and adapt checkpoint
checkpoint: Dict[str, Tensor] = torch.load(
os.path.join(pathlib.Path(__file__).parent.resolve(), "checkpoints/dino_selfpatch.pth")
)
checkpoint = {
key.replace("module.", ""): value for key, value in checkpoint.items() if key in model.state_dict().keys()
}
# Interpolated positional embeddings
if image_size != (224, 224):
checkpoint["pos_embed"] = _interpolate_positional_embeddings(
checkpoint["pos_embed"],
original_image_size=(224, 224),
target_image_size=image_size,
patch_size=16,
num_additional_tokens=0,
)
# Load checkpoint
model.load_state_dict(checkpoint, strict=False)
return _ViT(model, patch_size=16, class_token=False, registers=False)
def i_jepa_huge(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
"""Builds the pre-trained I-JEPA ViT huge model (patch size is 14 x 14).
Notes:
ViT huge is very large...
Args:
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224).
Returns:
model (nn.Module): SynCLR ViT model as a nn.Module
"""
# Init model
model = VisionTransformer(
img_size=image_size,
patch_size=14,
num_classes=0,
class_token=False,
global_pool="avg",
qkv_bias=True,
embed_dim=1280,
depth=32,
num_heads=16,
)
# Force not to use fused attention to access attention maps
_disable_fused_attention(model)
# Load and adapt checkpoint
checkpoint: Dict[str, Tensor] = torch.load(
os.path.join(pathlib.Path(__file__).parent.resolve(), "checkpoints/IN22k-vit.h.14-900e.pth.tar"),
map_location="cpu",
)["encoder"]
checkpoint = {key.replace("module.", ""): value for key, value in checkpoint.items()}
# Interpolated positional embeddings
if image_size != (224, 224):
checkpoint["pos_embed"] = _interpolate_positional_embeddings(
checkpoint["pos_embed"],
original_image_size=(224, 224),
target_image_size=image_size,
patch_size=14,
num_additional_tokens=0,
)
# Load checkpoint
model.load_state_dict(checkpoint, strict=False)
return _ViT(model, patch_size=14, class_token=False, registers=False)