Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule | |
| from mmengine.model import BaseModule | |
| from mmseg.registry import MODELS | |
| from ..utils import resize | |
| class JPU(BaseModule): | |
| """FastFCN: Rethinking Dilated Convolution in the Backbone | |
| for Semantic Segmentation. | |
| This Joint Pyramid Upsampling (JPU) neck is the implementation of | |
| `FastFCN <https://arxiv.org/abs/1903.11816>`_. | |
| Args: | |
| in_channels (Tuple[int], optional): The number of input channels | |
| for each convolution operations before upsampling. | |
| Default: (512, 1024, 2048). | |
| mid_channels (int): The number of output channels of JPU. | |
| Default: 512. | |
| start_level (int): Index of the start input backbone level used to | |
| build the feature pyramid. Default: 0. | |
| end_level (int): Index of the end input backbone level (exclusive) to | |
| build the feature pyramid. Default: -1, which means the last level. | |
| dilations (tuple[int]): Dilation rate of each Depthwise | |
| Separable ConvModule. Default: (1, 2, 4, 8). | |
| align_corners (bool, optional): The align_corners argument of | |
| resize operation. Default: False. | |
| conv_cfg (dict | None): Config of conv layers. | |
| Default: None. | |
| norm_cfg (dict | None): Config of norm layers. | |
| Default: dict(type='BN'). | |
| act_cfg (dict): Config of activation layers. | |
| Default: dict(type='ReLU'). | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Default: None. | |
| """ | |
| def __init__(self, | |
| in_channels=(512, 1024, 2048), | |
| mid_channels=512, | |
| start_level=0, | |
| end_level=-1, | |
| dilations=(1, 2, 4, 8), | |
| align_corners=False, | |
| conv_cfg=None, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU'), | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| assert isinstance(in_channels, tuple) | |
| assert isinstance(dilations, tuple) | |
| self.in_channels = in_channels | |
| self.mid_channels = mid_channels | |
| self.start_level = start_level | |
| self.num_ins = len(in_channels) | |
| if end_level == -1: | |
| self.backbone_end_level = self.num_ins | |
| else: | |
| self.backbone_end_level = end_level | |
| assert end_level <= len(in_channels) | |
| self.dilations = dilations | |
| self.align_corners = align_corners | |
| self.conv_layers = nn.ModuleList() | |
| self.dilation_layers = nn.ModuleList() | |
| for i in range(self.start_level, self.backbone_end_level): | |
| conv_layer = nn.Sequential( | |
| ConvModule( | |
| self.in_channels[i], | |
| self.mid_channels, | |
| kernel_size=3, | |
| padding=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg)) | |
| self.conv_layers.append(conv_layer) | |
| for i in range(len(dilations)): | |
| dilation_layer = nn.Sequential( | |
| DepthwiseSeparableConvModule( | |
| in_channels=(self.backbone_end_level - self.start_level) * | |
| self.mid_channels, | |
| out_channels=self.mid_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=dilations[i], | |
| dilation=dilations[i], | |
| dw_norm_cfg=norm_cfg, | |
| dw_act_cfg=None, | |
| pw_norm_cfg=norm_cfg, | |
| pw_act_cfg=act_cfg)) | |
| self.dilation_layers.append(dilation_layer) | |
| def forward(self, inputs): | |
| """Forward function.""" | |
| assert len(inputs) == len(self.in_channels), 'Length of inputs must \ | |
| be the same with self.in_channels!' | |
| feats = [ | |
| self.conv_layers[i - self.start_level](inputs[i]) | |
| for i in range(self.start_level, self.backbone_end_level) | |
| ] | |
| h, w = feats[0].shape[2:] | |
| for i in range(1, len(feats)): | |
| feats[i] = resize( | |
| feats[i], | |
| size=(h, w), | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| feat = torch.cat(feats, dim=1) | |
| concat_feat = torch.cat([ | |
| self.dilation_layers[i](feat) for i in range(len(self.dilations)) | |
| ], | |
| dim=1) | |
| outs = [] | |
| # Default: outs[2] is the output of JPU for decoder head, outs[1] is | |
| # the feature map from backbone for auxiliary head. Additionally, | |
| # outs[0] can also be used for auxiliary head. | |
| for i in range(self.start_level, self.backbone_end_level - 1): | |
| outs.append(inputs[i]) | |
| outs.append(concat_feat) | |
| return tuple(outs) | |