Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| from mmcv.cnn import NonLocal2d | |
| from mmseg.registry import MODELS | |
| from .fcn_head import FCNHead | |
| class NLHead(FCNHead): | |
| """Non-local Neural Networks. | |
| This head is the implementation of `NLNet | |
| <https://arxiv.org/abs/1711.07971>`_. | |
| Args: | |
| reduction (int): Reduction factor of projection transform. Default: 2. | |
| use_scale (bool): Whether to scale pairwise_weight by | |
| sqrt(1/inter_channels). Default: True. | |
| mode (str): The nonlocal mode. Options are 'embedded_gaussian', | |
| 'dot_product'. Default: 'embedded_gaussian.'. | |
| """ | |
| def __init__(self, | |
| reduction=2, | |
| use_scale=True, | |
| mode='embedded_gaussian', | |
| **kwargs): | |
| super().__init__(num_convs=2, **kwargs) | |
| self.reduction = reduction | |
| self.use_scale = use_scale | |
| self.mode = mode | |
| self.nl_block = NonLocal2d( | |
| in_channels=self.channels, | |
| reduction=self.reduction, | |
| use_scale=self.use_scale, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| mode=self.mode) | |
| def forward(self, inputs): | |
| """Forward function.""" | |
| x = self._transform_inputs(inputs) | |
| output = self.convs[0](x) | |
| output = self.nl_block(output) | |
| output = self.convs[1](output) | |
| if self.concat_input: | |
| output = self.conv_cat(torch.cat([x, output], dim=1)) | |
| output = self.cls_seg(output) | |
| return output | |