Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import pathlib | |
from typing import Dict, Tuple, List | |
import timm | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from timm.models.vision_transformer import VisionTransformer | |
from torch import Tensor | |
from torch.fx import GraphModule | |
from torchvision.models.feature_extraction import create_feature_extractor | |
# from ups.utils import normalize | |
__all__: Tuple[str, ...] = ( | |
"dino_small", | |
"dino_base", | |
"dinov2_small", | |
"dinov2_base", | |
"dino_reg_small", | |
"dino_reg_base", | |
"i_jepa_huge", | |
"mae_base", | |
"self_patch_small", | |
"synclr_base", | |
"mocov3_base", | |
"msn_base", | |
"vmae_large", | |
) | |
def _disable_fused_attention(model: nn.Module) -> None: | |
"""Function disables the use of fused attention in Timm's ViT models. (Don't use anywhere else!) | |
Args: | |
model (nn.Module): Timm ViT model. | |
""" | |
# Get ViT depth | |
depth: int = len(model.blocks) # type: ignore | |
# Disable fused attention for last block | |
for name, module in model.named_modules(): | |
if "Attention" in str(type(module)): | |
if str(depth - 1) in name: | |
module.fused_attn = False # type: ignore | |
def _load_vit(name: str, image_size: Tuple[int, int] = (224, 224), depth: int = 12) -> nn.Module: | |
"""Function to load ViT models from Timm. | |
Args: | |
name (str): Timm name of the model. | |
depth (int): Depth of the model. | |
Returns: | |
model (nn.Module): ViT model as a nn.Module. | |
""" | |
# Load model | |
model: nn.Module = timm.create_model(name, pretrained=True, img_size=image_size, num_classes=0) | |
# Force not to use fused attention to access attention maps | |
# _disable_fused_attention(model) | |
return model | |
def _interpolate_positional_embeddings( | |
positional_embeddings: Tensor, | |
original_image_size: Tuple[int, int], | |
target_image_size: Tuple[int, int], | |
patch_size: int, | |
num_additional_tokens: int = 1, | |
) -> Tensor: | |
"""Function interpolates positional embeddings to a different image size. | |
Args: | |
positional_embeddings (Tensor): Positional embeddings of the sape [1, N, C]. | |
original_image_size (Tuple[int, int]): Original image size as a tuple. | |
target_image_size (Tuple[int, int]): Target image size as a tuple. | |
patch_size (int): Utilize patch size. | |
num_additional_tokens (int): Number of additional tokens used. Default 1 (class token). | |
Returns: | |
positional_embeddings_interpolated (Tensor): Interpolated positional embeddings [1, N_new, C]. | |
""" | |
# Get positional embeddings for image | |
if num_additional_tokens > 0: | |
positional_embeddings_add_tokens: Tensor = positional_embeddings[:, :num_additional_tokens] | |
positional_embeddings_image: Tensor = positional_embeddings[:, num_additional_tokens:] | |
else: | |
positional_embeddings_image = positional_embeddings | |
# Reshape embeddings to 2D | |
positional_embeddings_image = positional_embeddings_image.view( | |
1, original_image_size[0] // patch_size, original_image_size[1] // patch_size, -1 | |
) | |
# Interpolate positional embeddings | |
positional_embeddings_image = F.interpolate( | |
positional_embeddings_image.permute(0, 3, 1, 2), | |
size=(target_image_size[0] // patch_size, target_image_size[1] // patch_size), | |
mode="bicubic", | |
align_corners=False, | |
antialias=False, | |
).permute(0, 2, 3, 1) | |
# Stack positional embeddings again | |
if num_additional_tokens > 0: | |
positional_embeddings_interpolated: Tensor = torch.cat( | |
(positional_embeddings_add_tokens, positional_embeddings_image.flatten(1, 2)), dim=1 | |
) | |
else: | |
positional_embeddings_interpolated = positional_embeddings_image.flatten(1, 2) | |
return positional_embeddings_interpolated | |
class _ViT(nn.Module): | |
"""This class wraps Timm's ViT's and always ensures eval mode.""" | |
def __init__( | |
self, | |
vit: nn.Module, | |
patch_size: int, | |
registers: bool = False, | |
class_token: bool = False, | |
intermediate_features: List[int] = None, | |
) -> None: | |
"""Constructor method. | |
Args: | |
vit (nn.Module): Timm ViT model. | |
patch_size (int): Patch size utilized. | |
registers (bool): Set to true if registers are use. Default False. | |
class_token (bool): Set true if class token is use. Default False. | |
""" | |
# Call super constructor | |
super(_ViT, self).__init__() | |
# Save parameter | |
self.patch_size: int = patch_size | |
self.registers: bool = registers | |
self.class_token: bool = class_token | |
# Get ViT depth | |
depth: int = len(vit.blocks) # type: ignore | |
return_nodes = { | |
#f"blocks.{depth - 1}.attn.softmax": "attention_maps", | |
"norm": "features_normalized", | |
f"blocks.{depth - 1}.attn.getitem_4": "key_features", | |
} | |
if intermediate_features is not None: | |
for idx, feat in enumerate(intermediate_features): | |
return_nodes[f"blocks.{feat}"] = f"intermediate_features.{idx}" | |
# Make FX graph module for feature extraction | |
self.vit: GraphModule = create_feature_extractor(vit, return_nodes) | |
def forward(self, images: Tensor) -> Dict[str, Tensor]: | |
"""Forward pass. | |
Notes: | |
attention_maps have the shape of [B, num heads, N, N]. | |
features have the shape of [B, N, C] | |
Class token and register tokens omitted! | |
Args: | |
images (Tensor): Images (w/ pix. range of [0, 1]) of the shape [B, 3, H, W]. | |
Returns: | |
output_dict (Dict[str, Tensor]): Dict of features ("attention_maps" and "features"). | |
""" | |
# Ensure model is in eval mode | |
self.vit.eval() | |
# Normalize images | |
images_normalized: Tensor = images # normalize(images) | |
# Perform forward pass | |
output_dict: Dict[str, Tensor] = self.vit(images_normalized) | |
# Omit class token (and registers) from attention maps and features | |
if self.registers: | |
output_dict["features_normalized"] = output_dict["features_normalized"][:, 5:] | |
#output_dict["attention_maps"] = output_dict["attention_maps"][..., 5:, 5:] | |
output_dict["key_features"] = output_dict["key_features"][:, :, 5:] | |
for output_key in output_dict: | |
if output_key.startswith('intermediate_features'): | |
output_dict[output_key] = output_dict[output_key][:, 5:] | |
elif self.class_token: | |
output_dict["features_normalized"] = output_dict["features_normalized"][:, 1:] | |
#output_dict["attention_maps"] = output_dict["attention_maps"][..., 1:, 1:] | |
output_dict["key_features"] = output_dict["key_features"][:, :, 1:] | |
for output_key in output_dict: | |
if output_key.startswith('intermediate_features'): | |
output_dict[output_key] = output_dict[output_key][:, 1:] | |
# Normalize features | |
output_dict["features_normalized"] = F.normalize(output_dict["features_normalized"], p=2, dim=2) # [B, N, C] | |
return output_dict | |
def mae_base(image_size: Tuple[int, int] = (224, 224)) -> nn.Module: | |
"""Builds the pre-trained MAE base model (patch size is 16 x 16). | |
Args: | |
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224). | |
Returns: | |
model (nn.Module): ViT MAE model as a nn.Module. | |
""" | |
return _ViT( | |
_load_vit(name="vit_base_patch16_224.mae", image_size=image_size), | |
patch_size=16, | |
class_token=True, | |
) | |
def vmae_large(image_size: Tuple[int, int] = (224, 224)) -> nn.Module: | |
"""Builds the pre-trained video MAE large model (patch size is 16 x 16). | |
Args: | |
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224). | |
Returns: | |
model (nn.Module): ViT MAE model as a nn.Module. | |
""" | |
# Init model | |
model = VisionTransformer( | |
img_size=image_size, patch_size=16, num_classes=0, qkv_bias=True, embed_dim=1024, depth=24, num_heads=16 | |
) | |
# Force not to use fused attention to access attention maps | |
_disable_fused_attention(model) | |
# Load and adapt checkpoint | |
checkpoint: Dict[str, Tensor] = torch.load( | |
os.path.join(pathlib.Path(__file__).parent.resolve(), "checkpoints/mae_pretrain_vit_large_k700.pth"), | |
map_location="cpu", | |
)["model_state"] | |
checkpoint["pos_embed"] = checkpoint["pos_embed_spatial"] + checkpoint["pos_embed_temporal"].mean( | |
dim=1, keepdim=True | |
) | |
checkpoint["pos_embed"] = torch.cat((checkpoint["pos_embed_class"], checkpoint["pos_embed"]), dim=1) | |
checkpoint["patch_embed.proj.weight"] = checkpoint["patch_embed.proj.weight"][:, :, 0] | |
for layer in range(24): | |
checkpoint[f"blocks.{layer}.attn.qkv.weight"] = torch.cat( | |
( | |
checkpoint[f"blocks.{layer}.attn.q.weight"], | |
checkpoint[f"blocks.{layer}.attn.k.weight"], | |
checkpoint[f"blocks.{layer}.attn.v.weight"], | |
), | |
dim=0, | |
) | |
checkpoint[f"blocks.{layer}.attn.qkv.bias"] = torch.cat( | |
( | |
checkpoint[f"blocks.{layer}.attn.q.bias"], | |
checkpoint[f"blocks.{layer}.attn.k.bias"], | |
checkpoint[f"blocks.{layer}.attn.v.bias"], | |
), | |
dim=0, | |
) | |
checkpoint = {key: value for key, value in checkpoint.items() if key in model.state_dict().keys()} | |
# Interpolated positional embeddings | |
checkpoint["pos_embed"] = _interpolate_positional_embeddings( | |
checkpoint["pos_embed"], | |
original_image_size=(224, 224), | |
target_image_size=image_size, | |
patch_size=16, | |
num_additional_tokens=1, | |
) | |
# Load checkpoint | |
model.load_state_dict(checkpoint) | |
return _ViT(model, patch_size=16, class_token=True) | |
def dino_small(image_size: Tuple[int, int] = (224, 224), intermediate_features: List[int] = None) -> nn.Module: | |
"""Builds the pre-trained ViT DINO small model (patch size is 16 x 16). | |
Args: | |
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224). | |
Returns: | |
model (nn.Module): ViT Dino model as a nn.Module. | |
""" | |
return _ViT( | |
_load_vit(name="vit_small_patch16_224.dino", image_size=image_size), | |
patch_size=16, | |
class_token=True, | |
intermediate_features=intermediate_features, | |
) | |
def dino_small8(image_size: Tuple[int, int] = (224, 224), intermediate_features: List[int] = None) -> nn.Module: | |
"""Builds the pre-trained ViT DINO small model (patch size is 8 x 8). | |
Args: | |
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224). | |
Returns: | |
model (nn.Module): ViT Dino model as a nn.Module. | |
""" | |
return _ViT( | |
_load_vit(name="vit_small_patch8_224.dino", image_size=image_size), | |
patch_size=8, | |
class_token=True, | |
intermediate_features=intermediate_features, | |
) | |
def dino_base(image_size: Tuple[int, int] = (224, 224), intermediate_features: List[int] = None) -> nn.Module: | |
"""Builds the pre-trained ViT DINO base model (patch size is 16 x 16). | |
Args: | |
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224). | |
Returns: | |
model (nn.Module): ViT Dino model as a nn.Module. | |
""" | |
return _ViT( | |
_load_vit(name="vit_base_patch16_224.dino", image_size=image_size), | |
patch_size=16, | |
class_token=True, | |
intermediate_features=intermediate_features, | |
) | |
def dino_base8(image_size: Tuple[int, int] = (224, 224), intermediate_features: List[int] = None) -> nn.Module: | |
"""Builds the pre-trained ViT DINO base model (patch size is 16 x 16). | |
Args: | |
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224). | |
Returns: | |
model (nn.Module): ViT Dino model as a nn.Module. | |
""" | |
return _ViT( | |
_load_vit(name="vit_base_patch8_224.dino", image_size=image_size), | |
patch_size=8, | |
class_token=True, | |
intermediate_features=intermediate_features, | |
) | |
def dinov2_small(image_size: Tuple[int, int] = (224, 224), intermediate_features: List[int] = None) -> nn.Module: | |
"""Builds the pre-trained ViT DINO V2 small model (patch size is 14 x 14). | |
Args: | |
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224). | |
intermediate_features (List[int]): Index of intermediate layer features to return. | |
Returns: | |
model (nn.Module): ViT Dino model as a nn.Module. | |
""" | |
return _ViT( | |
_load_vit(name="vit_small_patch14_dinov2.lvd142m", image_size=image_size), | |
patch_size=14, | |
class_token=True, | |
intermediate_features=intermediate_features, | |
) | |
def dinov2_base(image_size: Tuple[int, int] = (224, 224), intermediate_features: List[int] = None) -> nn.Module: | |
"""Builds the pre-trained ViT DINO V2 base model (patch size is 14 x 14). | |
Args: | |
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224). | |
intermediate_features (List[int]): Index of intermediate layer features to return. | |
Returns: | |
model (nn.Module): ViT Dino model as a nn.Module. | |
""" | |
return _ViT( | |
_load_vit(name="vit_base_patch14_dinov2.lvd142m", image_size=image_size), | |
patch_size=14, | |
class_token=True, | |
intermediate_features=intermediate_features, | |
) | |
def dino_reg_small(image_size: Tuple[int, int] = (224, 224), intermediate_features: List[int] = None) -> nn.Module: | |
"""Builds the pre-trained ViT DINO (w/ registers) small model (patch size is 14 x 14). | |
Args: | |
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224). | |
intermediate_features (List[int]): Index of intermediate layer features to return. | |
Returns: | |
model (nn.Module): ViT Dino model as a nn.Module. | |
""" | |
return _ViT( | |
_load_vit(name="vit_small_patch14_reg4_dinov2.lvd142m", image_size=image_size), | |
patch_size=14, | |
registers=True, | |
intermediate_features=intermediate_features, | |
) | |
def dino_reg_base(image_size: Tuple[int, int] = (224, 224), intermediate_features: List[int] = None) -> nn.Module: | |
"""Builds the pre-trained ViT DINO (w/ registers) base model (patch size is 14 x 14). | |
Args: | |
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224). | |
intermediate_features (List[int]): Index of intermediate layer features to return. | |
Returns: | |
model (nn.Module): ViT Dino model as a nn.Module. | |
""" | |
return _ViT( | |
_load_vit(name="vit_base_patch14_reg4_dinov2.lvd142m", image_size=image_size), | |
patch_size=14, | |
registers=True, | |
intermediate_features=intermediate_features, | |
) | |
def synclr_base(image_size: Tuple[int, int] = (224, 224)) -> nn.Module: | |
"""Builds the pre-trained SynCLR ViT base model (patch size is 16 x 16). | |
Args: | |
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224). | |
Returns: | |
model (nn.Module): SynCLR ViT model as a nn.Module | |
""" | |
# Init model | |
model = _load_vit(name="vit_base_patch16_224", image_size=image_size) | |
# Load and adapt checkpoint | |
checkpoint: Dict[str, Tensor] = torch.load( | |
os.path.join(pathlib.Path(__file__).parent.resolve(), "checkpoints/synclr_vit_b_16.pth") | |
)["model"] | |
checkpoint = {key.replace("module.visual.", ""): value for key, value in checkpoint.items()} | |
# Interpolated positional embeddings | |
if image_size != (224, 224): | |
checkpoint["pos_embed"] = _interpolate_positional_embeddings( | |
checkpoint["pos_embed"], | |
original_image_size=(224, 224), | |
target_image_size=image_size, | |
patch_size=16, | |
num_additional_tokens=1, | |
) | |
# Load checkpoint | |
model.load_state_dict(checkpoint) | |
return _ViT(model, patch_size=16, class_token=True) | |
def mocov3_base(image_size: Tuple[int, int] = (224, 224)) -> nn.Module: | |
"""Builds the pre-trained MoCo-V3 ViT base model (patch size is 16 x 16). | |
Args: | |
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224). | |
Returns: | |
model (nn.Module): MoCo-V3 ViT model as a nn.Module | |
""" | |
# Init model | |
model = _load_vit(name="vit_base_patch16_224", image_size=image_size) | |
# Load and adapt checkpoint | |
checkpoint: Dict[str, Tensor] = torch.load( | |
os.path.join(pathlib.Path(__file__).parent.resolve(), "checkpoints/vit-b-300ep.pth.tar") | |
)["state_dict"] | |
checkpoint = { | |
key.replace("module.momentum_encoder.", ""): value | |
for key, value in checkpoint.items() | |
if ("module.momentum_encoder." in key) and ("head." not in key) | |
} | |
# Interpolated positional embeddings | |
if image_size != (224, 224): | |
checkpoint["pos_embed"] = _interpolate_positional_embeddings( | |
checkpoint["pos_embed"], | |
original_image_size=(224, 224), | |
target_image_size=image_size, | |
patch_size=16, | |
num_additional_tokens=1, | |
) | |
# Load checkpoint | |
model.load_state_dict(checkpoint) | |
return _ViT(model, patch_size=16, class_token=True) | |
def msn_base(image_size: Tuple[int, int] = (224, 224)) -> nn.Module: | |
"""Builds the pre-trained MSN ViT base model (patch size is 16 x 16). | |
Args: | |
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224). | |
Returns: | |
model (nn.Module): MSN ViT model as a nn.Module | |
""" | |
# Init model | |
model = _load_vit(name="vit_base_patch16_224", image_size=image_size) | |
# Load and adapt checkpoint | |
checkpoint: Dict[str, Tensor] = torch.load( | |
os.path.join(pathlib.Path(__file__).parent.resolve(), "checkpoints/vitb16_600ep.pth.tar") | |
)["target_encoder"] | |
checkpoint = { | |
key.replace("module.", ""): value | |
for key, value in checkpoint.items() | |
if key.replace("module.", "") in model.state_dict().keys() | |
} | |
# Interpolated positional embeddings | |
if image_size != (224, 224): | |
checkpoint["pos_embed"] = _interpolate_positional_embeddings( | |
checkpoint["pos_embed"], | |
original_image_size=(224, 224), | |
target_image_size=image_size, | |
patch_size=16, | |
num_additional_tokens=1, | |
) | |
# Load checkpoint | |
model.load_state_dict(checkpoint) | |
return _ViT(model, patch_size=16, class_token=True) | |
def self_patch_small(image_size: Tuple[int, int] = (224, 224)) -> nn.Module: | |
"""Builds the pre-trained Self-Patch ViT small model (patch size is 16 x 16). | |
Args: | |
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224). | |
Returns: | |
model (nn.Module): Self-Patch ViT model as a nn.Module | |
""" | |
# Init model | |
model = VisionTransformer( | |
img_size=image_size, | |
patch_size=16, | |
num_classes=0, | |
class_token=False, | |
qkv_bias=True, | |
global_pool="avg", | |
embed_dim=384, | |
depth=12, | |
num_heads=6, | |
) | |
# Force not to use fused attention to access attention maps | |
_disable_fused_attention(model) | |
# Load and adapt checkpoint | |
checkpoint: Dict[str, Tensor] = torch.load( | |
os.path.join(pathlib.Path(__file__).parent.resolve(), "checkpoints/dino_selfpatch.pth") | |
) | |
checkpoint = { | |
key.replace("module.", ""): value for key, value in checkpoint.items() if key in model.state_dict().keys() | |
} | |
# Interpolated positional embeddings | |
if image_size != (224, 224): | |
checkpoint["pos_embed"] = _interpolate_positional_embeddings( | |
checkpoint["pos_embed"], | |
original_image_size=(224, 224), | |
target_image_size=image_size, | |
patch_size=16, | |
num_additional_tokens=0, | |
) | |
# Load checkpoint | |
model.load_state_dict(checkpoint, strict=False) | |
return _ViT(model, patch_size=16, class_token=False, registers=False) | |
def i_jepa_huge(image_size: Tuple[int, int] = (224, 224)) -> nn.Module: | |
"""Builds the pre-trained I-JEPA ViT huge model (patch size is 14 x 14). | |
Notes: | |
ViT huge is very large... | |
Args: | |
image_size (Tuple[int, int]): Image size to be used. Default is (224, 224). | |
Returns: | |
model (nn.Module): SynCLR ViT model as a nn.Module | |
""" | |
# Init model | |
model = VisionTransformer( | |
img_size=image_size, | |
patch_size=14, | |
num_classes=0, | |
class_token=False, | |
global_pool="avg", | |
qkv_bias=True, | |
embed_dim=1280, | |
depth=32, | |
num_heads=16, | |
) | |
# Force not to use fused attention to access attention maps | |
_disable_fused_attention(model) | |
# Load and adapt checkpoint | |
checkpoint: Dict[str, Tensor] = torch.load( | |
os.path.join(pathlib.Path(__file__).parent.resolve(), "checkpoints/IN22k-vit.h.14-900e.pth.tar"), | |
map_location="cpu", | |
)["encoder"] | |
checkpoint = {key.replace("module.", ""): value for key, value in checkpoint.items()} | |
# Interpolated positional embeddings | |
if image_size != (224, 224): | |
checkpoint["pos_embed"] = _interpolate_positional_embeddings( | |
checkpoint["pos_embed"], | |
original_image_size=(224, 224), | |
target_image_size=image_size, | |
patch_size=14, | |
num_additional_tokens=0, | |
) | |
# Load checkpoint | |
model.load_state_dict(checkpoint, strict=False) | |
return _ViT(model, patch_size=14, class_token=False, registers=False) | |