Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import build_norm_layer | |
| from mmcv.cnn.bricks.transformer import BaseTransformerLayer | |
| from mmengine.model import BaseModule, ModuleList | |
| from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict | |
| from torch.nn import functional as F | |
| from mmseg.registry import MODELS | |
| from mmseg.utils import get_classes, get_predefined_templates, tokenizer | |
| class CLIPTextEncoder(BaseModule): | |
| """A text encoder with transformer architecture to encode the label text. | |
| Modified from https://github.com/MendelXu/SAN/blob/main/san/model/clip_utils/classifier.py # noqa:E501 | |
| Copyright (c) 2023 MendelXu. | |
| Licensed under the MIT License | |
| Args: | |
| dataset_name: (str|None): The name of the dataset to which | |
| the data belongs. | |
| vocabulary: (List[str]|None): The list of class names. Default: None. | |
| templates: (List[str]|None): The prompt template used for labels. | |
| Default: None. | |
| total_vocab_size: (int): Number of all words used by the pre-trained | |
| model. Default: 49408 (CLIP). | |
| context_length: (int): The max length of prompt text. | |
| Default: 77 (CLIP). | |
| embed_dims: (int): Width of transformer model. Default: 512. | |
| num_layers: (int): Depth of transformer. Default: 12, | |
| num_heads: (int): Number of attention heads in transformer. | |
| Default: 8, | |
| mlp_ratio: (int) Ratio of mlp hidden dim to embedding dim in | |
| transformer. Default: 4, | |
| output_dims: (int) Dim of output text embeddings. Default: 512, | |
| cache_feature: (bool) Whether to save class embeddings in cache. | |
| Default: True, | |
| cat_bg: (bool) Whether to add background embedding. Default: True. | |
| norm_cfg (dict|None): Config for norm layer. Default: dict(type='LN') | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Default: None. | |
| """ | |
| def __init__(self, | |
| dataset_name: str = None, | |
| vocabulary: List[str] = None, | |
| templates: str = 'vild', | |
| total_vocab_size: int = 49408, | |
| context_length: int = 77, | |
| embed_dims: int = 512, | |
| num_layers: int = 12, | |
| num_heads: int = 8, | |
| mlp_ratio: int = 4, | |
| output_dims: int = 512, | |
| cache_feature: bool = True, | |
| cat_bg: bool = True, | |
| norm_cfg: dict = dict(type='LN'), | |
| init_cfg: dict = None): | |
| super().__init__(init_cfg) | |
| if isinstance(templates, List): | |
| self.templates = templates | |
| else: | |
| self.templates = get_predefined_templates(templates) | |
| assert dataset_name is not None or vocabulary is not None, \ | |
| "text_encoder required either 'dataset_name' or 'vocabulary'" | |
| assert dataset_name is None or vocabulary is None, \ | |
| "there is conflict between 'dataset_name' and 'vocabulary'" | |
| self.dataset_name = dataset_name | |
| self.vocabulary = vocabulary | |
| self.num_pos = context_length | |
| self.token_embedding = nn.Embedding(total_vocab_size, embed_dims) | |
| self.positional_embedding = nn.Parameter( | |
| torch.empty(context_length, embed_dims)) | |
| self.text_projection = nn.Parameter( | |
| torch.empty(embed_dims, output_dims)) | |
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
| self.transformer = ModuleList() | |
| self.register_buffer( | |
| 'attn_mask', self.build_attention_mask(), persistent=False) | |
| for i in range(num_layers): | |
| self.transformer.append( | |
| BaseTransformerLayer( | |
| attn_cfgs=dict( | |
| type='MultiheadAttention', | |
| embed_dims=embed_dims, | |
| num_heads=num_heads, | |
| batch_first=False, | |
| bias=True), | |
| ffn_cfgs=dict( | |
| type='FFN', | |
| embed_dims=embed_dims, | |
| feedforward_channels=mlp_ratio * embed_dims, | |
| act_cfg=dict(type='QuickGELU')), | |
| operation_order=('norm', 'self_attn', 'norm', 'ffn'))) | |
| self.ln_final = build_norm_layer( | |
| norm_cfg, embed_dims, postfix='_final')[1] | |
| self.cache_feature = cache_feature | |
| if self.cache_feature: | |
| self.cache = {} | |
| self._freeze() | |
| self.cat_bg = cat_bg | |
| if self.cat_bg: | |
| self.bg_embed = nn.Parameter( | |
| torch.randn(1, self.text_projection.shape[1])) | |
| def ln_final(self): | |
| return getattr(self, self.final_name) | |
| def build_attention_mask(self): | |
| """lazily create causal attention mask, with full attention between the | |
| tokens. | |
| pytorch uses additive attention mask; fill with -inf | |
| """ | |
| mask = torch.empty(self.num_pos, self.num_pos) | |
| mask.fill_(float('-inf')) | |
| mask.triu_(1) # zero out the lower diagonal | |
| return mask | |
| def _freeze(self): | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def init_weights(self): | |
| if self.cat_bg: | |
| nn.init.normal_( | |
| self.bg_embed, | |
| std=self.bg_embed.shape[1]**-0.5, | |
| ) | |
| if isinstance(self.init_cfg, dict) and \ | |
| self.init_cfg.get('type') == 'Pretrained_Part': | |
| checkpoint = CheckpointLoader.load_checkpoint( | |
| self.init_cfg['checkpoint'], logger=None, map_location='cpu') | |
| state_dict = checkpoint.copy() | |
| para_prefix = 'text_encoder' | |
| prefix_len = len(para_prefix) + 1 | |
| for k, v in checkpoint.items(): | |
| state_dict.pop(k) | |
| if para_prefix in k: | |
| state_dict[k[prefix_len:]] = v | |
| load_state_dict(self, state_dict, strict=False, logger=None) | |
| else: | |
| super().init_weights() | |
| def encode_text(self, text, normalize=False): | |
| """encode class token.""" | |
| embed_device = self.token_embedding.weight.device | |
| x = self.token_embedding( | |
| text.to(embed_device)) # [batch_size, n_ctx, d_model] | |
| x = x + self.positional_embedding | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| for block in self.transformer: | |
| x = block(query=x, attn_masks=self.attn_mask) | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] | |
| # take features from the eot embedding | |
| # (eot_token is the highest number in each sequence) | |
| x = x[torch.arange(x.shape[0]), | |
| text.argmax(dim=-1)] @ self.text_projection | |
| return F.normalize(x, dim=-1) if normalize else x | |
| def template_encode(self, vocabulary): | |
| """Prompt engineering.""" | |
| text_embed_bucket = [] | |
| for template in self.templates: | |
| text_inputs = tokenizer.tokenize( | |
| [template.format(noun) for noun in vocabulary]) | |
| text_embed = self.encode_text(text_inputs, normalize=True) | |
| text_embed_bucket.append(text_embed) | |
| text_embed = torch.stack(text_embed_bucket).mean(dim=0) | |
| text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) | |
| return text_embed | |
| def forward(self): | |
| """Forward function.""" | |
| if self.dataset_name is None: # encoding vocabulary directly | |
| class_names = self.vocabulary | |
| if self.cache_feature: | |
| new_classes = [ | |
| word for word in class_names if word not in self.cache | |
| ] | |
| if len(new_classes) > 0: | |
| class_embeds = self.template_encode(new_classes) | |
| self.cache.update(dict(zip(new_classes, class_embeds))) | |
| class_embeds = torch.stack( | |
| [self.cache[word] for word in class_names]) | |
| else: | |
| class_embeds = self.template_encode(class_names) | |
| else: # encoding the classes of the dataset | |
| class_names = get_classes(self.dataset_name) | |
| if class_names[0] == 'background': | |
| class_names = class_names[1:] | |
| if self.cache_feature: | |
| if self.dataset_name not in self.cache: | |
| class_embeds = self.template_encode(class_names) | |
| self.cache[self.dataset_name] = class_embeds | |
| else: | |
| class_embeds = self.cache[self.dataset_name] | |
| else: | |
| class_embeds = self.template_encode(class_names) | |
| if self.cat_bg: | |
| class_embeds = torch.cat([class_embeds, self.bg_embed]) | |
| class_embeds = F.normalize(class_embeds, p=2, dim=-1) | |
| return self.logit_scale.exp() * class_embeds | |
| class QuickGELU(nn.Module): | |
| # From https://github.com/openai/CLIP/blob/main/clip/model.py | |
| def forward(self, x: torch.Tensor): | |
| return x * torch.sigmoid(1.702 * x) | |