# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import math import fvcore.nn.weight_init as weight_init import torch.nn.functional as F from torch import nn from detectron2.layers import Conv2d, ShapeSpec, get_norm from detectron2.modeling.backbone import Backbone from detectron2.modeling.backbone.fpn import FPN from detectron2.modeling.backbone.build import BACKBONE_REGISTRY from detectron2.modeling.backbone.resnet import build_resnet_backbone class LastLevelP6P7_P5(nn.Module): """ This module is used in RetinaNet to generate extra layers, P6 and P7 from C5 feature. """ def __init__(self, in_channels, out_channels): super().__init__() self.num_levels = 2 self.in_feature = "p5" self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) for module in [self.p6, self.p7]: weight_init.c2_xavier_fill(module) def forward(self, c5): p6 = self.p6(c5) p7 = self.p7(F.relu(p6)) return [p6, p7] @BACKBONE_REGISTRY.register() def build_p67_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): """ Args: cfg: a detectron2 CfgNode Returns: backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. """ bottom_up = build_resnet_backbone(cfg, input_shape) in_features = cfg.MODEL.FPN.IN_FEATURES out_channels = cfg.MODEL.FPN.OUT_CHANNELS backbone = FPN( bottom_up=bottom_up, in_features=in_features, out_channels=out_channels, norm=cfg.MODEL.FPN.NORM, top_block=LastLevelP6P7_P5(out_channels, out_channels), fuse_type=cfg.MODEL.FPN.FUSE_TYPE, ) return backbone @BACKBONE_REGISTRY.register() def build_p35_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): """ Args: cfg: a detectron2 CfgNode Returns: backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. """ bottom_up = build_resnet_backbone(cfg, input_shape) in_features = cfg.MODEL.FPN.IN_FEATURES out_channels = cfg.MODEL.FPN.OUT_CHANNELS backbone = FPN( bottom_up=bottom_up, in_features=in_features, out_channels=out_channels, norm=cfg.MODEL.FPN.NORM, top_block=None, fuse_type=cfg.MODEL.FPN.FUSE_TYPE, ) return backbone