Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, | |
| build_activation_layer, build_norm_layer) | |
| from mmengine.model import BaseModule | |
| from mmseg.registry import MODELS | |
| from ..utils import resize | |
| class DetailBranch(BaseModule): | |
| """Detail Branch with wide channels and shallow layers to capture low-level | |
| details and generate high-resolution feature representation. | |
| Args: | |
| detail_channels (Tuple[int]): Size of channel numbers of each stage | |
| in Detail Branch, in paper it has 3 stages. | |
| Default: (64, 64, 128). | |
| in_channels (int): Number of channels of input image. Default: 3. | |
| conv_cfg (dict | None): Config of conv layers. | |
| Default: None. | |
| norm_cfg (dict | None): Config of norm layers. | |
| Default: dict(type='BN'). | |
| act_cfg (dict): Config of activation layers. | |
| Default: dict(type='ReLU'). | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Default: None. | |
| Returns: | |
| x (torch.Tensor): Feature map of Detail Branch. | |
| """ | |
| def __init__(self, | |
| detail_channels=(64, 64, 128), | |
| in_channels=3, | |
| conv_cfg=None, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU'), | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| detail_branch = [] | |
| for i in range(len(detail_channels)): | |
| if i == 0: | |
| detail_branch.append( | |
| nn.Sequential( | |
| ConvModule( | |
| in_channels=in_channels, | |
| out_channels=detail_channels[i], | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg), | |
| ConvModule( | |
| in_channels=detail_channels[i], | |
| out_channels=detail_channels[i], | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg))) | |
| else: | |
| detail_branch.append( | |
| nn.Sequential( | |
| ConvModule( | |
| in_channels=detail_channels[i - 1], | |
| out_channels=detail_channels[i], | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg), | |
| ConvModule( | |
| in_channels=detail_channels[i], | |
| out_channels=detail_channels[i], | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg), | |
| ConvModule( | |
| in_channels=detail_channels[i], | |
| out_channels=detail_channels[i], | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg))) | |
| self.detail_branch = nn.ModuleList(detail_branch) | |
| def forward(self, x): | |
| for stage in self.detail_branch: | |
| x = stage(x) | |
| return x | |
| class StemBlock(BaseModule): | |
| """Stem Block at the beginning of Semantic Branch. | |
| Args: | |
| in_channels (int): Number of input channels. | |
| Default: 3. | |
| out_channels (int): Number of output channels. | |
| Default: 16. | |
| conv_cfg (dict | None): Config of conv layers. | |
| Default: None. | |
| norm_cfg (dict | None): Config of norm layers. | |
| Default: dict(type='BN'). | |
| act_cfg (dict): Config of activation layers. | |
| Default: dict(type='ReLU'). | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Default: None. | |
| Returns: | |
| x (torch.Tensor): First feature map in Semantic Branch. | |
| """ | |
| def __init__(self, | |
| in_channels=3, | |
| out_channels=16, | |
| conv_cfg=None, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU'), | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| self.conv_first = ConvModule( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| self.convs = nn.Sequential( | |
| ConvModule( | |
| in_channels=out_channels, | |
| out_channels=out_channels // 2, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg), | |
| ConvModule( | |
| in_channels=out_channels // 2, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg)) | |
| self.pool = nn.MaxPool2d( | |
| kernel_size=3, stride=2, padding=1, ceil_mode=False) | |
| self.fuse_last = ConvModule( | |
| in_channels=out_channels * 2, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| def forward(self, x): | |
| x = self.conv_first(x) | |
| x_left = self.convs(x) | |
| x_right = self.pool(x) | |
| x = self.fuse_last(torch.cat([x_left, x_right], dim=1)) | |
| return x | |
| class GELayer(BaseModule): | |
| """Gather-and-Expansion Layer. | |
| Args: | |
| in_channels (int): Number of input channels. | |
| out_channels (int): Number of output channels. | |
| exp_ratio (int): Expansion ratio for middle channels. | |
| Default: 6. | |
| stride (int): Stride of GELayer. Default: 1 | |
| conv_cfg (dict | None): Config of conv layers. | |
| Default: None. | |
| norm_cfg (dict | None): Config of norm layers. | |
| Default: dict(type='BN'). | |
| act_cfg (dict): Config of activation layers. | |
| Default: dict(type='ReLU'). | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Default: None. | |
| Returns: | |
| x (torch.Tensor): Intermediate feature map in | |
| Semantic Branch. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| exp_ratio=6, | |
| stride=1, | |
| conv_cfg=None, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU'), | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| mid_channel = in_channels * exp_ratio | |
| self.conv1 = ConvModule( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| if stride == 1: | |
| self.dwconv = nn.Sequential( | |
| # ReLU in ConvModule not shown in paper | |
| ConvModule( | |
| in_channels=in_channels, | |
| out_channels=mid_channel, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=1, | |
| groups=in_channels, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg)) | |
| self.shortcut = None | |
| else: | |
| self.dwconv = nn.Sequential( | |
| ConvModule( | |
| in_channels=in_channels, | |
| out_channels=mid_channel, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=1, | |
| groups=in_channels, | |
| bias=False, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=None), | |
| # ReLU in ConvModule not shown in paper | |
| ConvModule( | |
| in_channels=mid_channel, | |
| out_channels=mid_channel, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| groups=mid_channel, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg), | |
| ) | |
| self.shortcut = nn.Sequential( | |
| DepthwiseSeparableConvModule( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=1, | |
| dw_norm_cfg=norm_cfg, | |
| dw_act_cfg=None, | |
| pw_norm_cfg=norm_cfg, | |
| pw_act_cfg=None, | |
| )) | |
| self.conv2 = nn.Sequential( | |
| ConvModule( | |
| in_channels=mid_channel, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=None, | |
| )) | |
| self.act = build_activation_layer(act_cfg) | |
| def forward(self, x): | |
| identity = x | |
| x = self.conv1(x) | |
| x = self.dwconv(x) | |
| x = self.conv2(x) | |
| if self.shortcut is not None: | |
| shortcut = self.shortcut(identity) | |
| x = x + shortcut | |
| else: | |
| x = x + identity | |
| x = self.act(x) | |
| return x | |
| class CEBlock(BaseModule): | |
| """Context Embedding Block for large receptive filed in Semantic Branch. | |
| Args: | |
| in_channels (int): Number of input channels. | |
| Default: 3. | |
| out_channels (int): Number of output channels. | |
| Default: 16. | |
| conv_cfg (dict | None): Config of conv layers. | |
| Default: None. | |
| norm_cfg (dict | None): Config of norm layers. | |
| Default: dict(type='BN'). | |
| act_cfg (dict): Config of activation layers. | |
| Default: dict(type='ReLU'). | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Default: None. | |
| Returns: | |
| x (torch.Tensor): Last feature map in Semantic Branch. | |
| """ | |
| def __init__(self, | |
| in_channels=3, | |
| out_channels=16, | |
| conv_cfg=None, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU'), | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.gap = nn.Sequential( | |
| nn.AdaptiveAvgPool2d((1, 1)), | |
| build_norm_layer(norm_cfg, self.in_channels)[1]) | |
| self.conv_gap = ConvModule( | |
| in_channels=self.in_channels, | |
| out_channels=self.out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| # Note: in paper here is naive conv2d, no bn-relu | |
| self.conv_last = ConvModule( | |
| in_channels=self.out_channels, | |
| out_channels=self.out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| def forward(self, x): | |
| identity = x | |
| x = self.gap(x) | |
| x = self.conv_gap(x) | |
| x = identity + x | |
| x = self.conv_last(x) | |
| return x | |
| class SemanticBranch(BaseModule): | |
| """Semantic Branch which is lightweight with narrow channels and deep | |
| layers to obtain high-level semantic context. | |
| Args: | |
| semantic_channels(Tuple[int]): Size of channel numbers of | |
| various stages in Semantic Branch. | |
| Default: (16, 32, 64, 128). | |
| in_channels (int): Number of channels of input image. Default: 3. | |
| exp_ratio (int): Expansion ratio for middle channels. | |
| Default: 6. | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Default: None. | |
| Returns: | |
| semantic_outs (List[torch.Tensor]): List of several feature maps | |
| for auxiliary heads (Booster) and Bilateral | |
| Guided Aggregation Layer. | |
| """ | |
| def __init__(self, | |
| semantic_channels=(16, 32, 64, 128), | |
| in_channels=3, | |
| exp_ratio=6, | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| self.in_channels = in_channels | |
| self.semantic_channels = semantic_channels | |
| self.semantic_stages = [] | |
| for i in range(len(semantic_channels)): | |
| stage_name = f'stage{i + 1}' | |
| self.semantic_stages.append(stage_name) | |
| if i == 0: | |
| self.add_module( | |
| stage_name, | |
| StemBlock(self.in_channels, semantic_channels[i])) | |
| elif i == (len(semantic_channels) - 1): | |
| self.add_module( | |
| stage_name, | |
| nn.Sequential( | |
| GELayer(semantic_channels[i - 1], semantic_channels[i], | |
| exp_ratio, 2), | |
| GELayer(semantic_channels[i], semantic_channels[i], | |
| exp_ratio, 1), | |
| GELayer(semantic_channels[i], semantic_channels[i], | |
| exp_ratio, 1), | |
| GELayer(semantic_channels[i], semantic_channels[i], | |
| exp_ratio, 1))) | |
| else: | |
| self.add_module( | |
| stage_name, | |
| nn.Sequential( | |
| GELayer(semantic_channels[i - 1], semantic_channels[i], | |
| exp_ratio, 2), | |
| GELayer(semantic_channels[i], semantic_channels[i], | |
| exp_ratio, 1))) | |
| self.add_module(f'stage{len(semantic_channels)}_CEBlock', | |
| CEBlock(semantic_channels[-1], semantic_channels[-1])) | |
| self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock') | |
| def forward(self, x): | |
| semantic_outs = [] | |
| for stage_name in self.semantic_stages: | |
| semantic_stage = getattr(self, stage_name) | |
| x = semantic_stage(x) | |
| semantic_outs.append(x) | |
| return semantic_outs | |
| class BGALayer(BaseModule): | |
| """Bilateral Guided Aggregation Layer to fuse the complementary information | |
| from both Detail Branch and Semantic Branch. | |
| Args: | |
| out_channels (int): Number of output channels. | |
| Default: 128. | |
| align_corners (bool): align_corners argument of F.interpolate. | |
| Default: False. | |
| conv_cfg (dict | None): Config of conv layers. | |
| Default: None. | |
| norm_cfg (dict | None): Config of norm layers. | |
| Default: dict(type='BN'). | |
| act_cfg (dict): Config of activation layers. | |
| Default: dict(type='ReLU'). | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Default: None. | |
| Returns: | |
| output (torch.Tensor): Output feature map for Segment heads. | |
| """ | |
| def __init__(self, | |
| out_channels=128, | |
| align_corners=False, | |
| conv_cfg=None, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU'), | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| self.out_channels = out_channels | |
| self.align_corners = align_corners | |
| self.detail_dwconv = nn.Sequential( | |
| DepthwiseSeparableConvModule( | |
| in_channels=self.out_channels, | |
| out_channels=self.out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| dw_norm_cfg=norm_cfg, | |
| dw_act_cfg=None, | |
| pw_norm_cfg=None, | |
| pw_act_cfg=None, | |
| )) | |
| self.detail_down = nn.Sequential( | |
| ConvModule( | |
| in_channels=self.out_channels, | |
| out_channels=self.out_channels, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| bias=False, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=None), | |
| nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)) | |
| self.semantic_conv = nn.Sequential( | |
| ConvModule( | |
| in_channels=self.out_channels, | |
| out_channels=self.out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=None)) | |
| self.semantic_dwconv = nn.Sequential( | |
| DepthwiseSeparableConvModule( | |
| in_channels=self.out_channels, | |
| out_channels=self.out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| dw_norm_cfg=norm_cfg, | |
| dw_act_cfg=None, | |
| pw_norm_cfg=None, | |
| pw_act_cfg=None, | |
| )) | |
| self.conv = ConvModule( | |
| in_channels=self.out_channels, | |
| out_channels=self.out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| inplace=True, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| ) | |
| def forward(self, x_d, x_s): | |
| detail_dwconv = self.detail_dwconv(x_d) | |
| detail_down = self.detail_down(x_d) | |
| semantic_conv = self.semantic_conv(x_s) | |
| semantic_dwconv = self.semantic_dwconv(x_s) | |
| semantic_conv = resize( | |
| input=semantic_conv, | |
| size=detail_dwconv.shape[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv) | |
| fuse_2 = detail_down * torch.sigmoid(semantic_dwconv) | |
| fuse_2 = resize( | |
| input=fuse_2, | |
| size=fuse_1.shape[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| output = self.conv(fuse_1 + fuse_2) | |
| return output | |
| class BiSeNetV2(BaseModule): | |
| """BiSeNetV2: Bilateral Network with Guided Aggregation for | |
| Real-time Semantic Segmentation. | |
| This backbone is the implementation of | |
| `BiSeNetV2 <https://arxiv.org/abs/2004.02147>`_. | |
| Args: | |
| in_channels (int): Number of channel of input image. Default: 3. | |
| detail_channels (Tuple[int], optional): Channels of each stage | |
| in Detail Branch. Default: (64, 64, 128). | |
| semantic_channels (Tuple[int], optional): Channels of each stage | |
| in Semantic Branch. Default: (16, 32, 64, 128). | |
| See Table 1 and Figure 3 of paper for more details. | |
| semantic_expansion_ratio (int, optional): The expansion factor | |
| expanding channel number of middle channels in Semantic Branch. | |
| Default: 6. | |
| bga_channels (int, optional): Number of middle channels in | |
| Bilateral Guided Aggregation Layer. Default: 128. | |
| out_indices (Tuple[int] | int, optional): Output from which stages. | |
| Default: (0, 1, 2, 3, 4). | |
| align_corners (bool, optional): The align_corners argument of | |
| resize operation in Bilateral Guided Aggregation Layer. | |
| Default: False. | |
| conv_cfg (dict | None): Config of conv layers. | |
| Default: None. | |
| norm_cfg (dict | None): Config of norm layers. | |
| Default: dict(type='BN'). | |
| act_cfg (dict): Config of activation layers. | |
| Default: dict(type='ReLU'). | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Default: None. | |
| """ | |
| def __init__(self, | |
| in_channels=3, | |
| detail_channels=(64, 64, 128), | |
| semantic_channels=(16, 32, 64, 128), | |
| semantic_expansion_ratio=6, | |
| bga_channels=128, | |
| out_indices=(0, 1, 2, 3, 4), | |
| align_corners=False, | |
| conv_cfg=None, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU'), | |
| init_cfg=None): | |
| if init_cfg is None: | |
| init_cfg = [ | |
| dict(type='Kaiming', layer='Conv2d'), | |
| dict( | |
| type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) | |
| ] | |
| super().__init__(init_cfg=init_cfg) | |
| self.in_channels = in_channels | |
| self.out_indices = out_indices | |
| self.detail_channels = detail_channels | |
| self.semantic_channels = semantic_channels | |
| self.semantic_expansion_ratio = semantic_expansion_ratio | |
| self.bga_channels = bga_channels | |
| self.align_corners = align_corners | |
| self.conv_cfg = conv_cfg | |
| self.norm_cfg = norm_cfg | |
| self.act_cfg = act_cfg | |
| self.detail = DetailBranch(self.detail_channels, self.in_channels) | |
| self.semantic = SemanticBranch(self.semantic_channels, | |
| self.in_channels, | |
| self.semantic_expansion_ratio) | |
| self.bga = BGALayer(self.bga_channels, self.align_corners) | |
| def forward(self, x): | |
| # stole refactoring code from Coin Cheung, thanks | |
| x_detail = self.detail(x) | |
| x_semantic_lst = self.semantic(x) | |
| x_head = self.bga(x_detail, x_semantic_lst[-1]) | |
| outs = [x_head] + x_semantic_lst[:-1] | |
| outs = [outs[i] for i in self.out_indices] | |
| return tuple(outs) | |