Realcat's picture
add: rdd sparse and dense match
1b369eb
from .matchers import DualSoftmaxMatcher, DenseMatcher, LightGlue
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import kornia
class RDD_helper(nn.Module):
def __init__(self, RDD):
super().__init__()
self.matcher = DualSoftmaxMatcher(inv_temperature = 20, thr = 0.01)
self.dense_matcher = DenseMatcher(inv_temperature=20, thr=0.01)
self.RDD = RDD
self.lg_matcher = None
@torch.inference_mode()
def match(self, img0, img1, thr=0.01, resize=None, top_k=4096):
if top_k is not None and top_k != self.RDD.top_k:
self.RDD.top_k = top_k
self.RDD.set_softdetect(top_k=top_k)
img0, scale0 = self.parse_input(img0, resize)
img1, scale1 = self.parse_input(img1, resize)
out0 = self.RDD.extract(img0)[0]
out1 = self.RDD.extract(img1)[0]
# get top_k confident matches
mkpts0, mkpts1, conf = self.matcher(out0, out1, thr)
scale0 = 1.0 / scale0
scale1 = 1.0 / scale1
mkpts0 = mkpts0 * scale0
mkpts1 = mkpts1 * scale1
return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy()
@torch.inference_mode()
def match_lg(self, img0, img1, thr=0.01, resize=None, top_k=4096):
if self.lg_matcher is None:
lg_conf = {
"name": "lightglue", # just for interfacing
"input_dim": 256, # input descriptor dimension (autoselected from weights)
"descriptor_dim": 256,
"add_scale_ori": False,
"n_layers": 9,
"num_heads": 4,
"flash": True, # enable FlashAttention if available.
"mp": False, # enable mixed precision
"filter_threshold": 0.01, # match threshold
"depth_confidence": -1, # depth confidence threshold
"width_confidence": -1, # width confidence threshold
"weights": './weights/RDD_lg-v2.pth', # path to the weights
}
self.lg_matcher = LightGlue(features='rdd', conf=lg_conf).to(self.RDD.device)
if top_k is not None and top_k != self.RDD.top_k:
self.RDD.top_k = top_k
self.RDD.set_softdetect(top_k=top_k)
img0, scale0 = self.parse_input(img0, resize=resize)
img1, scale1 = self.parse_input(img1, resize=resize)
size0 = torch.tensor(img0.shape[-2:])[None]
size1 = torch.tensor(img1.shape[-2:])[None]
out0 = self.RDD.extract(img0)[0]
out1 = self.RDD.extract(img1)[0]
# get top_k confident matches
image0_data = {
'keypoints': out0['keypoints'][None],
'descriptors': out0['descriptors'][None],
'image_size': size0,
}
image1_data = {
'keypoints': out1['keypoints'][None],
'descriptors': out1['descriptors'][None],
'image_size': size1,
}
pred = {}
with torch.no_grad():
pred.update({'image0': image0_data, 'image1': image1_data})
pred.update(self.lg_matcher({**pred}))
kpts0 = pred['image0']['keypoints'][0]
kpts1 = pred['image1']['keypoints'][0]
matches = pred['matches'][0]
mkpts0 = kpts0[matches[... , 0]]
mkpts1 = kpts1[matches[... , 1]]
conf = pred['scores'][0]
valid_mask = conf > thr
mkpts0 = mkpts0[valid_mask]
mkpts1 = mkpts1[valid_mask]
conf = conf[valid_mask]
scale0 = 1.0 / scale0
scale1 = 1.0 / scale1
mkpts0 = mkpts0 * scale0
mkpts1 = mkpts1 * scale1
return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy()
@torch.inference_mode()
def match_dense(self, img0, img1, thr=0.01, resize=None):
img0, scale0 = self.parse_input(img0, resize=resize)
img1, scale1 = self.parse_input(img1, resize=resize)
out0 = self.RDD.extract_dense(img0)[0]
out1 = self.RDD.extract_dense(img1)[0]
# get top_k confident matches
mkpts0, mkpts1, conf = self.dense_matcher(out0, out1, thr, err_thr=self.RDD.stride)
scale0 = 1.0 / scale0
scale1 = 1.0 / scale1
mkpts0 = mkpts0 * scale0
mkpts1 = mkpts1 * scale1
return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy()
@torch.inference_mode()
def match_3rd_party(self, img0, img1, model='aliked', resize=None, thr=0.01):
img0, scale0 = self.parse_input(img0, resize=resize)
img1, scale1 = self.parse_input(img1, resize=resize)
out0 = self.RDD.extract_3rd_party(img0, model=model)[0]
out1 = self.RDD.extract_3rd_party(img1, model=model)[0]
mkpts0, mkpts1, conf = self.matcher(out0, out1, thr)
scale0 = 1.0 / scale0
scale1 = 1.0 / scale1
mkpts0 = mkpts0 * scale0
mkpts1 = mkpts1 * scale1
return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy()
def parse_input(self, x, resize=None):
if len(x.shape) == 3:
x = x[None, ...]
if isinstance(x, np.ndarray):
x = torch.tensor(x).permute(0,3,1,2)/255
h, w = x.shape[-2:]
size = h, w
if resize is not None:
size = self.get_new_image_size(h, w, resize)
x = kornia.geometry.transform.resize(
x,
size,
side='long',
antialias=True,
align_corners=None,
interpolation='bilinear',
)
scale = torch.Tensor([x.shape[-1] / w, x.shape[-2] / h]).to(self.RDD.device)
return x, scale
def get_new_image_size(self, h, w, resize=1600):
aspect_ratio = w / h
size = int(resize / aspect_ratio), resize
size = list(map(lambda x: int(x // 32 * 32), size)) # make sure size is divisible by 32
return size