Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
""" | |
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ...core import register | |
from .common import FrozenBatchNorm2d, get_activation | |
__all__ = ["PResNet"] | |
ResNet_cfg = { | |
18: [2, 2, 2, 2], | |
34: [3, 4, 6, 3], | |
50: [3, 4, 6, 3], | |
101: [3, 4, 23, 3], | |
# 152: [3, 8, 36, 3], | |
} | |
donwload_url = { | |
18: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet18_vd_pretrained_from_paddle.pth", | |
34: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet34_vd_pretrained_from_paddle.pth", | |
50: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet50_vd_ssld_v2_pretrained_from_paddle.pth", | |
101: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth", | |
} | |
class ConvNormLayer(nn.Module): | |
def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None): | |
super().__init__() | |
self.conv = nn.Conv2d( | |
ch_in, | |
ch_out, | |
kernel_size, | |
stride, | |
padding=(kernel_size - 1) // 2 if padding is None else padding, | |
bias=bias, | |
) | |
self.norm = nn.BatchNorm2d(ch_out) | |
self.act = get_activation(act) | |
def forward(self, x): | |
return self.act(self.norm(self.conv(x))) | |
class BasicBlock(nn.Module): | |
expansion = 1 | |
def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"): | |
super().__init__() | |
self.shortcut = shortcut | |
if not shortcut: | |
if variant == "d" and stride == 2: | |
self.short = nn.Sequential( | |
OrderedDict( | |
[ | |
("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)), | |
("conv", ConvNormLayer(ch_in, ch_out, 1, 1)), | |
] | |
) | |
) | |
else: | |
self.short = ConvNormLayer(ch_in, ch_out, 1, stride) | |
self.branch2a = ConvNormLayer(ch_in, ch_out, 3, stride, act=act) | |
self.branch2b = ConvNormLayer(ch_out, ch_out, 3, 1, act=None) | |
self.act = nn.Identity() if act is None else get_activation(act) | |
def forward(self, x): | |
out = self.branch2a(x) | |
out = self.branch2b(out) | |
if self.shortcut: | |
short = x | |
else: | |
short = self.short(x) | |
out = out + short | |
out = self.act(out) | |
return out | |
class BottleNeck(nn.Module): | |
expansion = 4 | |
def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"): | |
super().__init__() | |
if variant == "a": | |
stride1, stride2 = stride, 1 | |
else: | |
stride1, stride2 = 1, stride | |
width = ch_out | |
self.branch2a = ConvNormLayer(ch_in, width, 1, stride1, act=act) | |
self.branch2b = ConvNormLayer(width, width, 3, stride2, act=act) | |
self.branch2c = ConvNormLayer(width, ch_out * self.expansion, 1, 1) | |
self.shortcut = shortcut | |
if not shortcut: | |
if variant == "d" and stride == 2: | |
self.short = nn.Sequential( | |
OrderedDict( | |
[ | |
("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)), | |
("conv", ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1)), | |
] | |
) | |
) | |
else: | |
self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride) | |
self.act = nn.Identity() if act is None else get_activation(act) | |
def forward(self, x): | |
out = self.branch2a(x) | |
out = self.branch2b(out) | |
out = self.branch2c(out) | |
if self.shortcut: | |
short = x | |
else: | |
short = self.short(x) | |
out = out + short | |
out = self.act(out) | |
return out | |
class Blocks(nn.Module): | |
def __init__(self, block, ch_in, ch_out, count, stage_num, act="relu", variant="b"): | |
super().__init__() | |
self.blocks = nn.ModuleList() | |
for i in range(count): | |
self.blocks.append( | |
block( | |
ch_in, | |
ch_out, | |
stride=2 if i == 0 and stage_num != 2 else 1, | |
shortcut=False if i == 0 else True, | |
variant=variant, | |
act=act, | |
) | |
) | |
if i == 0: | |
ch_in = ch_out * block.expansion | |
def forward(self, x): | |
out = x | |
for block in self.blocks: | |
out = block(out) | |
return out | |
class PResNet(nn.Module): | |
def __init__( | |
self, | |
depth, | |
variant="d", | |
num_stages=4, | |
return_idx=[0, 1, 2, 3], | |
act="relu", | |
freeze_at=-1, | |
freeze_norm=True, | |
pretrained=False, | |
): | |
super().__init__() | |
block_nums = ResNet_cfg[depth] | |
ch_in = 64 | |
if variant in ["c", "d"]: | |
conv_def = [ | |
[3, ch_in // 2, 3, 2, "conv1_1"], | |
[ch_in // 2, ch_in // 2, 3, 1, "conv1_2"], | |
[ch_in // 2, ch_in, 3, 1, "conv1_3"], | |
] | |
else: | |
conv_def = [[3, ch_in, 7, 2, "conv1_1"]] | |
self.conv1 = nn.Sequential( | |
OrderedDict( | |
[ | |
(name, ConvNormLayer(cin, cout, k, s, act=act)) | |
for cin, cout, k, s, name in conv_def | |
] | |
) | |
) | |
ch_out_list = [64, 128, 256, 512] | |
block = BottleNeck if depth >= 50 else BasicBlock | |
_out_channels = [block.expansion * v for v in ch_out_list] | |
_out_strides = [4, 8, 16, 32] | |
self.res_layers = nn.ModuleList() | |
for i in range(num_stages): | |
stage_num = i + 2 | |
self.res_layers.append( | |
Blocks( | |
block, ch_in, ch_out_list[i], block_nums[i], stage_num, act=act, variant=variant | |
) | |
) | |
ch_in = _out_channels[i] | |
self.return_idx = return_idx | |
self.out_channels = [_out_channels[_i] for _i in return_idx] | |
self.out_strides = [_out_strides[_i] for _i in return_idx] | |
if freeze_at >= 0: | |
self._freeze_parameters(self.conv1) | |
for i in range(min(freeze_at, num_stages)): | |
self._freeze_parameters(self.res_layers[i]) | |
if freeze_norm: | |
self._freeze_norm(self) | |
if pretrained: | |
if isinstance(pretrained, bool) or "http" in pretrained: | |
state = torch.hub.load_state_dict_from_url( | |
donwload_url[depth], map_location="cpu", model_dir="weight" | |
) | |
else: | |
state = torch.load(pretrained, map_location="cpu") | |
self.load_state_dict(state) | |
print(f"Load PResNet{depth} state_dict") | |
def _freeze_parameters(self, m: nn.Module): | |
for p in m.parameters(): | |
p.requires_grad = False | |
def _freeze_norm(self, m: nn.Module): | |
if isinstance(m, nn.BatchNorm2d): | |
m = FrozenBatchNorm2d(m.num_features) | |
else: | |
for name, child in m.named_children(): | |
_child = self._freeze_norm(child) | |
if _child is not child: | |
setattr(m, name, _child) | |
return m | |
def forward(self, x): | |
conv1 = self.conv1(x) | |
x = F.max_pool2d(conv1, kernel_size=3, stride=2, padding=1) | |
outs = [] | |
for idx, stage in enumerate(self.res_layers): | |
x = stage(x) | |
if idx in self.return_idx: | |
outs.append(x) | |
return outs | |