File size: 4,532 Bytes
57746f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
r""" Hypercorrelation Squeeze Network """
from functools import reduce
from operator import add
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet
from torchvision.models import vgg
from fewshot_data.model.base.feature import extract_feat_vgg, extract_feat_res
from fewshot_data.model.base.correlation import Correlation
from fewshot_data.model.learner import HPNLearner
class HypercorrSqueezeNetwork(nn.Module):
def __init__(self, backbone, use_original_imgsize):
super(HypercorrSqueezeNetwork, self).__init__()
# 1. Backbone network initialization
self.backbone_type = backbone
self.use_original_imgsize = use_original_imgsize
if backbone == 'vgg16':
self.backbone = vgg.vgg16(pretrained=True)
self.feat_ids = [17, 19, 21, 24, 26, 28, 30]
self.extract_feats = extract_feat_vgg
nbottlenecks = [2, 2, 3, 3, 3, 1]
elif backbone == 'resnet50':
self.backbone = resnet.resnet50(pretrained=True)
self.feat_ids = list(range(4, 17))
self.extract_feats = extract_feat_res
nbottlenecks = [3, 4, 6, 3]
elif backbone == 'resnet101':
self.backbone = resnet.resnet101(pretrained=True)
self.feat_ids = list(range(4, 34))
self.extract_feats = extract_feat_res
nbottlenecks = [3, 4, 23, 3]
else:
raise Exception('Unavailable backbone: %s' % backbone)
self.bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), nbottlenecks)))
self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(nbottlenecks)])
self.stack_ids = torch.tensor(self.lids).bincount().__reversed__().cumsum(dim=0)[:3]
self.backbone.eval()
self.hpn_learner = HPNLearner(list(reversed(nbottlenecks[-3:])))
self.cross_entropy_loss = nn.CrossEntropyLoss()
def forward(self, query_img, support_img, support_mask):
with torch.no_grad():
query_feats = self.extract_feats(query_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids)
support_feats = self.extract_feats(support_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids)
support_feats = self.mask_feature(support_feats, support_mask.clone())
corr = Correlation.multilayer_correlation(query_feats, support_feats, self.stack_ids)
logit_mask = self.hpn_learner(corr)
if not self.use_original_imgsize:
logit_mask = F.interpolate(logit_mask, support_img.size()[2:], mode='bilinear', align_corners=True)
return logit_mask
def mask_feature(self, features, support_mask):
for idx, feature in enumerate(features):
mask = F.interpolate(support_mask.unsqueeze(1).float(), feature.size()[2:], mode='bilinear', align_corners=True)
features[idx] = features[idx] * mask
return features
def predict_mask_nshot(self, batch, nshot):
# Perform multiple prediction given (nshot) number of different support sets
logit_mask_agg = 0
for s_idx in range(nshot):
logit_mask = self(batch['query_img'], batch['support_imgs'][:, s_idx], batch['support_masks'][:, s_idx])
if self.use_original_imgsize:
org_qry_imsize = tuple([batch['org_query_imsize'][1].item(), batch['org_query_imsize'][0].item()])
logit_mask = F.interpolate(logit_mask, org_qry_imsize, mode='bilinear', align_corners=True)
logit_mask_agg += logit_mask.argmax(dim=1).clone()
if nshot == 1: return logit_mask_agg
# Average & quantize predictions given threshold (=0.5)
bsz = logit_mask_agg.size(0)
max_vote = logit_mask_agg.view(bsz, -1).max(dim=1)[0]
max_vote = torch.stack([max_vote, torch.ones_like(max_vote).long()])
max_vote = max_vote.max(dim=0)[0].view(bsz, 1, 1)
pred_mask = logit_mask_agg.float() / max_vote
pred_mask[pred_mask < 0.5] = 0
pred_mask[pred_mask >= 0.5] = 1
return pred_mask
def compute_objective(self, logit_mask, gt_mask):
bsz = logit_mask.size(0)
logit_mask = logit_mask.view(bsz, 2, -1)
gt_mask = gt_mask.view(bsz, -1).long()
return self.cross_entropy_loss(logit_mask, gt_mask)
def train_mode(self):
self.train()
self.backbone.eval() # to prevent BN from learning data statistics with exponential averaging
|