Spaces:
Running
Running
""" | |
"LiftFeat: 3D Geometry-Aware Local Feature Matching" | |
This script is used to interpolate rough descriptors from LiftFeat | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class InterpolateSparse2d(nn.Module): | |
""" Efficiently interpolate tensor at given sparse 2D positions. """ | |
def __init__(self, mode = 'bicubic', align_corners = False): | |
super().__init__() | |
self.mode = mode | |
self.align_corners = align_corners | |
def normgrid(self, x, H, W): | |
""" Normalize coords to [-1,1]. """ | |
return 2. * (x/(torch.tensor([W-1, H-1], device = x.device, dtype = x.dtype))) - 1. | |
def forward(self, x, pos, H, W): | |
""" | |
Input | |
x: [B, C, H, W] feature tensor | |
pos: [B, N, 2] tensor of positions | |
H, W: int, original resolution of input 2d positions -- used in normalization [-1,1] | |
Returns | |
[B, N, C] sampled channels at 2d positions | |
""" | |
grid = self.normgrid(pos, H, W).unsqueeze(-2).to(x.dtype) | |
x = F.grid_sample(x, grid, mode = self.mode , align_corners = False) | |
return x.permute(0,2,3,1).squeeze(-2) |