Spaces:
No application file
No application file
| import math | |
| import os | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ..utils import get_tuple_transform_ops | |
| from einops import rearrange | |
| from ..utils.local_correlation import local_correlation | |
| class ConvRefiner(nn.Module): | |
| def __init__( | |
| self, | |
| in_dim=6, | |
| hidden_dim=16, | |
| out_dim=2, | |
| dw=False, | |
| kernel_size=5, | |
| hidden_blocks=3, | |
| displacement_emb=None, | |
| displacement_emb_dim=None, | |
| local_corr_radius=None, | |
| corr_in_other=None, | |
| no_support_fm=False, | |
| ): | |
| super().__init__() | |
| self.block1 = self.create_block( | |
| in_dim, hidden_dim, dw=dw, kernel_size=kernel_size | |
| ) | |
| self.hidden_blocks = nn.Sequential( | |
| *[ | |
| self.create_block( | |
| hidden_dim, | |
| hidden_dim, | |
| dw=dw, | |
| kernel_size=kernel_size, | |
| ) | |
| for hb in range(hidden_blocks) | |
| ] | |
| ) | |
| self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0) | |
| if displacement_emb: | |
| self.has_displacement_emb = True | |
| self.disp_emb = nn.Conv2d(2, displacement_emb_dim, 1, 1, 0) | |
| else: | |
| self.has_displacement_emb = False | |
| self.local_corr_radius = local_corr_radius | |
| self.corr_in_other = corr_in_other | |
| self.no_support_fm = no_support_fm | |
| def create_block( | |
| self, | |
| in_dim, | |
| out_dim, | |
| dw=False, | |
| kernel_size=5, | |
| ): | |
| num_groups = 1 if not dw else in_dim | |
| if dw: | |
| assert ( | |
| out_dim % in_dim == 0 | |
| ), "outdim must be divisible by indim for depthwise" | |
| conv1 = nn.Conv2d( | |
| in_dim, | |
| out_dim, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=kernel_size // 2, | |
| groups=num_groups, | |
| ) | |
| norm = nn.BatchNorm2d(out_dim) | |
| relu = nn.ReLU(inplace=True) | |
| conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0) | |
| return nn.Sequential(conv1, norm, relu, conv2) | |
| def forward(self, x, y, flow): | |
| """Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them | |
| Args: | |
| x ([type]): [description] | |
| y ([type]): [description] | |
| flow ([type]): [description] | |
| Returns: | |
| [type]: [description] | |
| """ | |
| device = x.device | |
| b, c, hs, ws = x.shape | |
| with torch.no_grad(): | |
| x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False) | |
| if self.has_displacement_emb: | |
| query_coords = torch.meshgrid( | |
| ( | |
| torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), | |
| torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), | |
| ) | |
| ) | |
| query_coords = torch.stack((query_coords[1], query_coords[0])) | |
| query_coords = query_coords[None].expand(b, 2, hs, ws) | |
| in_displacement = flow - query_coords | |
| emb_in_displacement = self.disp_emb(in_displacement) | |
| if self.local_corr_radius: | |
| # TODO: should corr have gradient? | |
| if self.corr_in_other: | |
| # Corr in other means take a kxk grid around the predicted coordinate in other image | |
| local_corr = local_correlation( | |
| x, y, local_radius=self.local_corr_radius, flow=flow | |
| ) | |
| else: | |
| # Otherwise we use the warp to sample in the first image | |
| # This is actually different operations, especially for large viewpoint changes | |
| local_corr = local_correlation( | |
| x, | |
| x_hat, | |
| local_radius=self.local_corr_radius, | |
| ) | |
| if self.no_support_fm: | |
| x_hat = torch.zeros_like(x) | |
| d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1) | |
| else: | |
| d = torch.cat((x, x_hat, emb_in_displacement), dim=1) | |
| else: | |
| if self.no_support_fm: | |
| x_hat = torch.zeros_like(x) | |
| d = torch.cat((x, x_hat), dim=1) | |
| d = self.block1(d) | |
| d = self.hidden_blocks(d) | |
| d = self.out_conv(d) | |
| certainty, displacement = d[:, :-2], d[:, -2:] | |
| return certainty, displacement | |
| class CosKernel(nn.Module): # similar to softmax kernel | |
| def __init__(self, T, learn_temperature=False): | |
| super().__init__() | |
| self.learn_temperature = learn_temperature | |
| if self.learn_temperature: | |
| self.T = nn.Parameter(torch.tensor(T)) | |
| else: | |
| self.T = T | |
| def __call__(self, x, y, eps=1e-6): | |
| c = torch.einsum("bnd,bmd->bnm", x, y) / ( | |
| x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps | |
| ) | |
| if self.learn_temperature: | |
| T = self.T.abs() + 0.01 | |
| else: | |
| T = torch.tensor(self.T, device=c.device) | |
| K = ((c - 1.0) / T).exp() | |
| return K | |
| class CAB(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(CAB, self).__init__() | |
| self.global_pooling = nn.AdaptiveAvgPool2d(1) | |
| self.conv1 = nn.Conv2d( | |
| in_channels, out_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| self.relu = nn.ReLU() | |
| self.conv2 = nn.Conv2d( | |
| out_channels, out_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| self.sigmod = nn.Sigmoid() | |
| def forward(self, x): | |
| x1, x2 = x # high, low (old, new) | |
| x = torch.cat([x1, x2], dim=1) | |
| x = self.global_pooling(x) | |
| x = self.conv1(x) | |
| x = self.relu(x) | |
| x = self.conv2(x) | |
| x = self.sigmod(x) | |
| x2 = x * x2 | |
| res = x2 + x1 | |
| return res | |
| class RRB(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=3): | |
| super(RRB, self).__init__() | |
| self.conv1 = nn.Conv2d( | |
| in_channels, out_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| self.conv2 = nn.Conv2d( | |
| out_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=kernel_size // 2, | |
| ) | |
| self.relu = nn.ReLU() | |
| self.bn = nn.BatchNorm2d(out_channels) | |
| self.conv3 = nn.Conv2d( | |
| out_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=kernel_size // 2, | |
| ) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| res = self.conv2(x) | |
| res = self.bn(res) | |
| res = self.relu(res) | |
| res = self.conv3(res) | |
| return self.relu(x + res) | |
| class DFN(nn.Module): | |
| def __init__( | |
| self, | |
| internal_dim, | |
| feat_input_modules, | |
| pred_input_modules, | |
| rrb_d_dict, | |
| cab_dict, | |
| rrb_u_dict, | |
| use_global_context=False, | |
| global_dim=None, | |
| terminal_module=None, | |
| upsample_mode="bilinear", | |
| align_corners=False, | |
| ): | |
| super().__init__() | |
| if use_global_context: | |
| assert ( | |
| global_dim is not None | |
| ), "Global dim must be provided when using global context" | |
| self.align_corners = align_corners | |
| self.internal_dim = internal_dim | |
| self.feat_input_modules = feat_input_modules | |
| self.pred_input_modules = pred_input_modules | |
| self.rrb_d = rrb_d_dict | |
| self.cab = cab_dict | |
| self.rrb_u = rrb_u_dict | |
| self.use_global_context = use_global_context | |
| if use_global_context: | |
| self.global_to_internal = nn.Conv2d(global_dim, self.internal_dim, 1, 1, 0) | |
| self.global_pooling = nn.AdaptiveAvgPool2d(1) | |
| self.terminal_module = ( | |
| terminal_module if terminal_module is not None else nn.Identity() | |
| ) | |
| self.upsample_mode = upsample_mode | |
| self._scales = [int(key) for key in self.terminal_module.keys()] | |
| def scales(self): | |
| return self._scales.copy() | |
| def forward(self, embeddings, feats, context, key): | |
| feats = self.feat_input_modules[str(key)](feats) | |
| embeddings = torch.cat([feats, embeddings], dim=1) | |
| embeddings = self.rrb_d[str(key)](embeddings) | |
| context = self.cab[str(key)]([context, embeddings]) | |
| context = self.rrb_u[str(key)](context) | |
| preds = self.terminal_module[str(key)](context) | |
| pred_coord = preds[:, -2:] | |
| pred_certainty = preds[:, :-2] | |
| return pred_coord, pred_certainty, context | |
| class GP(nn.Module): | |
| def __init__( | |
| self, | |
| kernel, | |
| T=1, | |
| learn_temperature=False, | |
| only_attention=False, | |
| gp_dim=64, | |
| basis="fourier", | |
| covar_size=5, | |
| only_nearest_neighbour=False, | |
| sigma_noise=0.1, | |
| no_cov=False, | |
| predict_features=False, | |
| ): | |
| super().__init__() | |
| self.K = kernel(T=T, learn_temperature=learn_temperature) | |
| self.sigma_noise = sigma_noise | |
| self.covar_size = covar_size | |
| self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1) | |
| self.only_attention = only_attention | |
| self.only_nearest_neighbour = only_nearest_neighbour | |
| self.basis = basis | |
| self.no_cov = no_cov | |
| self.dim = gp_dim | |
| self.predict_features = predict_features | |
| def get_local_cov(self, cov): | |
| K = self.covar_size | |
| b, h, w, h, w = cov.shape | |
| hw = h * w | |
| cov = F.pad(cov, 4 * (K // 2,)) # pad v_q | |
| delta = torch.stack( | |
| torch.meshgrid( | |
| torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1) | |
| ), | |
| dim=-1, | |
| ) | |
| positions = torch.stack( | |
| torch.meshgrid( | |
| torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2) | |
| ), | |
| dim=-1, | |
| ) | |
| neighbours = positions[:, :, None, None, :] + delta[None, :, :] | |
| points = torch.arange(hw)[:, None].expand(hw, K**2) | |
| local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[ | |
| :, | |
| points.flatten(), | |
| neighbours[..., 0].flatten(), | |
| neighbours[..., 1].flatten(), | |
| ].reshape(b, h, w, K**2) | |
| return local_cov | |
| def reshape(self, x): | |
| return rearrange(x, "b d h w -> b (h w) d") | |
| def project_to_basis(self, x): | |
| if self.basis == "fourier": | |
| return torch.cos(8 * math.pi * self.pos_conv(x)) | |
| elif self.basis == "linear": | |
| return self.pos_conv(x) | |
| else: | |
| raise ValueError( | |
| "No other bases other than fourier and linear currently supported in public release" | |
| ) | |
| def get_pos_enc(self, y): | |
| b, c, h, w = y.shape | |
| coarse_coords = torch.meshgrid( | |
| ( | |
| torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device), | |
| torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device), | |
| ) | |
| ) | |
| coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ | |
| None | |
| ].expand(b, h, w, 2) | |
| coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") | |
| coarse_embedded_coords = self.project_to_basis(coarse_coords) | |
| return coarse_embedded_coords | |
| def forward(self, x, y, **kwargs): | |
| b, c, h1, w1 = x.shape | |
| b, c, h2, w2 = y.shape | |
| f = self.get_pos_enc(y) | |
| if self.predict_features: | |
| f = f + y[:, : self.dim] # Stupid way to predict features | |
| b, d, h2, w2 = f.shape | |
| # assert x.shape == y.shape | |
| x, y, f = self.reshape(x), self.reshape(y), self.reshape(f) | |
| K_xx = self.K(x, x) | |
| K_yy = self.K(y, y) | |
| K_xy = self.K(x, y) | |
| K_yx = K_xy.permute(0, 2, 1) | |
| sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :] | |
| # Due to https://github.com/pytorch/pytorch/issues/16963 annoying warnings, remove batch if N large | |
| if len(K_yy[0]) > 2000: | |
| K_yy_inv = torch.cat( | |
| [ | |
| torch.linalg.inv(K_yy[k : k + 1] + sigma_noise[k : k + 1]) | |
| for k in range(b) | |
| ] | |
| ) | |
| else: | |
| K_yy_inv = torch.linalg.inv(K_yy + sigma_noise) | |
| mu_x = K_xy.matmul(K_yy_inv.matmul(f)) | |
| mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1) | |
| if not self.no_cov: | |
| cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx)) | |
| cov_x = rearrange( | |
| cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1 | |
| ) | |
| local_cov_x = self.get_local_cov(cov_x) | |
| local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w") | |
| gp_feats = torch.cat((mu_x, local_cov_x), dim=1) | |
| else: | |
| gp_feats = mu_x | |
| return gp_feats | |
| class Encoder(nn.Module): | |
| def __init__(self, resnet): | |
| super().__init__() | |
| self.resnet = resnet | |
| def forward(self, x): | |
| x0 = x | |
| b, c, h, w = x.shape | |
| x = self.resnet.conv1(x) | |
| x = self.resnet.bn1(x) | |
| x1 = self.resnet.relu(x) | |
| x = self.resnet.maxpool(x1) | |
| x2 = self.resnet.layer1(x) | |
| x3 = self.resnet.layer2(x2) | |
| x4 = self.resnet.layer3(x3) | |
| x5 = self.resnet.layer4(x4) | |
| feats = {32: x5, 16: x4, 8: x3, 4: x2, 2: x1, 1: x0} | |
| return feats | |
| def train(self, mode=True): | |
| super().train(mode) | |
| for m in self.modules(): | |
| if isinstance(m, nn.BatchNorm2d): | |
| m.eval() | |
| pass | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| embedding_decoder, | |
| gps, | |
| proj, | |
| conv_refiner, | |
| transformers=None, | |
| detach=False, | |
| scales="all", | |
| pos_embeddings=None, | |
| ): | |
| super().__init__() | |
| self.embedding_decoder = embedding_decoder | |
| self.gps = gps | |
| self.proj = proj | |
| self.conv_refiner = conv_refiner | |
| self.detach = detach | |
| if scales == "all": | |
| self.scales = ["32", "16", "8", "4", "2", "1"] | |
| else: | |
| self.scales = scales | |
| def upsample_preds(self, flow, certainty, query, support): | |
| b, hs, ws, d = flow.shape | |
| b, c, h, w = query.shape | |
| flow = flow.permute(0, 3, 1, 2) | |
| certainty = F.interpolate( | |
| certainty, size=(h, w), align_corners=False, mode="bilinear" | |
| ) | |
| flow = F.interpolate(flow, size=(h, w), align_corners=False, mode="bilinear") | |
| delta_certainty, delta_flow = self.conv_refiner["1"](query, support, flow) | |
| flow = torch.stack( | |
| ( | |
| flow[:, 0] + delta_flow[:, 0] / (4 * w), | |
| flow[:, 1] + delta_flow[:, 1] / (4 * h), | |
| ), | |
| dim=1, | |
| ) | |
| flow = flow.permute(0, 2, 3, 1) | |
| certainty = certainty + delta_certainty | |
| return flow, certainty | |
| def get_placeholder_flow(self, b, h, w, device): | |
| coarse_coords = torch.meshgrid( | |
| ( | |
| torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), | |
| torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), | |
| ) | |
| ) | |
| coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ | |
| None | |
| ].expand(b, h, w, 2) | |
| coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") | |
| return coarse_coords | |
| def forward(self, f1, f2, upsample=False, dense_flow=None, dense_certainty=None): | |
| coarse_scales = self.embedding_decoder.scales() | |
| all_scales = self.scales if not upsample else ["8", "4", "2", "1"] | |
| sizes = {scale: f1[scale].shape[-2:] for scale in f1} | |
| h, w = sizes[1] | |
| b = f1[1].shape[0] | |
| device = f1[1].device | |
| coarsest_scale = int(all_scales[0]) | |
| old_stuff = torch.zeros( | |
| b, | |
| self.embedding_decoder.internal_dim, | |
| *sizes[coarsest_scale], | |
| device=f1[coarsest_scale].device | |
| ) | |
| dense_corresps = {} | |
| if not upsample: | |
| dense_flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device) | |
| dense_certainty = 0.0 | |
| else: | |
| dense_flow = F.interpolate( | |
| dense_flow, | |
| size=sizes[coarsest_scale], | |
| align_corners=False, | |
| mode="bilinear", | |
| ) | |
| dense_certainty = F.interpolate( | |
| dense_certainty, | |
| size=sizes[coarsest_scale], | |
| align_corners=False, | |
| mode="bilinear", | |
| ) | |
| for new_scale in all_scales: | |
| ins = int(new_scale) | |
| f1_s, f2_s = f1[ins], f2[ins] | |
| if new_scale in self.proj: | |
| f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s) | |
| b, c, hs, ws = f1_s.shape | |
| if ins in coarse_scales: | |
| old_stuff = F.interpolate( | |
| old_stuff, size=sizes[ins], mode="bilinear", align_corners=False | |
| ) | |
| new_stuff = self.gps[new_scale](f1_s, f2_s, dense_flow=dense_flow) | |
| dense_flow, dense_certainty, old_stuff = self.embedding_decoder( | |
| new_stuff, f1_s, old_stuff, new_scale | |
| ) | |
| if new_scale in self.conv_refiner: | |
| delta_certainty, displacement = self.conv_refiner[new_scale]( | |
| f1_s, f2_s, dense_flow | |
| ) | |
| dense_flow = torch.stack( | |
| ( | |
| dense_flow[:, 0] + ins * displacement[:, 0] / (4 * w), | |
| dense_flow[:, 1] + ins * displacement[:, 1] / (4 * h), | |
| ), | |
| dim=1, | |
| ) | |
| dense_certainty = ( | |
| dense_certainty + delta_certainty | |
| ) # predict both certainty and displacement | |
| dense_corresps[ins] = { | |
| "dense_flow": dense_flow, | |
| "dense_certainty": dense_certainty, | |
| } | |
| if new_scale != "1": | |
| dense_flow = F.interpolate( | |
| dense_flow, | |
| size=sizes[ins // 2], | |
| align_corners=False, | |
| mode="bilinear", | |
| ) | |
| dense_certainty = F.interpolate( | |
| dense_certainty, | |
| size=sizes[ins // 2], | |
| align_corners=False, | |
| mode="bilinear", | |
| ) | |
| if self.detach: | |
| dense_flow = dense_flow.detach() | |
| dense_certainty = dense_certainty.detach() | |
| return dense_corresps | |
| class RegressionMatcher(nn.Module): | |
| def __init__( | |
| self, | |
| encoder, | |
| decoder, | |
| h=384, | |
| w=512, | |
| use_contrastive_loss=False, | |
| alpha=1, | |
| beta=0, | |
| sample_mode="threshold", | |
| upsample_preds=False, | |
| symmetric=False, | |
| name=None, | |
| use_soft_mutual_nearest_neighbours=False, | |
| ): | |
| super().__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| self.w_resized = w | |
| self.h_resized = h | |
| self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True) | |
| self.use_contrastive_loss = use_contrastive_loss | |
| self.alpha = alpha | |
| self.beta = beta | |
| self.sample_mode = sample_mode | |
| self.upsample_preds = upsample_preds | |
| self.symmetric = symmetric | |
| self.name = name | |
| self.sample_thresh = 0.05 | |
| self.upsample_res = (864, 1152) | |
| if use_soft_mutual_nearest_neighbours: | |
| assert symmetric, "MNS requires symmetric inference" | |
| self.use_soft_mutual_nearest_neighbours = use_soft_mutual_nearest_neighbours | |
| def extract_backbone_features(self, batch, batched=True, upsample=True): | |
| # TODO: only extract stride [1,2,4,8] for upsample = True | |
| x_q = batch["query"] | |
| x_s = batch["support"] | |
| if batched: | |
| X = torch.cat((x_q, x_s)) | |
| feature_pyramid = self.encoder(X) | |
| else: | |
| feature_pyramid = self.encoder(x_q), self.encoder(x_s) | |
| return feature_pyramid | |
| def sample( | |
| self, | |
| dense_matches, | |
| dense_certainty, | |
| num=10000, | |
| ): | |
| if "threshold" in self.sample_mode: | |
| upper_thresh = self.sample_thresh | |
| dense_certainty = dense_certainty.clone() | |
| dense_certainty[dense_certainty > upper_thresh] = 1 | |
| elif "pow" in self.sample_mode: | |
| dense_certainty = dense_certainty ** (1 / 3) | |
| elif "naive" in self.sample_mode: | |
| dense_certainty = torch.ones_like(dense_certainty) | |
| matches, certainty = ( | |
| dense_matches.reshape(-1, 4), | |
| dense_certainty.reshape(-1), | |
| ) | |
| expansion_factor = 4 if "balanced" in self.sample_mode else 1 | |
| good_samples = torch.multinomial( | |
| certainty, | |
| num_samples=min(expansion_factor * num, len(certainty)), | |
| replacement=False, | |
| ) | |
| good_matches, good_certainty = matches[good_samples], certainty[good_samples] | |
| if "balanced" not in self.sample_mode: | |
| return good_matches, good_certainty | |
| from ..utils.kde import kde | |
| density = kde(good_matches, std=0.1) | |
| p = 1 / (density + 1) | |
| p[ | |
| density < 10 | |
| ] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones | |
| balanced_samples = torch.multinomial( | |
| p, num_samples=min(num, len(good_certainty)), replacement=False | |
| ) | |
| return good_matches[balanced_samples], good_certainty[balanced_samples] | |
| def forward(self, batch, batched=True): | |
| feature_pyramid = self.extract_backbone_features(batch, batched=batched) | |
| if batched: | |
| f_q_pyramid = { | |
| scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items() | |
| } | |
| f_s_pyramid = { | |
| scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items() | |
| } | |
| else: | |
| f_q_pyramid, f_s_pyramid = feature_pyramid | |
| dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid) | |
| if self.training and self.use_contrastive_loss: | |
| return dense_corresps, (f_q_pyramid, f_s_pyramid) | |
| else: | |
| return dense_corresps | |
| def forward_symmetric(self, batch, upsample=False, batched=True): | |
| feature_pyramid = self.extract_backbone_features( | |
| batch, upsample=upsample, batched=batched | |
| ) | |
| f_q_pyramid = feature_pyramid | |
| f_s_pyramid = { | |
| scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0])) | |
| for scale, f_scale in feature_pyramid.items() | |
| } | |
| dense_corresps = self.decoder( | |
| f_q_pyramid, | |
| f_s_pyramid, | |
| upsample=upsample, | |
| **(batch["corresps"] if "corresps" in batch else {}) | |
| ) | |
| return dense_corresps | |
| def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B): | |
| kpts_A, kpts_B = matches[..., :2], matches[..., 2:] | |
| kpts_A = torch.stack( | |
| (W_A / 2 * (kpts_A[..., 0] + 1), H_A / 2 * (kpts_A[..., 1] + 1)), axis=-1 | |
| ) | |
| kpts_B = torch.stack( | |
| (W_B / 2 * (kpts_B[..., 0] + 1), H_B / 2 * (kpts_B[..., 1] + 1)), axis=-1 | |
| ) | |
| return kpts_A, kpts_B | |
| def match(self, im1_path, im2_path, *args, batched=False, device=None): | |
| assert not ( | |
| batched and self.upsample_preds | |
| ), "Cannot upsample preds if in batchmode (as we don't have access to high res images). You can turn off upsample_preds by model.upsample_preds = False " | |
| if isinstance(im1_path, (str, os.PathLike)): | |
| im1, im2 = Image.open(im1_path), Image.open(im2_path) | |
| else: # assume it is a PIL Image | |
| im1, im2 = im1_path, im2_path | |
| if device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| symmetric = self.symmetric | |
| self.train(False) | |
| with torch.no_grad(): | |
| if not batched: | |
| b = 1 | |
| w, h = im1.size | |
| w2, h2 = im2.size | |
| # Get images in good format | |
| ws = self.w_resized | |
| hs = self.h_resized | |
| test_transform = get_tuple_transform_ops( | |
| resize=(hs, ws), normalize=True | |
| ) | |
| query, support = test_transform((im1, im2)) | |
| batch = { | |
| "query": query[None].to(device), | |
| "support": support[None].to(device), | |
| } | |
| else: | |
| b, c, h, w = im1.shape | |
| b, c, h2, w2 = im2.shape | |
| assert w == w2 and h == h2, "For batched images we assume same size" | |
| batch = {"query": im1.to(device), "support": im2.to(device)} | |
| hs, ws = self.h_resized, self.w_resized | |
| finest_scale = 1 | |
| # Run matcher | |
| if symmetric: | |
| dense_corresps = self.forward_symmetric(batch, batched=True) | |
| else: | |
| dense_corresps = self.forward(batch, batched=True) | |
| if self.upsample_preds: | |
| hs, ws = self.upsample_res | |
| low_res_certainty = F.interpolate( | |
| dense_corresps[16]["dense_certainty"], | |
| size=(hs, ws), | |
| align_corners=False, | |
| mode="bilinear", | |
| ) | |
| cert_clamp = 0 | |
| factor = 0.5 | |
| low_res_certainty = ( | |
| factor * low_res_certainty * (low_res_certainty < cert_clamp) | |
| ) | |
| if self.upsample_preds: | |
| test_transform = get_tuple_transform_ops( | |
| resize=(hs, ws), normalize=True | |
| ) | |
| query, support = test_transform((im1, im2)) | |
| query, support = query[None].to(device), support[None].to(device) | |
| batch = { | |
| "query": query, | |
| "support": support, | |
| "corresps": dense_corresps[finest_scale], | |
| } | |
| if symmetric: | |
| dense_corresps = self.forward_symmetric( | |
| batch, upsample=True, batched=True | |
| ) | |
| else: | |
| dense_corresps = self.forward(batch, batched=True, upsample=True) | |
| query_to_support = dense_corresps[finest_scale]["dense_flow"] | |
| dense_certainty = dense_corresps[finest_scale]["dense_certainty"] | |
| # Get certainty interpolation | |
| dense_certainty = dense_certainty - low_res_certainty | |
| query_to_support = query_to_support.permute(0, 2, 3, 1) | |
| # Create im1 meshgrid | |
| query_coords = torch.meshgrid( | |
| ( | |
| torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), | |
| torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), | |
| ) | |
| ) | |
| query_coords = torch.stack((query_coords[1], query_coords[0])) | |
| query_coords = query_coords[None].expand(b, 2, hs, ws) | |
| dense_certainty = dense_certainty.sigmoid() # logits -> probs | |
| query_coords = query_coords.permute(0, 2, 3, 1) | |
| if (query_to_support.abs() > 1).any() and True: | |
| wrong = (query_to_support.abs() > 1).sum(dim=-1) > 0 | |
| dense_certainty[wrong[:, None]] = 0 | |
| query_to_support = torch.clamp(query_to_support, -1, 1) | |
| if symmetric: | |
| support_coords = query_coords | |
| qts, stq = query_to_support.chunk(2) | |
| q_warp = torch.cat((query_coords, qts), dim=-1) | |
| s_warp = torch.cat((stq, support_coords), dim=-1) | |
| warp = torch.cat((q_warp, s_warp), dim=2) | |
| dense_certainty = torch.cat(dense_certainty.chunk(2), dim=3)[:, 0] | |
| else: | |
| warp = torch.cat((query_coords, query_to_support), dim=-1) | |
| if batched: | |
| return (warp, dense_certainty) | |
| else: | |
| return ( | |
| warp[0], | |
| dense_certainty[0], | |
| ) | |