Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import build_norm_layer | |
| from mmengine.model import ModuleList | |
| from mmengine.model.weight_init import (constant_init, trunc_normal_, | |
| trunc_normal_init) | |
| from mmseg.models.backbones.vit import TransformerEncoderLayer | |
| from mmseg.registry import MODELS | |
| from .decode_head import BaseDecodeHead | |
| class SegmenterMaskTransformerHead(BaseDecodeHead): | |
| """Segmenter: Transformer for Semantic Segmentation. | |
| This head is the implementation of | |
| `Segmenter: <https://arxiv.org/abs/2105.05633>`_. | |
| Args: | |
| backbone_cfg:(dict): Config of backbone of | |
| Context Path. | |
| in_channels (int): The number of channels of input image. | |
| num_layers (int): The depth of transformer. | |
| num_heads (int): The number of attention heads. | |
| embed_dims (int): The number of embedding dimension. | |
| mlp_ratio (int): ratio of mlp hidden dim to embedding dim. | |
| Default: 4. | |
| drop_path_rate (float): stochastic depth rate. Default 0.1. | |
| drop_rate (float): Probability of an element to be zeroed. | |
| Default 0.0 | |
| attn_drop_rate (float): The drop out rate for attention layer. | |
| Default 0.0 | |
| num_fcs (int): The number of fully-connected layers for FFNs. | |
| Default: 2. | |
| qkv_bias (bool): Enable bias for qkv if True. Default: True. | |
| act_cfg (dict): The activation config for FFNs. | |
| Default: dict(type='GELU'). | |
| norm_cfg (dict): Config dict for normalization layer. | |
| Default: dict(type='LN') | |
| init_std (float): The value of std in weight initialization. | |
| Default: 0.02. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels, | |
| num_layers, | |
| num_heads, | |
| embed_dims, | |
| mlp_ratio=4, | |
| drop_path_rate=0.1, | |
| drop_rate=0.0, | |
| attn_drop_rate=0.0, | |
| num_fcs=2, | |
| qkv_bias=True, | |
| act_cfg=dict(type='GELU'), | |
| norm_cfg=dict(type='LN'), | |
| init_std=0.02, | |
| **kwargs, | |
| ): | |
| super().__init__(in_channels=in_channels, **kwargs) | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] | |
| self.layers = ModuleList() | |
| for i in range(num_layers): | |
| self.layers.append( | |
| TransformerEncoderLayer( | |
| embed_dims=embed_dims, | |
| num_heads=num_heads, | |
| feedforward_channels=mlp_ratio * embed_dims, | |
| attn_drop_rate=attn_drop_rate, | |
| drop_rate=drop_rate, | |
| drop_path_rate=dpr[i], | |
| num_fcs=num_fcs, | |
| qkv_bias=qkv_bias, | |
| act_cfg=act_cfg, | |
| norm_cfg=norm_cfg, | |
| batch_first=True, | |
| )) | |
| self.dec_proj = nn.Linear(in_channels, embed_dims) | |
| self.cls_emb = nn.Parameter( | |
| torch.randn(1, self.num_classes, embed_dims)) | |
| self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False) | |
| self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False) | |
| self.decoder_norm = build_norm_layer( | |
| norm_cfg, embed_dims, postfix=1)[1] | |
| self.mask_norm = build_norm_layer( | |
| norm_cfg, self.num_classes, postfix=2)[1] | |
| self.init_std = init_std | |
| delattr(self, 'conv_seg') | |
| def init_weights(self): | |
| trunc_normal_(self.cls_emb, std=self.init_std) | |
| trunc_normal_init(self.patch_proj, std=self.init_std) | |
| trunc_normal_init(self.classes_proj, std=self.init_std) | |
| for n, m in self.named_modules(): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_init(m, std=self.init_std, bias=0) | |
| elif isinstance(m, nn.LayerNorm): | |
| constant_init(m, val=1.0, bias=0.0) | |
| def forward(self, inputs): | |
| x = self._transform_inputs(inputs) | |
| b, c, h, w = x.shape | |
| x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c) | |
| x = self.dec_proj(x) | |
| cls_emb = self.cls_emb.expand(x.size(0), -1, -1) | |
| x = torch.cat((x, cls_emb), 1) | |
| for layer in self.layers: | |
| x = layer(x) | |
| x = self.decoder_norm(x) | |
| patches = self.patch_proj(x[:, :-self.num_classes]) | |
| cls_seg_feat = self.classes_proj(x[:, -self.num_classes:]) | |
| patches = F.normalize(patches, dim=2, p=2) | |
| cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2) | |
| masks = patches @ cls_seg_feat.transpose(1, 2) | |
| masks = self.mask_norm(masks) | |
| masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w) | |
| return masks | |