|
r""" Hypercorrelation Squeeze Network """ |
|
from functools import reduce |
|
from operator import add |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision.models import resnet |
|
from torchvision.models import vgg |
|
|
|
from fewshot_data.model.base.feature import extract_feat_vgg, extract_feat_res |
|
from fewshot_data.model.base.correlation import Correlation |
|
from fewshot_data.model.learner import HPNLearner |
|
|
|
|
|
class HypercorrSqueezeNetwork(nn.Module): |
|
def __init__(self, backbone, use_original_imgsize): |
|
super(HypercorrSqueezeNetwork, self).__init__() |
|
|
|
|
|
self.backbone_type = backbone |
|
self.use_original_imgsize = use_original_imgsize |
|
if backbone == 'vgg16': |
|
self.backbone = vgg.vgg16(pretrained=True) |
|
self.feat_ids = [17, 19, 21, 24, 26, 28, 30] |
|
self.extract_feats = extract_feat_vgg |
|
nbottlenecks = [2, 2, 3, 3, 3, 1] |
|
elif backbone == 'resnet50': |
|
self.backbone = resnet.resnet50(pretrained=True) |
|
self.feat_ids = list(range(4, 17)) |
|
self.extract_feats = extract_feat_res |
|
nbottlenecks = [3, 4, 6, 3] |
|
elif backbone == 'resnet101': |
|
self.backbone = resnet.resnet101(pretrained=True) |
|
self.feat_ids = list(range(4, 34)) |
|
self.extract_feats = extract_feat_res |
|
nbottlenecks = [3, 4, 23, 3] |
|
else: |
|
raise Exception('Unavailable backbone: %s' % backbone) |
|
|
|
self.bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), nbottlenecks))) |
|
self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(nbottlenecks)]) |
|
self.stack_ids = torch.tensor(self.lids).bincount().__reversed__().cumsum(dim=0)[:3] |
|
self.backbone.eval() |
|
self.hpn_learner = HPNLearner(list(reversed(nbottlenecks[-3:]))) |
|
self.cross_entropy_loss = nn.CrossEntropyLoss() |
|
|
|
def forward(self, query_img, support_img, support_mask): |
|
with torch.no_grad(): |
|
query_feats = self.extract_feats(query_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids) |
|
support_feats = self.extract_feats(support_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids) |
|
support_feats = self.mask_feature(support_feats, support_mask.clone()) |
|
corr = Correlation.multilayer_correlation(query_feats, support_feats, self.stack_ids) |
|
|
|
logit_mask = self.hpn_learner(corr) |
|
if not self.use_original_imgsize: |
|
logit_mask = F.interpolate(logit_mask, support_img.size()[2:], mode='bilinear', align_corners=True) |
|
|
|
return logit_mask |
|
|
|
def mask_feature(self, features, support_mask): |
|
for idx, feature in enumerate(features): |
|
mask = F.interpolate(support_mask.unsqueeze(1).float(), feature.size()[2:], mode='bilinear', align_corners=True) |
|
features[idx] = features[idx] * mask |
|
return features |
|
|
|
def predict_mask_nshot(self, batch, nshot): |
|
|
|
|
|
logit_mask_agg = 0 |
|
for s_idx in range(nshot): |
|
logit_mask = self(batch['query_img'], batch['support_imgs'][:, s_idx], batch['support_masks'][:, s_idx]) |
|
|
|
if self.use_original_imgsize: |
|
org_qry_imsize = tuple([batch['org_query_imsize'][1].item(), batch['org_query_imsize'][0].item()]) |
|
logit_mask = F.interpolate(logit_mask, org_qry_imsize, mode='bilinear', align_corners=True) |
|
|
|
logit_mask_agg += logit_mask.argmax(dim=1).clone() |
|
if nshot == 1: return logit_mask_agg |
|
|
|
|
|
bsz = logit_mask_agg.size(0) |
|
max_vote = logit_mask_agg.view(bsz, -1).max(dim=1)[0] |
|
max_vote = torch.stack([max_vote, torch.ones_like(max_vote).long()]) |
|
max_vote = max_vote.max(dim=0)[0].view(bsz, 1, 1) |
|
pred_mask = logit_mask_agg.float() / max_vote |
|
pred_mask[pred_mask < 0.5] = 0 |
|
pred_mask[pred_mask >= 0.5] = 1 |
|
|
|
return pred_mask |
|
|
|
def compute_objective(self, logit_mask, gt_mask): |
|
bsz = logit_mask.size(0) |
|
logit_mask = logit_mask.view(bsz, 2, -1) |
|
gt_mask = gt_mask.view(bsz, -1).long() |
|
|
|
return self.cross_entropy_loss(logit_mask, gt_mask) |
|
|
|
def train_mode(self): |
|
self.train() |
|
self.backbone.eval() |
|
|