Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from lanet_utils import image_grid | |
| class ConvBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(ConvBlock, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d( | |
| out_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class DilationConv3x3(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(DilationConv3x3, self).__init__() | |
| self.conv = nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=2, | |
| dilation=2, | |
| bias=False, | |
| ) | |
| self.bn = nn.BatchNorm2d(out_channels) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.bn(x) | |
| return x | |
| class InterestPointModule(nn.Module): | |
| def __init__(self, is_test=False): | |
| super(InterestPointModule, self).__init__() | |
| self.is_test = is_test | |
| self.conv1 = ConvBlock(3, 32) | |
| self.conv2 = ConvBlock(32, 64) | |
| self.conv3 = ConvBlock(64, 128) | |
| self.conv4 = ConvBlock(128, 256) | |
| self.maxpool2x2 = nn.MaxPool2d(2, 2) | |
| # score head | |
| self.score_conv = nn.Conv2d( | |
| 256, 256, kernel_size=3, stride=1, padding=1, bias=False | |
| ) | |
| self.score_norm = nn.BatchNorm2d(256) | |
| self.score_out = nn.Conv2d(256, 3, kernel_size=3, stride=1, padding=1) | |
| self.softmax = nn.Softmax(dim=1) | |
| # location head | |
| self.loc_conv = nn.Conv2d( | |
| 256, 256, kernel_size=3, stride=1, padding=1, bias=False | |
| ) | |
| self.loc_norm = nn.BatchNorm2d(256) | |
| self.loc_out = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) | |
| # descriptor out | |
| self.des_conv2 = DilationConv3x3(64, 256) | |
| self.des_conv3 = DilationConv3x3(128, 256) | |
| # cross_head: | |
| self.shift_out = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| B, _, H, W = x.shape | |
| x = self.conv1(x) | |
| x = self.maxpool2x2(x) | |
| x2 = self.conv2(x) | |
| x = self.maxpool2x2(x2) | |
| x3 = self.conv3(x) | |
| x = self.maxpool2x2(x3) | |
| x = self.conv4(x) | |
| B, _, Hc, Wc = x.shape | |
| # score head | |
| score_x = self.score_out(self.relu(self.score_norm(self.score_conv(x)))) | |
| aware = self.softmax(score_x[:, 0:2, :, :]) | |
| score = score_x[:, 2, :, :].unsqueeze(1).sigmoid() | |
| border_mask = torch.ones(B, Hc, Wc) | |
| border_mask[:, 0] = 0 | |
| border_mask[:, Hc - 1] = 0 | |
| border_mask[:, :, 0] = 0 | |
| border_mask[:, :, Wc - 1] = 0 | |
| border_mask = border_mask.unsqueeze(1) | |
| score = score * border_mask.to(score.device) | |
| # location head | |
| coord_x = self.relu(self.loc_norm(self.loc_conv(x))) | |
| coord_cell = self.loc_out(coord_x).tanh() | |
| shift_ratio = self.shift_out(coord_x).sigmoid() * 2.0 | |
| step = ((H / Hc) - 1) / 2.0 | |
| center_base = ( | |
| image_grid( | |
| B, | |
| Hc, | |
| Wc, | |
| dtype=coord_cell.dtype, | |
| device=coord_cell.device, | |
| ones=False, | |
| normalized=False, | |
| ).mul(H / Hc) | |
| + step | |
| ) | |
| coord_un = center_base.add(coord_cell.mul(shift_ratio * step)) | |
| coord = coord_un.clone() | |
| coord[:, 0] = torch.clamp(coord_un[:, 0], min=0, max=W - 1) | |
| coord[:, 1] = torch.clamp(coord_un[:, 1], min=0, max=H - 1) | |
| # descriptor block | |
| desc_block = [] | |
| desc_block.append(self.des_conv2(x2)) | |
| desc_block.append(self.des_conv3(x3)) | |
| desc_block.append(aware) | |
| if self.is_test: | |
| coord_norm = coord[:, :2].clone() | |
| coord_norm[:, 0] = (coord_norm[:, 0] / (float(W - 1) / 2.0)) - 1.0 | |
| coord_norm[:, 1] = (coord_norm[:, 1] / (float(H - 1) / 2.0)) - 1.0 | |
| coord_norm = coord_norm.permute(0, 2, 3, 1) | |
| desc2 = torch.nn.functional.grid_sample(desc_block[0], coord_norm) | |
| desc3 = torch.nn.functional.grid_sample(desc_block[1], coord_norm) | |
| aware = desc_block[2] | |
| desc = torch.mul(desc2, aware[:, 0, :, :]) + torch.mul( | |
| desc3, aware[:, 1, :, :] | |
| ) | |
| desc = desc.div( | |
| torch.unsqueeze(torch.norm(desc, p=2, dim=1), 1) | |
| ) # Divide by norm to normalize. | |
| return score, coord, desc | |
| return score, coord, desc_block | |
| class CorrespondenceModule(nn.Module): | |
| def __init__(self, match_type="dual_softmax"): | |
| super(CorrespondenceModule, self).__init__() | |
| self.match_type = match_type | |
| if self.match_type == "dual_softmax": | |
| self.temperature = 0.1 | |
| else: | |
| raise NotImplementedError() | |
| def forward(self, source_desc, target_desc): | |
| b, c, h, w = source_desc.size() | |
| source_desc = source_desc.div( | |
| torch.unsqueeze(torch.norm(source_desc, p=2, dim=1), 1) | |
| ).view(b, -1, h * w) | |
| target_desc = target_desc.div( | |
| torch.unsqueeze(torch.norm(target_desc, p=2, dim=1), 1) | |
| ).view(b, -1, h * w) | |
| if self.match_type == "dual_softmax": | |
| sim_mat = ( | |
| torch.einsum("bcm, bcn -> bmn", source_desc, target_desc) | |
| / self.temperature | |
| ) | |
| confidence_matrix = F.softmax(sim_mat, 1) * F.softmax(sim_mat, 2) | |
| else: | |
| raise NotImplementedError() | |
| return confidence_matrix | |