Spaces:
Build error
Build error
| ''' | |
| Implementation of ViTSTR based on timm VisionTransformer. | |
| TODO: | |
| 1) distilled deit backbone | |
| 2) base deit backbone | |
| Copyright 2021 Rowel Atienza | |
| ''' | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import torch | |
| import torch.nn as nn | |
| import logging | |
| import torch.utils.model_zoo as model_zoo | |
| from copy import deepcopy | |
| from functools import partial | |
| from timm.models.vision_transformer import VisionTransformer, _cfg | |
| from timm.models.registry import register_model | |
| from timm.models import create_model | |
| _logger = logging.getLogger(__name__) | |
| __all__ = [ | |
| 'vitstr_tiny_patch16_224', | |
| 'vitstr_small_patch16_224', | |
| 'vitstr_base_patch16_224', | |
| #'vitstr_tiny_distilled_patch16_224', | |
| #'vitstr_small_distilled_patch16_224', | |
| #'vitstr_base_distilled_patch16_224', | |
| ] | |
| def create_vitstr(num_tokens, model=None, checkpoint_path=''): | |
| vitstr = create_model( | |
| model, | |
| pretrained=True, | |
| num_classes=num_tokens, | |
| checkpoint_path=checkpoint_path) | |
| # might need to run to get zero init head for transfer learning | |
| vitstr.reset_classifier(num_classes=num_tokens) | |
| return vitstr | |
| class ViTSTR(VisionTransformer): | |
| ''' | |
| ViTSTR is basically a ViT that uses DeiT weights. | |
| Modified head to support a sequence of characters prediction for STR. | |
| ''' | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def reset_classifier(self, num_classes): | |
| self.num_classes = num_classes | |
| self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() | |
| def patch_embed_func(self): | |
| return self.patch_embed | |
| def forward_features(self, x): | |
| B = x.shape[0] | |
| # print("prevx shape: ", x.shape) ### (1, 224, 224) | |
| x = self.patch_embed(x) | |
| # print("new x shape: ", x.shape) ### (1, 196, 768) | |
| # patchsize is 16X16 so there are 14X14 grids=196. | |
| # 768 - embedding size | |
| # self.cls_token shape: torch.Size([1, 1, 768]) | |
| cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks | |
| x = torch.cat((cls_tokens, x), dim=1) | |
| # self.pos_embed shape: torch.Size([1, 197, 768])] | |
| x = x + self.pos_embed | |
| # + self.pos_embed shape: torch.Size([1, 197, 768]) | |
| x = self.pos_drop(x) | |
| for blk in self.blocks: | |
| x = blk(x) | |
| # blocks shape: torch.Size([1, 197, 768]) ALLL | |
| x = self.norm(x) | |
| # norm shape: torch.Size([1, 197, 768]) | |
| return x | |
| def forward(self, x, seqlen=25): | |
| x = self.forward_features(x) | |
| x = x[:, :seqlen] | |
| # seqlen shape: torch.Size([1, 25, 768]) | |
| # batch, seqlen, embsize | |
| b, s, e = x.size() | |
| x = x.reshape(b*s, e) | |
| # reshaped shape: torch.Size([25, 768]) | |
| x = self.head(x).view(b, s, self.num_classes) | |
| return x | |
| def load_pretrained(model, cfg=None, num_classes=1000, in_chans=1, filter_fn=None, strict=True): | |
| ''' | |
| Loads a pretrained checkpoint | |
| From an older version of timm | |
| ''' | |
| if cfg is None: | |
| cfg = getattr(model, 'default_cfg') | |
| if cfg is None or 'url' not in cfg or not cfg['url']: | |
| _logger.warning("Pretrained model URL is invalid, using random initialization.") | |
| return | |
| state_dict = model_zoo.load_url(cfg['url'], progress=True, map_location='cpu') | |
| if "model" in state_dict.keys(): | |
| state_dict = state_dict["model"] | |
| if filter_fn is not None: | |
| state_dict = filter_fn(state_dict) | |
| if in_chans == 1: | |
| conv1_name = cfg['first_conv'] | |
| _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name) | |
| key = conv1_name + '.weight' | |
| if key in state_dict.keys(): | |
| _logger.info('(%s) key found in state_dict' % key) | |
| conv1_weight = state_dict[conv1_name + '.weight'] | |
| else: | |
| _logger.info('(%s) key NOT found in state_dict' % key) | |
| return | |
| # Some weights are in torch.half, ensure it's float for sum on CPU | |
| conv1_type = conv1_weight.dtype | |
| conv1_weight = conv1_weight.float() | |
| O, I, J, K = conv1_weight.shape | |
| if I > 3: | |
| assert conv1_weight.shape[1] % 3 == 0 | |
| # For models with space2depth stems | |
| conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) | |
| conv1_weight = conv1_weight.sum(dim=2, keepdim=False) | |
| else: | |
| conv1_weight = conv1_weight.sum(dim=1, keepdim=True) | |
| conv1_weight = conv1_weight.to(conv1_type) | |
| state_dict[conv1_name + '.weight'] = conv1_weight | |
| classifier_name = cfg['classifier'] | |
| if num_classes == 1000 and cfg['num_classes'] == 1001: | |
| # special case for imagenet trained models with extra background class in pretrained weights | |
| classifier_weight = state_dict[classifier_name + '.weight'] | |
| state_dict[classifier_name + '.weight'] = classifier_weight[1:] | |
| classifier_bias = state_dict[classifier_name + '.bias'] | |
| state_dict[classifier_name + '.bias'] = classifier_bias[1:] | |
| elif num_classes != cfg['num_classes']: | |
| # completely discard fully connected for all other differences between pretrained and created model | |
| del state_dict[classifier_name + '.weight'] | |
| del state_dict[classifier_name + '.bias'] | |
| strict = False | |
| print("Loading pre-trained vision transformer weights from %s ..." % cfg['url']) | |
| model.load_state_dict(state_dict, strict=strict) | |
| def _conv_filter(state_dict, patch_size=16): | |
| """ convert patch embedding weight from manual patchify + linear proj to conv""" | |
| out_dict = {} | |
| for k, v in state_dict.items(): | |
| if 'patch_embed.proj.weight' in k: | |
| v = v.reshape((v.shape[0], 3, patch_size, patch_size)) | |
| out_dict[k] = v | |
| return out_dict | |
| def vitstr_tiny_patch16_224(pretrained=False, **kwargs): | |
| kwargs['in_chans'] = 1 | |
| model = ViTSTR( | |
| patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, **kwargs) | |
| model.default_cfg = _cfg( | |
| #url='https://github.com/roatienza/public/releases/download/v0.1-deit-tiny/deit_tiny_patch16_224-a1311bcf.pth' | |
| url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth' | |
| ) | |
| if pretrained: | |
| load_pretrained( | |
| model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) | |
| return model | |
| def vitstr_small_patch16_224(pretrained=False, **kwargs): | |
| kwargs['in_chans'] = 1 | |
| model = ViTSTR( | |
| patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, **kwargs) | |
| model.default_cfg = _cfg( | |
| #url="https://github.com/roatienza/public/releases/download/v0.1-deit-small/deit_small_patch16_224-cd65a155.pth" | |
| url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth" | |
| ) | |
| if pretrained: | |
| load_pretrained( | |
| model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) | |
| return model | |
| def vitstr_base_patch16_224(pretrained=False, **kwargs): | |
| kwargs['in_chans'] = 1 | |
| model = ViTSTR( | |
| patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs) | |
| model.default_cfg = _cfg( | |
| #url='https://github.com/roatienza/public/releases/download/v0.1-deit-base/deit_base_patch16_224-b5f2ef4d.pth' | |
| url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth' | |
| ) | |
| if pretrained: | |
| load_pretrained( | |
| model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) | |
| return model | |
| # below is work in progress | |
| def vitstr_tiny_distilled_patch16_224(pretrained=False, **kwargs): | |
| kwargs['in_chans'] = 1 | |
| #kwargs['distilled'] = True | |
| model = ViTSTR( | |
| patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, **kwargs) | |
| model.default_cfg = _cfg( | |
| url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth' | |
| ) | |
| if pretrained: | |
| load_pretrained( | |
| model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) | |
| return model | |
| def vitstr_small_distilled_patch16_224(pretrained=False, **kwargs): | |
| kwargs['in_chans'] = 1 | |
| kwargs['distilled'] = True | |
| model = ViTSTR( | |
| patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, **kwargs) | |
| model.default_cfg = _cfg( | |
| url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth" | |
| ) | |
| if pretrained: | |
| load_pretrained( | |
| model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) | |
| return model | |