Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer | |
| from mmengine.model import BaseModule | |
| from torch import Tensor | |
| from mmseg.models.decode_heads.decode_head import BaseDecodeHead | |
| from mmseg.models.losses import accuracy | |
| from mmseg.models.utils import resize | |
| from mmseg.registry import MODELS | |
| from mmseg.utils import OptConfigType, SampleList | |
| class BasePIDHead(BaseModule): | |
| """Base class for PID head. | |
| Args: | |
| in_channels (int): Number of input channels. | |
| channels (int): Number of output channels. | |
| norm_cfg (dict): Config dict for normalization layer. | |
| Default: dict(type='BN'). | |
| act_cfg (dict): Config dict for activation layer. | |
| Default: dict(type='ReLU', inplace=True). | |
| init_cfg (dict or list[dict], optional): Init config dict. | |
| Default: None. | |
| """ | |
| def __init__(self, | |
| in_channels: int, | |
| channels: int, | |
| norm_cfg: OptConfigType = dict(type='BN'), | |
| act_cfg: OptConfigType = dict(type='ReLU', inplace=True), | |
| init_cfg: OptConfigType = None): | |
| super().__init__(init_cfg) | |
| self.conv = ConvModule( | |
| in_channels, | |
| channels, | |
| kernel_size=3, | |
| padding=1, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| order=('norm', 'act', 'conv')) | |
| _, self.norm = build_norm_layer(norm_cfg, num_features=channels) | |
| self.act = build_activation_layer(act_cfg) | |
| def forward(self, x: Tensor, cls_seg: Optional[nn.Module]) -> Tensor: | |
| """Forward function. | |
| Args: | |
| x (Tensor): Input tensor. | |
| cls_seg (nn.Module, optional): The classification head. | |
| Returns: | |
| Tensor: Output tensor. | |
| """ | |
| x = self.conv(x) | |
| x = self.norm(x) | |
| x = self.act(x) | |
| if cls_seg is not None: | |
| x = cls_seg(x) | |
| return x | |
| class PIDHead(BaseDecodeHead): | |
| """Decode head for PIDNet. | |
| Args: | |
| in_channels (int): Number of input channels. | |
| channels (int): Number of output channels. | |
| num_classes (int): Number of classes. | |
| norm_cfg (dict): Config dict for normalization layer. | |
| Default: dict(type='BN'). | |
| act_cfg (dict): Config dict for activation layer. | |
| Default: dict(type='ReLU', inplace=True). | |
| """ | |
| def __init__(self, | |
| in_channels: int, | |
| channels: int, | |
| num_classes: int, | |
| norm_cfg: OptConfigType = dict(type='BN'), | |
| act_cfg: OptConfigType = dict(type='ReLU', inplace=True), | |
| **kwargs): | |
| super().__init__( | |
| in_channels, | |
| channels, | |
| num_classes=num_classes, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| **kwargs) | |
| self.i_head = BasePIDHead(in_channels, channels, norm_cfg, act_cfg) | |
| self.p_head = BasePIDHead(in_channels // 2, channels, norm_cfg, | |
| act_cfg) | |
| self.d_head = BasePIDHead( | |
| in_channels // 2, | |
| in_channels // 4, | |
| norm_cfg, | |
| ) | |
| self.p_cls_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1) | |
| self.d_cls_seg = nn.Conv2d(in_channels // 4, 1, kernel_size=1) | |
| def init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_( | |
| m.weight, mode='fan_out', nonlinearity='relu') | |
| elif isinstance(m, nn.BatchNorm2d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| def forward( | |
| self, | |
| inputs: Union[Tensor, | |
| Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]: | |
| """Forward function. | |
| Args: | |
| inputs (Tensor | tuple[Tensor]): Input tensor or tuple of | |
| Tensor. When training, the input is a tuple of three tensors, | |
| (p_feat, i_feat, d_feat), and the output is a tuple of three | |
| tensors, (p_seg_logit, i_seg_logit, d_seg_logit). | |
| When inference, only the head of integral branch is used, and | |
| input is a tensor of integral feature map, and the output is | |
| the segmentation logit. | |
| Returns: | |
| Tensor | tuple[Tensor]: Output tensor or tuple of tensors. | |
| """ | |
| if self.training: | |
| x_p, x_i, x_d = inputs | |
| x_p = self.p_head(x_p, self.p_cls_seg) | |
| x_i = self.i_head(x_i, self.cls_seg) | |
| x_d = self.d_head(x_d, self.d_cls_seg) | |
| return x_p, x_i, x_d | |
| else: | |
| return self.i_head(inputs, self.cls_seg) | |
| def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tuple[Tensor]: | |
| gt_semantic_segs = [ | |
| data_sample.gt_sem_seg.data for data_sample in batch_data_samples | |
| ] | |
| gt_edge_segs = [ | |
| data_sample.gt_edge_map.data for data_sample in batch_data_samples | |
| ] | |
| gt_sem_segs = torch.stack(gt_semantic_segs, dim=0) | |
| gt_edge_segs = torch.stack(gt_edge_segs, dim=0) | |
| return gt_sem_segs, gt_edge_segs | |
| def loss_by_feat(self, seg_logits: Tuple[Tensor], | |
| batch_data_samples: SampleList) -> dict: | |
| loss = dict() | |
| p_logit, i_logit, d_logit = seg_logits | |
| sem_label, bd_label = self._stack_batch_gt(batch_data_samples) | |
| p_logit = resize( | |
| input=p_logit, | |
| size=sem_label.shape[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| i_logit = resize( | |
| input=i_logit, | |
| size=sem_label.shape[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| d_logit = resize( | |
| input=d_logit, | |
| size=bd_label.shape[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| sem_label = sem_label.squeeze(1) | |
| bd_label = bd_label.squeeze(1) | |
| loss['loss_sem_p'] = self.loss_decode[0]( | |
| p_logit, sem_label, ignore_index=self.ignore_index) | |
| loss['loss_sem_i'] = self.loss_decode[1](i_logit, sem_label) | |
| loss['loss_bd'] = self.loss_decode[2](d_logit, bd_label) | |
| filler = torch.ones_like(sem_label) * self.ignore_index | |
| sem_bd_label = torch.where( | |
| torch.sigmoid(d_logit[:, 0, :, :]) > 0.8, sem_label, filler) | |
| loss['loss_sem_bd'] = self.loss_decode[3](i_logit, sem_bd_label) | |
| loss['acc_seg'] = accuracy( | |
| i_logit, sem_label, ignore_index=self.ignore_index) | |
| return loss | |