Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| from mmseg.registry import MODELS | |
| from .fcn_head import FCNHead | |
| try: | |
| from mmcv.ops import CrissCrossAttention | |
| except ModuleNotFoundError: | |
| CrissCrossAttention = None | |
| class CCHead(FCNHead): | |
| """CCNet: Criss-Cross Attention for Semantic Segmentation. | |
| This head is the implementation of `CCNet | |
| <https://arxiv.org/abs/1811.11721>`_. | |
| Args: | |
| recurrence (int): Number of recurrence of Criss Cross Attention | |
| module. Default: 2. | |
| """ | |
| def __init__(self, recurrence=2, **kwargs): | |
| if CrissCrossAttention is None: | |
| raise RuntimeError('Please install mmcv-full for ' | |
| 'CrissCrossAttention ops') | |
| super().__init__(num_convs=2, **kwargs) | |
| self.recurrence = recurrence | |
| self.cca = CrissCrossAttention(self.channels) | |
| def forward(self, inputs): | |
| """Forward function.""" | |
| x = self._transform_inputs(inputs) | |
| output = self.convs[0](x) | |
| for _ in range(self.recurrence): | |
| output = self.cca(output) | |
| output = self.convs[1](output) | |
| if self.concat_input: | |
| output = self.conv_cat(torch.cat([x, output], dim=1)) | |
| output = self.cls_seg(output) | |
| return output | |