Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (C) 2024 Apple Inc. All Rights Reserved. | |
| # Factory functions to build and load ViT models. | |
| from __future__ import annotations | |
| import logging | |
| import types | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Literal, Optional | |
| import timm | |
| import torch | |
| import torch.nn as nn | |
| from .vit import ( | |
| forward_features_eva_fixed, | |
| make_vit_b16_backbone, | |
| resize_patch_embed, | |
| resize_vit, | |
| ) | |
| LOGGER = logging.getLogger(__name__) | |
| ViTPreset = Literal[ | |
| "dinov2l16_384", | |
| ] | |
| class ViTConfig: | |
| """Configuration for ViT.""" | |
| in_chans: int | |
| embed_dim: int | |
| img_size: int = 384 | |
| patch_size: int = 16 | |
| # In case we need to rescale the backbone when loading from timm. | |
| timm_preset: Optional[str] = None | |
| timm_img_size: int = 384 | |
| timm_patch_size: int = 16 | |
| # The following 2 parameters are only used by DPT. See dpt_factory.py. | |
| encoder_feature_layer_ids: List[int] = None | |
| """The layers in the Beit/ViT used to constructs encoder features for DPT.""" | |
| encoder_feature_dims: List[int] = None | |
| """The dimension of features of encoder layers from Beit/ViT features for DPT.""" | |
| VIT_CONFIG_DICT: Dict[ViTPreset, ViTConfig] = { | |
| "dinov2l16_384": ViTConfig( | |
| in_chans=3, | |
| embed_dim=1024, | |
| encoder_feature_layer_ids=[5, 11, 17, 23], | |
| encoder_feature_dims=[256, 512, 1024, 1024], | |
| img_size=384, | |
| patch_size=16, | |
| timm_preset="vit_large_patch14_dinov2", | |
| timm_img_size=518, | |
| timm_patch_size=14, | |
| ), | |
| } | |
| def create_vit( | |
| preset: ViTPreset, | |
| use_pretrained: bool = False, | |
| checkpoint_uri: str | None = None, | |
| use_grad_checkpointing: bool = False, | |
| ) -> nn.Module: | |
| """Create and load a VIT backbone module. | |
| Args: | |
| ---- | |
| preset: The VIT preset to load the pre-defined config. | |
| use_pretrained: Load pretrained weights if True, default is False. | |
| checkpoint_uri: Checkpoint to load the wights from. | |
| use_grad_checkpointing: Use grandient checkpointing. | |
| Returns: | |
| ------- | |
| A Torch ViT backbone module. | |
| """ | |
| config = VIT_CONFIG_DICT[preset] | |
| img_size = (config.img_size, config.img_size) | |
| patch_size = (config.patch_size, config.patch_size) | |
| if "eva02" in preset: | |
| model = timm.create_model(config.timm_preset, pretrained=use_pretrained) | |
| model.forward_features = types.MethodType(forward_features_eva_fixed, model) | |
| else: | |
| model = timm.create_model( | |
| config.timm_preset, pretrained=use_pretrained, dynamic_img_size=True | |
| ) | |
| model = make_vit_b16_backbone( | |
| model, | |
| encoder_feature_dims=config.encoder_feature_dims, | |
| encoder_feature_layer_ids=config.encoder_feature_layer_ids, | |
| vit_features=config.embed_dim, | |
| use_grad_checkpointing=use_grad_checkpointing, | |
| ) | |
| if config.patch_size != config.timm_patch_size: | |
| model.model = resize_patch_embed(model.model, new_patch_size=patch_size) | |
| if config.img_size != config.timm_img_size: | |
| model.model = resize_vit(model.model, img_size=img_size) | |
| if checkpoint_uri is not None: | |
| state_dict = torch.load(checkpoint_uri, map_location="cpu") | |
| missing_keys, unexpected_keys = model.load_state_dict( | |
| state_dict=state_dict, strict=False | |
| ) | |
| if len(unexpected_keys) != 0: | |
| raise KeyError(f"Found unexpected keys when loading vit: {unexpected_keys}") | |
| if len(missing_keys) != 0: | |
| raise KeyError(f"Keys are missing when loading vit: {missing_keys}") | |
| LOGGER.info(model) | |
| return model.model | |