Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule | |
| from mmseg.registry import MODELS | |
| from ..utils import resize | |
| from .decode_head import BaseDecodeHead | |
| class PPM(nn.ModuleList): | |
| """Pooling Pyramid Module used in PSPNet. | |
| Args: | |
| pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid | |
| Module. | |
| in_channels (int): Input channels. | |
| channels (int): Channels after modules, before conv_seg. | |
| conv_cfg (dict|None): Config of conv layers. | |
| norm_cfg (dict|None): Config of norm layers. | |
| act_cfg (dict): Config of activation layers. | |
| align_corners (bool): align_corners argument of F.interpolate. | |
| """ | |
| def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, | |
| act_cfg, align_corners, **kwargs): | |
| super().__init__() | |
| self.pool_scales = pool_scales | |
| self.align_corners = align_corners | |
| self.in_channels = in_channels | |
| self.channels = channels | |
| self.conv_cfg = conv_cfg | |
| self.norm_cfg = norm_cfg | |
| self.act_cfg = act_cfg | |
| for pool_scale in pool_scales: | |
| self.append( | |
| nn.Sequential( | |
| nn.AdaptiveAvgPool2d(pool_scale), | |
| ConvModule( | |
| self.in_channels, | |
| self.channels, | |
| 1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg, | |
| **kwargs))) | |
| def forward(self, x): | |
| """Forward function.""" | |
| ppm_outs = [] | |
| for ppm in self: | |
| ppm_out = ppm(x) | |
| upsampled_ppm_out = resize( | |
| ppm_out, | |
| size=x.size()[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| ppm_outs.append(upsampled_ppm_out) | |
| return ppm_outs | |
| class PSPHead(BaseDecodeHead): | |
| """Pyramid Scene Parsing Network. | |
| This head is the implementation of | |
| `PSPNet <https://arxiv.org/abs/1612.01105>`_. | |
| Args: | |
| pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid | |
| Module. Default: (1, 2, 3, 6). | |
| """ | |
| def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): | |
| super().__init__(**kwargs) | |
| assert isinstance(pool_scales, (list, tuple)) | |
| self.pool_scales = pool_scales | |
| self.psp_modules = PPM( | |
| self.pool_scales, | |
| self.in_channels, | |
| self.channels, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg, | |
| align_corners=self.align_corners) | |
| self.bottleneck = ConvModule( | |
| self.in_channels + len(pool_scales) * self.channels, | |
| self.channels, | |
| 3, | |
| padding=1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg) | |
| def _forward_feature(self, inputs): | |
| """Forward function for feature maps before classifying each pixel with | |
| ``self.cls_seg`` fc. | |
| Args: | |
| inputs (list[Tensor]): List of multi-level img features. | |
| Returns: | |
| feats (Tensor): A tensor of shape (batch_size, self.channels, | |
| H, W) which is feature map for last layer of decoder head. | |
| """ | |
| x = self._transform_inputs(inputs) | |
| psp_outs = [x] | |
| psp_outs.extend(self.psp_modules(x)) | |
| psp_outs = torch.cat(psp_outs, dim=1) | |
| feats = self.bottleneck(psp_outs) | |
| return feats | |
| def forward(self, inputs): | |
| """Forward function.""" | |
| output = self._forward_feature(inputs) | |
| output = self.cls_seg(output) | |
| return output | |