Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule | |
| from mmseg.registry import MODELS | |
| from ..utils import resize | |
| from .decode_head import BaseDecodeHead | |
| from .psp_head import PPM | |
| class UPerHead(BaseDecodeHead): | |
| """Unified Perceptual Parsing for Scene Understanding. | |
| This head is the implementation of `UPerNet | |
| <https://arxiv.org/abs/1807.10221>`_. | |
| Args: | |
| pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid | |
| Module applied on the last feature. Default: (1, 2, 3, 6). | |
| """ | |
| def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): | |
| super().__init__(input_transform='multiple_select', **kwargs) | |
| # PSP Module | |
| self.psp_modules = PPM( | |
| pool_scales, | |
| self.in_channels[-1], | |
| self.channels, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg, | |
| align_corners=self.align_corners) | |
| self.bottleneck = ConvModule( | |
| self.in_channels[-1] + len(pool_scales) * self.channels, | |
| self.channels, | |
| 3, | |
| padding=1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg) | |
| # FPN Module | |
| self.lateral_convs = nn.ModuleList() | |
| self.fpn_convs = nn.ModuleList() | |
| for in_channels in self.in_channels[:-1]: # skip the top layer | |
| l_conv = ConvModule( | |
| in_channels, | |
| self.channels, | |
| 1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg, | |
| inplace=False) | |
| fpn_conv = ConvModule( | |
| self.channels, | |
| self.channels, | |
| 3, | |
| padding=1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg, | |
| inplace=False) | |
| self.lateral_convs.append(l_conv) | |
| self.fpn_convs.append(fpn_conv) | |
| self.fpn_bottleneck = ConvModule( | |
| len(self.in_channels) * self.channels, | |
| self.channels, | |
| 3, | |
| padding=1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg) | |
| def psp_forward(self, inputs): | |
| """Forward function of PSP module.""" | |
| x = inputs[-1] | |
| psp_outs = [x] | |
| psp_outs.extend(self.psp_modules(x)) | |
| psp_outs = torch.cat(psp_outs, dim=1) | |
| output = self.bottleneck(psp_outs) | |
| return output | |
| def _forward_feature(self, inputs): | |
| """Forward function for feature maps before classifying each pixel with | |
| ``self.cls_seg`` fc. | |
| Args: | |
| inputs (list[Tensor]): List of multi-level img features. | |
| Returns: | |
| feats (Tensor): A tensor of shape (batch_size, self.channels, | |
| H, W) which is feature map for last layer of decoder head. | |
| """ | |
| inputs = self._transform_inputs(inputs) | |
| # build laterals | |
| laterals = [ | |
| lateral_conv(inputs[i]) | |
| for i, lateral_conv in enumerate(self.lateral_convs) | |
| ] | |
| laterals.append(self.psp_forward(inputs)) | |
| # build top-down path | |
| used_backbone_levels = len(laterals) | |
| for i in range(used_backbone_levels - 1, 0, -1): | |
| prev_shape = laterals[i - 1].shape[2:] | |
| laterals[i - 1] = laterals[i - 1] + resize( | |
| laterals[i], | |
| size=prev_shape, | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| # build outputs | |
| fpn_outs = [ | |
| self.fpn_convs[i](laterals[i]) | |
| for i in range(used_backbone_levels - 1) | |
| ] | |
| # append psp feature | |
| fpn_outs.append(laterals[-1]) | |
| for i in range(used_backbone_levels - 1, 0, -1): | |
| fpn_outs[i] = resize( | |
| fpn_outs[i], | |
| size=fpn_outs[0].shape[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| fpn_outs = torch.cat(fpn_outs, dim=1) | |
| feats = self.fpn_bottleneck(fpn_outs) | |
| return feats | |
| def forward(self, inputs): | |
| """Forward function.""" | |
| output = self._forward_feature(inputs) | |
| output = self.cls_seg(output) | |
| return output | |