Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional | |
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule | |
| from mmengine.model import BaseModule | |
| from torch import Tensor | |
| from mmseg.registry import MODELS | |
| from mmseg.utils import OptConfigType | |
| class BasicBlock(BaseModule): | |
| """Basic block from `ResNet <https://arxiv.org/abs/1512.03385>`_. | |
| Args: | |
| in_channels (int): Input channels. | |
| channels (int): Output channels. | |
| stride (int): Stride of the first block. Default: 1. | |
| downsample (nn.Module, optional): Downsample operation on identity. | |
| Default: None. | |
| norm_cfg (dict, optional): Config dict for normalization layer. | |
| Default: dict(type='BN'). | |
| act_cfg (dict, optional): Config dict for activation layer in | |
| ConvModule. Default: dict(type='ReLU', inplace=True). | |
| act_cfg_out (dict, optional): Config dict for activation layer at the | |
| last of the block. Default: None. | |
| init_cfg (dict, optional): Initialization config dict. Default: None. | |
| """ | |
| expansion = 1 | |
| def __init__(self, | |
| in_channels: int, | |
| channels: int, | |
| stride: int = 1, | |
| downsample: nn.Module = None, | |
| norm_cfg: OptConfigType = dict(type='BN'), | |
| act_cfg: OptConfigType = dict(type='ReLU', inplace=True), | |
| act_cfg_out: OptConfigType = dict(type='ReLU', inplace=True), | |
| init_cfg: OptConfigType = None): | |
| super().__init__(init_cfg) | |
| self.conv1 = ConvModule( | |
| in_channels, | |
| channels, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=1, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| self.conv2 = ConvModule( | |
| channels, | |
| channels, | |
| kernel_size=3, | |
| padding=1, | |
| norm_cfg=norm_cfg, | |
| act_cfg=None) | |
| self.downsample = downsample | |
| if act_cfg_out: | |
| self.act = MODELS.build(act_cfg_out) | |
| def forward(self, x: Tensor) -> Tensor: | |
| residual = x | |
| out = self.conv1(x) | |
| out = self.conv2(out) | |
| if self.downsample: | |
| residual = self.downsample(x) | |
| out += residual | |
| if hasattr(self, 'act'): | |
| out = self.act(out) | |
| return out | |
| class Bottleneck(BaseModule): | |
| """Bottleneck block from `ResNet <https://arxiv.org/abs/1512.03385>`_. | |
| Args: | |
| in_channels (int): Input channels. | |
| channels (int): Output channels. | |
| stride (int): Stride of the first block. Default: 1. | |
| downsample (nn.Module, optional): Downsample operation on identity. | |
| Default: None. | |
| norm_cfg (dict, optional): Config dict for normalization layer. | |
| Default: dict(type='BN'). | |
| act_cfg (dict, optional): Config dict for activation layer in | |
| ConvModule. Default: dict(type='ReLU', inplace=True). | |
| act_cfg_out (dict, optional): Config dict for activation layer at | |
| the last of the block. Default: None. | |
| init_cfg (dict, optional): Initialization config dict. Default: None. | |
| """ | |
| expansion = 2 | |
| def __init__(self, | |
| in_channels: int, | |
| channels: int, | |
| stride: int = 1, | |
| downsample: Optional[nn.Module] = None, | |
| norm_cfg: OptConfigType = dict(type='BN'), | |
| act_cfg: OptConfigType = dict(type='ReLU', inplace=True), | |
| act_cfg_out: OptConfigType = None, | |
| init_cfg: OptConfigType = None): | |
| super().__init__(init_cfg) | |
| self.conv1 = ConvModule( | |
| in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg) | |
| self.conv2 = ConvModule( | |
| channels, | |
| channels, | |
| 3, | |
| stride, | |
| 1, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| self.conv3 = ConvModule( | |
| channels, | |
| channels * self.expansion, | |
| 1, | |
| norm_cfg=norm_cfg, | |
| act_cfg=None) | |
| if act_cfg_out: | |
| self.act = MODELS.build(act_cfg_out) | |
| self.downsample = downsample | |
| def forward(self, x: Tensor) -> Tensor: | |
| residual = x | |
| out = self.conv1(x) | |
| out = self.conv2(out) | |
| out = self.conv3(out) | |
| if self.downsample: | |
| residual = self.downsample(x) | |
| out += residual | |
| if hasattr(self, 'act'): | |
| out = self.act(out) | |
| return out | |