Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule | |
| from mmseg.registry import MODELS | |
| from ..utils import resize | |
| from .aspp_head import ASPPHead, ASPPModule | |
| class DepthwiseSeparableASPPModule(ASPPModule): | |
| """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable | |
| conv.""" | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| for i, dilation in enumerate(self.dilations): | |
| if dilation > 1: | |
| self[i] = DepthwiseSeparableConvModule( | |
| self.in_channels, | |
| self.channels, | |
| 3, | |
| dilation=dilation, | |
| padding=dilation, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg) | |
| class DepthwiseSeparableASPPHead(ASPPHead): | |
| """Encoder-Decoder with Atrous Separable Convolution for Semantic Image | |
| Segmentation. | |
| This head is the implementation of `DeepLabV3+ | |
| <https://arxiv.org/abs/1802.02611>`_. | |
| Args: | |
| c1_in_channels (int): The input channels of c1 decoder. If is 0, | |
| the no decoder will be used. | |
| c1_channels (int): The intermediate channels of c1 decoder. | |
| """ | |
| def __init__(self, c1_in_channels, c1_channels, **kwargs): | |
| super().__init__(**kwargs) | |
| assert c1_in_channels >= 0 | |
| self.aspp_modules = DepthwiseSeparableASPPModule( | |
| dilations=self.dilations, | |
| in_channels=self.in_channels, | |
| channels=self.channels, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg) | |
| if c1_in_channels > 0: | |
| self.c1_bottleneck = ConvModule( | |
| c1_in_channels, | |
| c1_channels, | |
| 1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg) | |
| else: | |
| self.c1_bottleneck = None | |
| self.sep_bottleneck = nn.Sequential( | |
| DepthwiseSeparableConvModule( | |
| self.channels + c1_channels, | |
| self.channels, | |
| 3, | |
| padding=1, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg), | |
| DepthwiseSeparableConvModule( | |
| self.channels, | |
| self.channels, | |
| 3, | |
| padding=1, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg)) | |
| def forward(self, inputs): | |
| """Forward function.""" | |
| x = self._transform_inputs(inputs) | |
| aspp_outs = [ | |
| resize( | |
| self.image_pool(x), | |
| size=x.size()[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| ] | |
| aspp_outs.extend(self.aspp_modules(x)) | |
| aspp_outs = torch.cat(aspp_outs, dim=1) | |
| output = self.bottleneck(aspp_outs) | |
| if self.c1_bottleneck is not None: | |
| c1_output = self.c1_bottleneck(inputs[0]) | |
| output = resize( | |
| input=output, | |
| size=c1_output.shape[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| output = torch.cat([output, c1_output], dim=1) | |
| output = self.sep_bottleneck(output) | |
| output = self.cls_seg(output) | |
| return output | |