Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Tuple, Union | |
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer | |
| from torch import Tensor | |
| from mmseg.models.decode_heads.decode_head import BaseDecodeHead | |
| from mmseg.models.losses import accuracy | |
| from mmseg.models.utils import resize | |
| from mmseg.registry import MODELS | |
| from mmseg.utils import OptConfigType, SampleList | |
| class DDRHead(BaseDecodeHead): | |
| """Decode head for DDRNet. | |
| Args: | |
| in_channels (int): Number of input channels. | |
| channels (int): Number of output channels. | |
| num_classes (int): Number of classes. | |
| norm_cfg (dict, optional): Config dict for normalization layer. | |
| Default: dict(type='BN'). | |
| act_cfg (dict, optional): Config dict for activation layer. | |
| Default: dict(type='ReLU', inplace=True). | |
| """ | |
| def __init__(self, | |
| in_channels: int, | |
| channels: int, | |
| num_classes: int, | |
| norm_cfg: OptConfigType = dict(type='BN'), | |
| act_cfg: OptConfigType = dict(type='ReLU', inplace=True), | |
| **kwargs): | |
| super().__init__( | |
| in_channels, | |
| channels, | |
| num_classes=num_classes, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| **kwargs) | |
| self.head = self._make_base_head(self.in_channels, self.channels) | |
| self.aux_head = self._make_base_head(self.in_channels // 2, | |
| self.channels) | |
| self.aux_cls_seg = nn.Conv2d( | |
| self.channels, self.out_channels, kernel_size=1) | |
| def init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_( | |
| m.weight, mode='fan_out', nonlinearity='relu') | |
| elif isinstance(m, nn.BatchNorm2d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| def forward( | |
| self, | |
| inputs: Union[Tensor, | |
| Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]: | |
| if self.training: | |
| c3_feat, c5_feat = inputs | |
| x_c = self.head(c5_feat) | |
| x_c = self.cls_seg(x_c) | |
| x_s = self.aux_head(c3_feat) | |
| x_s = self.aux_cls_seg(x_s) | |
| return x_c, x_s | |
| else: | |
| x_c = self.head(inputs) | |
| x_c = self.cls_seg(x_c) | |
| return x_c | |
| def _make_base_head(self, in_channels: int, | |
| channels: int) -> nn.Sequential: | |
| layers = [ | |
| ConvModule( | |
| in_channels, | |
| channels, | |
| kernel_size=3, | |
| padding=1, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg, | |
| order=('norm', 'act', 'conv')), | |
| build_norm_layer(self.norm_cfg, channels)[1], | |
| build_activation_layer(self.act_cfg), | |
| ] | |
| return nn.Sequential(*layers) | |
| def loss_by_feat(self, seg_logits: Tuple[Tensor], | |
| batch_data_samples: SampleList) -> dict: | |
| loss = dict() | |
| context_logit, spatial_logit = seg_logits | |
| seg_label = self._stack_batch_gt(batch_data_samples) | |
| context_logit = resize( | |
| context_logit, | |
| size=seg_label.shape[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| spatial_logit = resize( | |
| spatial_logit, | |
| size=seg_label.shape[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| seg_label = seg_label.squeeze(1) | |
| loss['loss_context'] = self.loss_decode[0](context_logit, seg_label) | |
| loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label) | |
| loss['acc_seg'] = accuracy( | |
| context_logit, seg_label, ignore_index=self.ignore_index) | |
| return loss | |