""" Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F from ...core import register from .common import FrozenBatchNorm2d, get_activation __all__ = ["PResNet"] ResNet_cfg = { 18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], # 152: [3, 8, 36, 3], } donwload_url = { 18: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet18_vd_pretrained_from_paddle.pth", 34: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet34_vd_pretrained_from_paddle.pth", 50: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet50_vd_ssld_v2_pretrained_from_paddle.pth", 101: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth", } class ConvNormLayer(nn.Module): def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None): super().__init__() self.conv = nn.Conv2d( ch_in, ch_out, kernel_size, stride, padding=(kernel_size - 1) // 2 if padding is None else padding, bias=bias, ) self.norm = nn.BatchNorm2d(ch_out) self.act = get_activation(act) def forward(self, x): return self.act(self.norm(self.conv(x))) class BasicBlock(nn.Module): expansion = 1 def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"): super().__init__() self.shortcut = shortcut if not shortcut: if variant == "d" and stride == 2: self.short = nn.Sequential( OrderedDict( [ ("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)), ("conv", ConvNormLayer(ch_in, ch_out, 1, 1)), ] ) ) else: self.short = ConvNormLayer(ch_in, ch_out, 1, stride) self.branch2a = ConvNormLayer(ch_in, ch_out, 3, stride, act=act) self.branch2b = ConvNormLayer(ch_out, ch_out, 3, 1, act=None) self.act = nn.Identity() if act is None else get_activation(act) def forward(self, x): out = self.branch2a(x) out = self.branch2b(out) if self.shortcut: short = x else: short = self.short(x) out = out + short out = self.act(out) return out class BottleNeck(nn.Module): expansion = 4 def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"): super().__init__() if variant == "a": stride1, stride2 = stride, 1 else: stride1, stride2 = 1, stride width = ch_out self.branch2a = ConvNormLayer(ch_in, width, 1, stride1, act=act) self.branch2b = ConvNormLayer(width, width, 3, stride2, act=act) self.branch2c = ConvNormLayer(width, ch_out * self.expansion, 1, 1) self.shortcut = shortcut if not shortcut: if variant == "d" and stride == 2: self.short = nn.Sequential( OrderedDict( [ ("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)), ("conv", ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1)), ] ) ) else: self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride) self.act = nn.Identity() if act is None else get_activation(act) def forward(self, x): out = self.branch2a(x) out = self.branch2b(out) out = self.branch2c(out) if self.shortcut: short = x else: short = self.short(x) out = out + short out = self.act(out) return out class Blocks(nn.Module): def __init__(self, block, ch_in, ch_out, count, stage_num, act="relu", variant="b"): super().__init__() self.blocks = nn.ModuleList() for i in range(count): self.blocks.append( block( ch_in, ch_out, stride=2 if i == 0 and stage_num != 2 else 1, shortcut=False if i == 0 else True, variant=variant, act=act, ) ) if i == 0: ch_in = ch_out * block.expansion def forward(self, x): out = x for block in self.blocks: out = block(out) return out @register() class PResNet(nn.Module): def __init__( self, depth, variant="d", num_stages=4, return_idx=[0, 1, 2, 3], act="relu", freeze_at=-1, freeze_norm=True, pretrained=False, ): super().__init__() block_nums = ResNet_cfg[depth] ch_in = 64 if variant in ["c", "d"]: conv_def = [ [3, ch_in // 2, 3, 2, "conv1_1"], [ch_in // 2, ch_in // 2, 3, 1, "conv1_2"], [ch_in // 2, ch_in, 3, 1, "conv1_3"], ] else: conv_def = [[3, ch_in, 7, 2, "conv1_1"]] self.conv1 = nn.Sequential( OrderedDict( [ (name, ConvNormLayer(cin, cout, k, s, act=act)) for cin, cout, k, s, name in conv_def ] ) ) ch_out_list = [64, 128, 256, 512] block = BottleNeck if depth >= 50 else BasicBlock _out_channels = [block.expansion * v for v in ch_out_list] _out_strides = [4, 8, 16, 32] self.res_layers = nn.ModuleList() for i in range(num_stages): stage_num = i + 2 self.res_layers.append( Blocks( block, ch_in, ch_out_list[i], block_nums[i], stage_num, act=act, variant=variant ) ) ch_in = _out_channels[i] self.return_idx = return_idx self.out_channels = [_out_channels[_i] for _i in return_idx] self.out_strides = [_out_strides[_i] for _i in return_idx] if freeze_at >= 0: self._freeze_parameters(self.conv1) for i in range(min(freeze_at, num_stages)): self._freeze_parameters(self.res_layers[i]) if freeze_norm: self._freeze_norm(self) if pretrained: if isinstance(pretrained, bool) or "http" in pretrained: state = torch.hub.load_state_dict_from_url( donwload_url[depth], map_location="cpu", model_dir="weight" ) else: state = torch.load(pretrained, map_location="cpu") self.load_state_dict(state) print(f"Load PResNet{depth} state_dict") def _freeze_parameters(self, m: nn.Module): for p in m.parameters(): p.requires_grad = False def _freeze_norm(self, m: nn.Module): if isinstance(m, nn.BatchNorm2d): m = FrozenBatchNorm2d(m.num_features) else: for name, child in m.named_children(): _child = self._freeze_norm(child) if _child is not child: setattr(m, name, _child) return m def forward(self, x): conv1 = self.conv1(x) x = F.max_pool2d(conv1, kernel_size=3, stride=2, padding=1) outs = [] for idx, stage in enumerate(self.res_layers): x = stage(x) if idx in self.return_idx: outs.append(x) return outs