Spaces:
Running
Running
from .matchers import DualSoftmaxMatcher, DenseMatcher, LightGlue | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import torch.nn.functional as F | |
import kornia | |
class RDD_helper(nn.Module): | |
def __init__(self, RDD): | |
super().__init__() | |
self.matcher = DualSoftmaxMatcher(inv_temperature = 20, thr = 0.01) | |
self.dense_matcher = DenseMatcher(inv_temperature=20, thr=0.01) | |
self.RDD = RDD | |
self.lg_matcher = None | |
def match(self, img0, img1, thr=0.01, resize=None, top_k=4096): | |
if top_k is not None and top_k != self.RDD.top_k: | |
self.RDD.top_k = top_k | |
self.RDD.set_softdetect(top_k=top_k) | |
img0, scale0 = self.parse_input(img0, resize) | |
img1, scale1 = self.parse_input(img1, resize) | |
out0 = self.RDD.extract(img0)[0] | |
out1 = self.RDD.extract(img1)[0] | |
# get top_k confident matches | |
mkpts0, mkpts1, conf = self.matcher(out0, out1, thr) | |
scale0 = 1.0 / scale0 | |
scale1 = 1.0 / scale1 | |
mkpts0 = mkpts0 * scale0 | |
mkpts1 = mkpts1 * scale1 | |
return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy() | |
def match_lg(self, img0, img1, thr=0.01, resize=None, top_k=4096): | |
if self.lg_matcher is None: | |
lg_conf = { | |
"name": "lightglue", # just for interfacing | |
"input_dim": 256, # input descriptor dimension (autoselected from weights) | |
"descriptor_dim": 256, | |
"add_scale_ori": False, | |
"n_layers": 9, | |
"num_heads": 4, | |
"flash": True, # enable FlashAttention if available. | |
"mp": False, # enable mixed precision | |
"filter_threshold": 0.01, # match threshold | |
"depth_confidence": -1, # depth confidence threshold | |
"width_confidence": -1, # width confidence threshold | |
"weights": './weights/RDD_lg-v2.pth', # path to the weights | |
} | |
self.lg_matcher = LightGlue(features='rdd', conf=lg_conf).to(self.RDD.device) | |
if top_k is not None and top_k != self.RDD.top_k: | |
self.RDD.top_k = top_k | |
self.RDD.set_softdetect(top_k=top_k) | |
img0, scale0 = self.parse_input(img0, resize=resize) | |
img1, scale1 = self.parse_input(img1, resize=resize) | |
size0 = torch.tensor(img0.shape[-2:])[None] | |
size1 = torch.tensor(img1.shape[-2:])[None] | |
out0 = self.RDD.extract(img0)[0] | |
out1 = self.RDD.extract(img1)[0] | |
# get top_k confident matches | |
image0_data = { | |
'keypoints': out0['keypoints'][None], | |
'descriptors': out0['descriptors'][None], | |
'image_size': size0, | |
} | |
image1_data = { | |
'keypoints': out1['keypoints'][None], | |
'descriptors': out1['descriptors'][None], | |
'image_size': size1, | |
} | |
pred = {} | |
with torch.no_grad(): | |
pred.update({'image0': image0_data, 'image1': image1_data}) | |
pred.update(self.lg_matcher({**pred})) | |
kpts0 = pred['image0']['keypoints'][0] | |
kpts1 = pred['image1']['keypoints'][0] | |
matches = pred['matches'][0] | |
mkpts0 = kpts0[matches[... , 0]] | |
mkpts1 = kpts1[matches[... , 1]] | |
conf = pred['scores'][0] | |
valid_mask = conf > thr | |
mkpts0 = mkpts0[valid_mask] | |
mkpts1 = mkpts1[valid_mask] | |
conf = conf[valid_mask] | |
scale0 = 1.0 / scale0 | |
scale1 = 1.0 / scale1 | |
mkpts0 = mkpts0 * scale0 | |
mkpts1 = mkpts1 * scale1 | |
return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy() | |
def match_dense(self, img0, img1, thr=0.01, resize=None): | |
img0, scale0 = self.parse_input(img0, resize=resize) | |
img1, scale1 = self.parse_input(img1, resize=resize) | |
out0 = self.RDD.extract_dense(img0)[0] | |
out1 = self.RDD.extract_dense(img1)[0] | |
# get top_k confident matches | |
mkpts0, mkpts1, conf = self.dense_matcher(out0, out1, thr, err_thr=self.RDD.stride) | |
scale0 = 1.0 / scale0 | |
scale1 = 1.0 / scale1 | |
mkpts0 = mkpts0 * scale0 | |
mkpts1 = mkpts1 * scale1 | |
return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy() | |
def match_3rd_party(self, img0, img1, model='aliked', resize=None, thr=0.01): | |
img0, scale0 = self.parse_input(img0, resize=resize) | |
img1, scale1 = self.parse_input(img1, resize=resize) | |
out0 = self.RDD.extract_3rd_party(img0, model=model)[0] | |
out1 = self.RDD.extract_3rd_party(img1, model=model)[0] | |
mkpts0, mkpts1, conf = self.matcher(out0, out1, thr) | |
scale0 = 1.0 / scale0 | |
scale1 = 1.0 / scale1 | |
mkpts0 = mkpts0 * scale0 | |
mkpts1 = mkpts1 * scale1 | |
return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy() | |
def parse_input(self, x, resize=None): | |
if len(x.shape) == 3: | |
x = x[None, ...] | |
if isinstance(x, np.ndarray): | |
x = torch.tensor(x).permute(0,3,1,2)/255 | |
h, w = x.shape[-2:] | |
size = h, w | |
if resize is not None: | |
size = self.get_new_image_size(h, w, resize) | |
x = kornia.geometry.transform.resize( | |
x, | |
size, | |
side='long', | |
antialias=True, | |
align_corners=None, | |
interpolation='bilinear', | |
) | |
scale = torch.Tensor([x.shape[-1] / w, x.shape[-2] / h]).to(self.RDD.device) | |
return x, scale | |
def get_new_image_size(self, h, w, resize=1600): | |
aspect_ratio = w / h | |
size = int(resize / aspect_ratio), resize | |
size = list(map(lambda x: int(x // 32 * 32), size)) # make sure size is divisible by 32 | |
return size |