Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import ConvModule, caffe2_xavier_init | |
| from mmcv.ops.merge_cells import ConcatCell | |
| from ..builder import NECKS | |
| class NASFCOS_FPN(nn.Module): | |
| """FPN structure in NASFPN. | |
| Implementation of paper `NAS-FCOS: Fast Neural Architecture Search for | |
| Object Detection <https://arxiv.org/abs/1906.04423>`_ | |
| Args: | |
| in_channels (List[int]): Number of input channels per scale. | |
| out_channels (int): Number of output channels (used at each scale) | |
| num_outs (int): Number of output scales. | |
| start_level (int): Index of the start input backbone level used to | |
| build the feature pyramid. Default: 0. | |
| end_level (int): Index of the end input backbone level (exclusive) to | |
| build the feature pyramid. Default: -1, which means the last level. | |
| add_extra_convs (bool): It decides whether to add conv | |
| layers on top of the original feature maps. Default to False. | |
| If True, its actual mode is specified by `extra_convs_on_inputs`. | |
| conv_cfg (dict): dictionary to construct and config conv layer. | |
| norm_cfg (dict): dictionary to construct and config norm layer. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| num_outs, | |
| start_level=1, | |
| end_level=-1, | |
| add_extra_convs=False, | |
| conv_cfg=None, | |
| norm_cfg=None): | |
| super(NASFCOS_FPN, self).__init__() | |
| assert isinstance(in_channels, list) | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.num_ins = len(in_channels) | |
| self.num_outs = num_outs | |
| self.norm_cfg = norm_cfg | |
| self.conv_cfg = conv_cfg | |
| if end_level == -1: | |
| self.backbone_end_level = self.num_ins | |
| assert num_outs >= self.num_ins - start_level | |
| else: | |
| self.backbone_end_level = end_level | |
| assert end_level <= len(in_channels) | |
| assert num_outs == end_level - start_level | |
| self.start_level = start_level | |
| self.end_level = end_level | |
| self.add_extra_convs = add_extra_convs | |
| self.adapt_convs = nn.ModuleList() | |
| for i in range(self.start_level, self.backbone_end_level): | |
| adapt_conv = ConvModule( | |
| in_channels[i], | |
| out_channels, | |
| 1, | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU', inplace=False)) | |
| self.adapt_convs.append(adapt_conv) | |
| # C2 is omitted according to the paper | |
| extra_levels = num_outs - self.backbone_end_level + self.start_level | |
| def build_concat_cell(with_input1_conv, with_input2_conv): | |
| cell_conv_cfg = dict( | |
| kernel_size=1, padding=0, bias=False, groups=out_channels) | |
| return ConcatCell( | |
| in_channels=out_channels, | |
| out_channels=out_channels, | |
| with_out_conv=True, | |
| out_conv_cfg=cell_conv_cfg, | |
| out_norm_cfg=dict(type='BN'), | |
| out_conv_order=('norm', 'act', 'conv'), | |
| with_input1_conv=with_input1_conv, | |
| with_input2_conv=with_input2_conv, | |
| input_conv_cfg=conv_cfg, | |
| input_norm_cfg=norm_cfg, | |
| upsample_mode='nearest') | |
| # Denote c3=f0, c4=f1, c5=f2 for convince | |
| self.fpn = nn.ModuleDict() | |
| self.fpn['c22_1'] = build_concat_cell(True, True) | |
| self.fpn['c22_2'] = build_concat_cell(True, True) | |
| self.fpn['c32'] = build_concat_cell(True, False) | |
| self.fpn['c02'] = build_concat_cell(True, False) | |
| self.fpn['c42'] = build_concat_cell(True, True) | |
| self.fpn['c36'] = build_concat_cell(True, True) | |
| self.fpn['c61'] = build_concat_cell(True, True) # f9 | |
| self.extra_downsamples = nn.ModuleList() | |
| for i in range(extra_levels): | |
| extra_act_cfg = None if i == 0 \ | |
| else dict(type='ReLU', inplace=False) | |
| self.extra_downsamples.append( | |
| ConvModule( | |
| out_channels, | |
| out_channels, | |
| 3, | |
| stride=2, | |
| padding=1, | |
| act_cfg=extra_act_cfg, | |
| order=('act', 'norm', 'conv'))) | |
| def forward(self, inputs): | |
| """Forward function.""" | |
| feats = [ | |
| adapt_conv(inputs[i + self.start_level]) | |
| for i, adapt_conv in enumerate(self.adapt_convs) | |
| ] | |
| for (i, module_name) in enumerate(self.fpn): | |
| idx_1, idx_2 = int(module_name[1]), int(module_name[2]) | |
| res = self.fpn[module_name](feats[idx_1], feats[idx_2]) | |
| feats.append(res) | |
| ret = [] | |
| for (idx, input_idx) in zip([9, 8, 7], [1, 2, 3]): # add P3, P4, P5 | |
| feats1, feats2 = feats[idx], feats[5] | |
| feats2_resize = F.interpolate( | |
| feats2, | |
| size=feats1.size()[2:], | |
| mode='bilinear', | |
| align_corners=False) | |
| feats_sum = feats1 + feats2_resize | |
| ret.append( | |
| F.interpolate( | |
| feats_sum, | |
| size=inputs[input_idx].size()[2:], | |
| mode='bilinear', | |
| align_corners=False)) | |
| for submodule in self.extra_downsamples: | |
| ret.append(submodule(ret[-1])) | |
| return tuple(ret) | |
| def init_weights(self): | |
| """Initialize the weights of module.""" | |
| for module in self.fpn.values(): | |
| if hasattr(module, 'conv_out'): | |
| caffe2_xavier_init(module.out_conv.conv) | |
| for modules in [ | |
| self.adapt_convs.modules(), | |
| self.extra_downsamples.modules() | |
| ]: | |
| for module in modules: | |
| if isinstance(module, nn.Conv2d): | |
| caffe2_xavier_init(module) | |