Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Copyright (c) Meta Platforms, Inc. All Rights Reserved | |
https://github.com/jeonsworld/ViT-pytorch/blob/main/models/configs.py | |
""" | |
import ml_collections | |
def get_testing(): | |
"""Returns a minimal configuration for testing.""" | |
config = ml_collections.ConfigDict() | |
config.patches = ml_collections.ConfigDict({'size': (16, 16)}) | |
config.hidden_size = 1 | |
config.transformer = ml_collections.ConfigDict() | |
config.transformer.mlp_dim = 1 | |
config.transformer.num_heads = 1 | |
config.transformer.num_layers = 1 | |
config.transformer.attention_dropout_rate = 0.0 | |
config.transformer.dropout_rate = 0.1 | |
config.classifier = 'token' | |
config.representation_size = None | |
return config | |
def get_b16_config(): | |
"""Returns the ViT-B/16 configuration.""" | |
config = ml_collections.ConfigDict() | |
config.patches = ml_collections.ConfigDict({'size': (16, 16)}) | |
config.hidden_size = 768 | |
config.transformer = ml_collections.ConfigDict() | |
config.transformer.mlp_dim = 3072 | |
config.transformer.num_heads = 12 | |
config.transformer.num_layers = 12 | |
config.transformer.attention_dropout_rate = 0.0 | |
config.transformer.dropout_rate = 0.1 | |
config.classifier = 'token' | |
config.representation_size = None | |
return config | |
def get_r50_b16_config(): | |
"""Returns the Resnet50 + ViT-B/16 configuration.""" | |
config = get_b16_config() | |
del config.patches.size | |
config.patches.grid = (14, 14) | |
config.resnet = ml_collections.ConfigDict() | |
config.resnet.num_layers = (3, 4, 9) | |
config.resnet.width_factor = 1 | |
return config | |
def get_b32_config(): | |
"""Returns the ViT-B/32 configuration.""" | |
config = get_b16_config() | |
config.patches.size = (32, 32) | |
return config | |
def get_b8_config(): | |
"""Returns the ViT-B/32 configuration.""" | |
config = get_b16_config() | |
config.patches.size = (8, 8) | |
return config | |
def get_l16_config(): | |
"""Returns the ViT-L/16 configuration.""" | |
config = ml_collections.ConfigDict() | |
config.patches = ml_collections.ConfigDict({'size': (16, 16)}) | |
config.hidden_size = 1024 | |
config.transformer = ml_collections.ConfigDict() | |
config.transformer.mlp_dim = 4096 | |
config.transformer.num_heads = 16 | |
config.transformer.num_layers = 24 | |
config.transformer.attention_dropout_rate = 0.0 | |
config.transformer.dropout_rate = 0.1 | |
config.classifier = 'token' | |
config.representation_size = None | |
return config | |
def get_l32_config(): | |
"""Returns the ViT-L/32 configuration.""" | |
config = get_l16_config() | |
config.patches.size = (32, 32) | |
return config | |
def get_h14_config(): | |
"""Returns the ViT-L/16 configuration.""" | |
config = ml_collections.ConfigDict() | |
config.patches = ml_collections.ConfigDict({'size': (14, 14)}) | |
config.hidden_size = 1280 | |
config.transformer = ml_collections.ConfigDict() | |
config.transformer.mlp_dim = 5120 | |
config.transformer.num_heads = 16 | |
config.transformer.num_layers = 32 | |
config.transformer.attention_dropout_rate = 0.0 | |
config.transformer.dropout_rate = 0.1 | |
config.classifier = 'token' | |
config.representation_size = None | |
return config | |