File size: 3,806 Bytes
6c9ac8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
# Copyright (c) Tianheng Cheng and its affiliates. All Rights Reserved
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model.weight_init import caffe2_xavier_init, kaiming_init
from mmdet.registry import MODELS
class PyramidPoolingModule(nn.Module):
def __init__(self,
in_channels,
channels=512,
sizes=(1, 2, 3, 6),
act_cfg=dict(type='ReLU')):
super().__init__()
self.stages = []
self.stages = nn.ModuleList(
[self._make_stage(in_channels, channels, size) for size in sizes])
self.bottleneck = nn.Conv2d(in_channels + len(sizes) * channels,
in_channels, 1)
self.act = MODELS.build(act_cfg)
def _make_stage(self, features, out_features, size):
prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
conv = nn.Conv2d(features, out_features, 1)
return nn.Sequential(prior, conv)
def forward(self, feats):
h, w = feats.size(2), feats.size(3)
priors = [
F.interpolate(
input=self.act(stage(feats)),
size=(h, w),
mode='bilinear',
align_corners=False) for stage in self.stages
] + [feats]
out = self.act(self.bottleneck(torch.cat(priors, 1)))
return out
@MODELS.register_module()
class InstanceContextEncoder(nn.Module):
"""
Instance Context Encoder
1. construct feature pyramids from ResNet
2. enlarge receptive fields (ppm)
3. multi-scale fusion
"""
def __init__(self,
in_channels,
out_channels=256,
with_ppm=True,
act_cfg=dict(type='ReLU')):
super().__init__()
self.num_channels = out_channels
self.in_channels = in_channels
self.with_ppm = with_ppm
fpn_laterals = []
fpn_outputs = []
for in_channel in reversed(self.in_channels):
lateral_conv = nn.Conv2d(in_channel, self.num_channels, 1)
output_conv = nn.Conv2d(
self.num_channels, self.num_channels, 3, padding=1)
caffe2_xavier_init(lateral_conv)
caffe2_xavier_init(output_conv)
fpn_laterals.append(lateral_conv)
fpn_outputs.append(output_conv)
self.fpn_laterals = nn.ModuleList(fpn_laterals)
self.fpn_outputs = nn.ModuleList(fpn_outputs)
# ppm
if self.with_ppm:
self.ppm = PyramidPoolingModule(
self.num_channels, self.num_channels // 4, act_cfg=act_cfg)
# final fusion
self.fusion = nn.Conv2d(self.num_channels * 3, self.num_channels, 1)
kaiming_init(self.fusion)
def forward(self, features):
features = features[::-1]
prev_features = self.fpn_laterals[0](features[0])
if self.with_ppm:
prev_features = self.ppm(prev_features)
outputs = [self.fpn_outputs[0](prev_features)]
for feature, lat_conv, output_conv in zip(features[1:],
self.fpn_laterals[1:],
self.fpn_outputs[1:]):
lat_features = lat_conv(feature)
top_down_features = F.interpolate(
prev_features, scale_factor=2.0, mode='nearest')
prev_features = lat_features + top_down_features
outputs.insert(0, output_conv(prev_features))
size = outputs[0].shape[2:]
features = [outputs[0]] + [
F.interpolate(x, size, mode='bilinear', align_corners=False)
for x in outputs[1:]
]
features = self.fusion(torch.cat(features, dim=1))
return features
|