|
|
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from fewshot_data.model.base.conv4d import CenterPivotConv4d as Conv4d |
|
|
|
|
|
class HPNLearner(nn.Module): |
|
def __init__(self, inch): |
|
super(HPNLearner, self).__init__() |
|
|
|
def make_building_block(in_channel, out_channels, kernel_sizes, spt_strides, group=4): |
|
assert len(out_channels) == len(kernel_sizes) == len(spt_strides) |
|
|
|
building_block_layers = [] |
|
for idx, (outch, ksz, stride) in enumerate(zip(out_channels, kernel_sizes, spt_strides)): |
|
inch = in_channel if idx == 0 else out_channels[idx - 1] |
|
ksz4d = (ksz,) * 4 |
|
str4d = (1, 1) + (stride,) * 2 |
|
pad4d = (ksz // 2,) * 4 |
|
|
|
building_block_layers.append(Conv4d(inch, outch, ksz4d, str4d, pad4d)) |
|
building_block_layers.append(nn.GroupNorm(group, outch)) |
|
building_block_layers.append(nn.ReLU(inplace=True)) |
|
|
|
return nn.Sequential(*building_block_layers) |
|
|
|
outch1, outch2, outch3 = 16, 64, 128 |
|
|
|
|
|
self.encoder_layer4 = make_building_block(inch[0], [outch1, outch2, outch3], [3, 3, 3], [2, 2, 2]) |
|
self.encoder_layer3 = make_building_block(inch[1], [outch1, outch2, outch3], [5, 3, 3], [4, 2, 2]) |
|
self.encoder_layer2 = make_building_block(inch[2], [outch1, outch2, outch3], [5, 5, 3], [4, 4, 2]) |
|
|
|
|
|
self.encoder_layer4to3 = make_building_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1]) |
|
self.encoder_layer3to2 = make_building_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1]) |
|
|
|
|
|
self.decoder1 = nn.Sequential(nn.Conv2d(outch3, outch3, (3, 3), padding=(1, 1), bias=True), |
|
nn.ReLU(), |
|
nn.Conv2d(outch3, outch2, (3, 3), padding=(1, 1), bias=True), |
|
nn.ReLU()) |
|
|
|
self.decoder2 = nn.Sequential(nn.Conv2d(outch2, outch2, (3, 3), padding=(1, 1), bias=True), |
|
nn.ReLU(), |
|
nn.Conv2d(outch2, 2, (3, 3), padding=(1, 1), bias=True)) |
|
|
|
def interpolate_support_dims(self, hypercorr, spatial_size=None): |
|
bsz, ch, ha, wa, hb, wb = hypercorr.size() |
|
hypercorr = hypercorr.permute(0, 4, 5, 1, 2, 3).contiguous().view(bsz * hb * wb, ch, ha, wa) |
|
hypercorr = F.interpolate(hypercorr, spatial_size, mode='bilinear', align_corners=True) |
|
o_hb, o_wb = spatial_size |
|
hypercorr = hypercorr.view(bsz, hb, wb, ch, o_hb, o_wb).permute(0, 3, 4, 5, 1, 2).contiguous() |
|
return hypercorr |
|
|
|
def forward(self, hypercorr_pyramid): |
|
|
|
|
|
hypercorr_sqz4 = self.encoder_layer4(hypercorr_pyramid[0]) |
|
hypercorr_sqz3 = self.encoder_layer3(hypercorr_pyramid[1]) |
|
hypercorr_sqz2 = self.encoder_layer2(hypercorr_pyramid[2]) |
|
|
|
|
|
hypercorr_sqz4 = self.interpolate_support_dims(hypercorr_sqz4, hypercorr_sqz3.size()[-4:-2]) |
|
hypercorr_mix43 = hypercorr_sqz4 + hypercorr_sqz3 |
|
hypercorr_mix43 = self.encoder_layer4to3(hypercorr_mix43) |
|
|
|
hypercorr_mix43 = self.interpolate_support_dims(hypercorr_mix43, hypercorr_sqz2.size()[-4:-2]) |
|
hypercorr_mix432 = hypercorr_mix43 + hypercorr_sqz2 |
|
hypercorr_mix432 = self.encoder_layer3to2(hypercorr_mix432) |
|
|
|
bsz, ch, ha, wa, hb, wb = hypercorr_mix432.size() |
|
hypercorr_encoded = hypercorr_mix432.view(bsz, ch, ha, wa, -1).mean(dim=-1) |
|
|
|
|
|
hypercorr_decoded = self.decoder1(hypercorr_encoded) |
|
upsample_size = (hypercorr_decoded.size(-1) * 2,) * 2 |
|
hypercorr_decoded = F.interpolate(hypercorr_decoded, upsample_size, mode='bilinear', align_corners=True) |
|
logit_mask = self.decoder2(hypercorr_decoded) |
|
|
|
return logit_mask |
|
|