Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule | |
| from mmengine.utils import is_tuple_of | |
| from mmseg.registry import MODELS | |
| from ..utils import resize | |
| from .decode_head import BaseDecodeHead | |
| class LRASPPHead(BaseDecodeHead): | |
| """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. | |
| This head is the improved implementation of `Searching for MobileNetV3 | |
| <https://ieeexplore.ieee.org/document/9008835>`_. | |
| Args: | |
| branch_channels (tuple[int]): The number of output channels in every | |
| each branch. Default: (32, 64). | |
| """ | |
| def __init__(self, branch_channels=(32, 64), **kwargs): | |
| super().__init__(**kwargs) | |
| if self.input_transform != 'multiple_select': | |
| raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' | |
| f'must be \'multiple_select\'. But received ' | |
| f'\'{self.input_transform}\'') | |
| assert is_tuple_of(branch_channels, int) | |
| assert len(branch_channels) == len(self.in_channels) - 1 | |
| self.branch_channels = branch_channels | |
| self.convs = nn.Sequential() | |
| self.conv_ups = nn.Sequential() | |
| for i in range(len(branch_channels)): | |
| self.convs.add_module( | |
| f'conv{i}', | |
| nn.Conv2d( | |
| self.in_channels[i], branch_channels[i], 1, bias=False)) | |
| self.conv_ups.add_module( | |
| f'conv_up{i}', | |
| ConvModule( | |
| self.channels + branch_channels[i], | |
| self.channels, | |
| 1, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg, | |
| bias=False)) | |
| self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1) | |
| self.aspp_conv = ConvModule( | |
| self.in_channels[-1], | |
| self.channels, | |
| 1, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg, | |
| bias=False) | |
| self.image_pool = nn.Sequential( | |
| nn.AvgPool2d(kernel_size=49, stride=(16, 20)), | |
| ConvModule( | |
| self.in_channels[2], | |
| self.channels, | |
| 1, | |
| act_cfg=dict(type='Sigmoid'), | |
| bias=False)) | |
| def forward(self, inputs): | |
| """Forward function.""" | |
| inputs = self._transform_inputs(inputs) | |
| x = inputs[-1] | |
| x = self.aspp_conv(x) * resize( | |
| self.image_pool(x), | |
| size=x.size()[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| x = self.conv_up_input(x) | |
| for i in range(len(self.branch_channels) - 1, -1, -1): | |
| x = resize( | |
| x, | |
| size=inputs[i].size()[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| x = torch.cat([x, self.convs[i](inputs[i])], 1) | |
| x = self.conv_ups[i](x) | |
| return self.cls_seg(x) | |