Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule, Linear, build_activation_layer | |
| from mmengine.model import BaseModule | |
| from mmseg.registry import MODELS | |
| from ..utils import resize | |
| from .decode_head import BaseDecodeHead | |
| class ReassembleBlocks(BaseModule): | |
| """ViTPostProcessBlock, process cls_token in ViT backbone output and | |
| rearrange the feature vector to feature map. | |
| Args: | |
| in_channels (int): ViT feature channels. Default: 768. | |
| out_channels (List): output channels of each stage. | |
| Default: [96, 192, 384, 768]. | |
| readout_type (str): Type of readout operation. Default: 'ignore'. | |
| patch_size (int): The patch size. Default: 16. | |
| init_cfg (dict, optional): Initialization config dict. Default: None. | |
| """ | |
| def __init__(self, | |
| in_channels=768, | |
| out_channels=[96, 192, 384, 768], | |
| readout_type='ignore', | |
| patch_size=16, | |
| init_cfg=None): | |
| super().__init__(init_cfg) | |
| assert readout_type in ['ignore', 'add', 'project'] | |
| self.readout_type = readout_type | |
| self.patch_size = patch_size | |
| self.projects = nn.ModuleList([ | |
| ConvModule( | |
| in_channels=in_channels, | |
| out_channels=out_channel, | |
| kernel_size=1, | |
| act_cfg=None, | |
| ) for out_channel in out_channels | |
| ]) | |
| self.resize_layers = nn.ModuleList([ | |
| nn.ConvTranspose2d( | |
| in_channels=out_channels[0], | |
| out_channels=out_channels[0], | |
| kernel_size=4, | |
| stride=4, | |
| padding=0), | |
| nn.ConvTranspose2d( | |
| in_channels=out_channels[1], | |
| out_channels=out_channels[1], | |
| kernel_size=2, | |
| stride=2, | |
| padding=0), | |
| nn.Identity(), | |
| nn.Conv2d( | |
| in_channels=out_channels[3], | |
| out_channels=out_channels[3], | |
| kernel_size=3, | |
| stride=2, | |
| padding=1) | |
| ]) | |
| if self.readout_type == 'project': | |
| self.readout_projects = nn.ModuleList() | |
| for _ in range(len(self.projects)): | |
| self.readout_projects.append( | |
| nn.Sequential( | |
| Linear(2 * in_channels, in_channels), | |
| build_activation_layer(dict(type='GELU')))) | |
| def forward(self, inputs): | |
| assert isinstance(inputs, list) | |
| out = [] | |
| for i, x in enumerate(inputs): | |
| assert len(x) == 2 | |
| x, cls_token = x[0], x[1] | |
| feature_shape = x.shape | |
| if self.readout_type == 'project': | |
| x = x.flatten(2).permute((0, 2, 1)) | |
| readout = cls_token.unsqueeze(1).expand_as(x) | |
| x = self.readout_projects[i](torch.cat((x, readout), -1)) | |
| x = x.permute(0, 2, 1).reshape(feature_shape) | |
| elif self.readout_type == 'add': | |
| x = x.flatten(2) + cls_token.unsqueeze(-1) | |
| x = x.reshape(feature_shape) | |
| else: | |
| pass | |
| x = self.projects[i](x) | |
| x = self.resize_layers[i](x) | |
| out.append(x) | |
| return out | |
| class PreActResidualConvUnit(BaseModule): | |
| """ResidualConvUnit, pre-activate residual unit. | |
| Args: | |
| in_channels (int): number of channels in the input feature map. | |
| act_cfg (dict): dictionary to construct and config activation layer. | |
| norm_cfg (dict): dictionary to construct and config norm layer. | |
| stride (int): stride of the first block. Default: 1 | |
| dilation (int): dilation rate for convs layers. Default: 1. | |
| init_cfg (dict, optional): Initialization config dict. Default: None. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| act_cfg, | |
| norm_cfg, | |
| stride=1, | |
| dilation=1, | |
| init_cfg=None): | |
| super().__init__(init_cfg) | |
| self.conv1 = ConvModule( | |
| in_channels, | |
| in_channels, | |
| 3, | |
| stride=stride, | |
| padding=dilation, | |
| dilation=dilation, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| bias=False, | |
| order=('act', 'conv', 'norm')) | |
| self.conv2 = ConvModule( | |
| in_channels, | |
| in_channels, | |
| 3, | |
| padding=1, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| bias=False, | |
| order=('act', 'conv', 'norm')) | |
| def forward(self, inputs): | |
| inputs_ = inputs.clone() | |
| x = self.conv1(inputs) | |
| x = self.conv2(x) | |
| return x + inputs_ | |
| class FeatureFusionBlock(BaseModule): | |
| """FeatureFusionBlock, merge feature map from different stages. | |
| Args: | |
| in_channels (int): Input channels. | |
| act_cfg (dict): The activation config for ResidualConvUnit. | |
| norm_cfg (dict): Config dict for normalization layer. | |
| expand (bool): Whether expand the channels in post process block. | |
| Default: False. | |
| align_corners (bool): align_corner setting for bilinear upsample. | |
| Default: True. | |
| init_cfg (dict, optional): Initialization config dict. Default: None. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| act_cfg, | |
| norm_cfg, | |
| expand=False, | |
| align_corners=True, | |
| init_cfg=None): | |
| super().__init__(init_cfg) | |
| self.in_channels = in_channels | |
| self.expand = expand | |
| self.align_corners = align_corners | |
| self.out_channels = in_channels | |
| if self.expand: | |
| self.out_channels = in_channels // 2 | |
| self.project = ConvModule( | |
| self.in_channels, | |
| self.out_channels, | |
| kernel_size=1, | |
| act_cfg=None, | |
| bias=True) | |
| self.res_conv_unit1 = PreActResidualConvUnit( | |
| in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) | |
| self.res_conv_unit2 = PreActResidualConvUnit( | |
| in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) | |
| def forward(self, *inputs): | |
| x = inputs[0] | |
| if len(inputs) == 2: | |
| if x.shape != inputs[1].shape: | |
| res = resize( | |
| inputs[1], | |
| size=(x.shape[2], x.shape[3]), | |
| mode='bilinear', | |
| align_corners=False) | |
| else: | |
| res = inputs[1] | |
| x = x + self.res_conv_unit1(res) | |
| x = self.res_conv_unit2(x) | |
| x = resize( | |
| x, | |
| scale_factor=2, | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| x = self.project(x) | |
| return x | |
| class DPTHead(BaseDecodeHead): | |
| """Vision Transformers for Dense Prediction. | |
| This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_. | |
| Args: | |
| embed_dims (int): The embed dimension of the ViT backbone. | |
| Default: 768. | |
| post_process_channels (List): Out channels of post process conv | |
| layers. Default: [96, 192, 384, 768]. | |
| readout_type (str): Type of readout operation. Default: 'ignore'. | |
| patch_size (int): The patch size. Default: 16. | |
| expand_channels (bool): Whether expand the channels in post process | |
| block. Default: False. | |
| act_cfg (dict): The activation config for residual conv unit. | |
| Default dict(type='ReLU'). | |
| norm_cfg (dict): Config dict for normalization layer. | |
| Default: dict(type='BN'). | |
| """ | |
| def __init__(self, | |
| embed_dims=768, | |
| post_process_channels=[96, 192, 384, 768], | |
| readout_type='ignore', | |
| patch_size=16, | |
| expand_channels=False, | |
| act_cfg=dict(type='ReLU'), | |
| norm_cfg=dict(type='BN'), | |
| **kwargs): | |
| super().__init__(**kwargs) | |
| self.in_channels = self.in_channels | |
| self.expand_channels = expand_channels | |
| self.reassemble_blocks = ReassembleBlocks(embed_dims, | |
| post_process_channels, | |
| readout_type, patch_size) | |
| self.post_process_channels = [ | |
| channel * math.pow(2, i) if expand_channels else channel | |
| for i, channel in enumerate(post_process_channels) | |
| ] | |
| self.convs = nn.ModuleList() | |
| for channel in self.post_process_channels: | |
| self.convs.append( | |
| ConvModule( | |
| channel, | |
| self.channels, | |
| kernel_size=3, | |
| padding=1, | |
| act_cfg=None, | |
| bias=False)) | |
| self.fusion_blocks = nn.ModuleList() | |
| for _ in range(len(self.convs)): | |
| self.fusion_blocks.append( | |
| FeatureFusionBlock(self.channels, act_cfg, norm_cfg)) | |
| self.fusion_blocks[0].res_conv_unit1 = None | |
| self.project = ConvModule( | |
| self.channels, | |
| self.channels, | |
| kernel_size=3, | |
| padding=1, | |
| norm_cfg=norm_cfg) | |
| self.num_fusion_blocks = len(self.fusion_blocks) | |
| self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) | |
| self.num_post_process_channels = len(self.post_process_channels) | |
| assert self.num_fusion_blocks == self.num_reassemble_blocks | |
| assert self.num_reassemble_blocks == self.num_post_process_channels | |
| def forward(self, inputs): | |
| assert len(inputs) == self.num_reassemble_blocks | |
| x = self._transform_inputs(inputs) | |
| x = self.reassemble_blocks(x) | |
| x = [self.convs[i](feature) for i, feature in enumerate(x)] | |
| out = self.fusion_blocks[0](x[-1]) | |
| for i in range(1, len(self.fusion_blocks)): | |
| out = self.fusion_blocks[i](out, x[-(i + 1)]) | |
| out = self.project(out) | |
| out = self.cls_seg(out) | |
| return out | |