File size: 6,250 Bytes
1b369eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from .matchers import DualSoftmaxMatcher, DenseMatcher, LightGlue
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import kornia

class RDD_helper(nn.Module):
    def __init__(self, RDD):
        super().__init__()
        self.matcher = DualSoftmaxMatcher(inv_temperature = 20, thr = 0.01)
        self.dense_matcher = DenseMatcher(inv_temperature=20, thr=0.01)
        self.RDD = RDD
        self.lg_matcher = None
        
    @torch.inference_mode()
    def match(self, img0, img1, thr=0.01, resize=None, top_k=4096):
        if top_k is not None and top_k != self.RDD.top_k:
            self.RDD.top_k = top_k
            self.RDD.set_softdetect(top_k=top_k)
        
        img0, scale0 = self.parse_input(img0, resize)
        img1, scale1 = self.parse_input(img1, resize)

        out0 = self.RDD.extract(img0)[0]
        out1 = self.RDD.extract(img1)[0]
        
        # get top_k confident matches
        mkpts0, mkpts1, conf = self.matcher(out0, out1, thr)
        
        scale0 = 1.0 / scale0
        scale1 = 1.0 / scale1
        
        mkpts0 = mkpts0 * scale0
        mkpts1 = mkpts1 * scale1
        
        return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy()
    
    @torch.inference_mode()
    def match_lg(self, img0, img1, thr=0.01, resize=None, top_k=4096):
        if self.lg_matcher is None:
            lg_conf = {
                "name": "lightglue",  # just for interfacing
                "input_dim": 256,  # input descriptor dimension (autoselected from weights)
                "descriptor_dim": 256,
                "add_scale_ori": False,
                "n_layers": 9,
                "num_heads": 4,
                "flash": True,  # enable FlashAttention if available.
                "mp": False,  # enable mixed precision
                "filter_threshold": 0.01,  # match threshold
                "depth_confidence": -1,  # depth confidence threshold
                "width_confidence": -1,  # width confidence threshold
                "weights": './weights/RDD_lg-v2.pth',  # path to the weights
            }
            self.lg_matcher = LightGlue(features='rdd', conf=lg_conf).to(self.RDD.device)
            
        if top_k is not None and top_k != self.RDD.top_k:
            self.RDD.top_k = top_k
            self.RDD.set_softdetect(top_k=top_k)
            
        img0, scale0 = self.parse_input(img0, resize=resize)
        img1, scale1 = self.parse_input(img1, resize=resize)
        
        size0 = torch.tensor(img0.shape[-2:])[None]
        size1 = torch.tensor(img1.shape[-2:])[None]
        
        out0 = self.RDD.extract(img0)[0]
        out1 = self.RDD.extract(img1)[0]
        
        # get top_k confident matches
        image0_data = {
            'keypoints': out0['keypoints'][None],
            'descriptors': out0['descriptors'][None],
            'image_size': size0,
        }

        image1_data = {
            'keypoints': out1['keypoints'][None],
            'descriptors': out1['descriptors'][None],
            'image_size': size1,
        }
        
        pred = {}
        
        with torch.no_grad():
            pred.update({'image0': image0_data, 'image1': image1_data})
            pred.update(self.lg_matcher({**pred}))
        
        kpts0 = pred['image0']['keypoints'][0]
        kpts1 = pred['image1']['keypoints'][0]
        
        matches = pred['matches'][0]

        mkpts0 = kpts0[matches[... , 0]]
        mkpts1 = kpts1[matches[... , 1]]
        conf = pred['scores'][0]
        
        valid_mask = conf > thr
        mkpts0 = mkpts0[valid_mask]
        mkpts1 = mkpts1[valid_mask]
        conf = conf[valid_mask]
        
        scale0 = 1.0 / scale0
        scale1 = 1.0 / scale1
        mkpts0 = mkpts0 * scale0
        mkpts1 = mkpts1 * scale1
        
        return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy()
    
    @torch.inference_mode()
    def match_dense(self, img0, img1, thr=0.01, resize=None):
        
        img0, scale0 = self.parse_input(img0, resize=resize)
        img1, scale1 = self.parse_input(img1, resize=resize)

        out0 = self.RDD.extract_dense(img0)[0]
        out1 = self.RDD.extract_dense(img1)[0]
        
        # get top_k confident matches
        mkpts0, mkpts1, conf = self.dense_matcher(out0, out1, thr, err_thr=self.RDD.stride)
        
        scale0 = 1.0 / scale0
        scale1 = 1.0 / scale1
        
        mkpts0 = mkpts0 * scale0
        mkpts1 = mkpts1 * scale1
        
        return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy()
        
    @torch.inference_mode()
    def match_3rd_party(self, img0, img1, model='aliked', resize=None, thr=0.01):
        img0, scale0 = self.parse_input(img0, resize=resize)
        img1, scale1 = self.parse_input(img1, resize=resize)

        out0 = self.RDD.extract_3rd_party(img0, model=model)[0]
        out1 = self.RDD.extract_3rd_party(img1, model=model)[0]
        
        mkpts0, mkpts1, conf = self.matcher(out0, out1, thr)
        
        scale0 = 1.0 / scale0
        scale1 = 1.0 / scale1
        
        mkpts0 = mkpts0 * scale0
        mkpts1 = mkpts1 * scale1
        
        return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy()
    
    def parse_input(self, x, resize=None):
        if len(x.shape) == 3:
            x = x[None, ...]

        if isinstance(x, np.ndarray):
            x = torch.tensor(x).permute(0,3,1,2)/255
        
        h, w = x.shape[-2:]
        size = h, w
        
        if resize is not None:
            size = self.get_new_image_size(h, w, resize)
            x = kornia.geometry.transform.resize(
                x,
                size,
                side='long',
                antialias=True,
                align_corners=None,
                interpolation='bilinear',
            )
        scale = torch.Tensor([x.shape[-1] / w, x.shape[-2] / h]).to(self.RDD.device)
        
        return x, scale
    
    def get_new_image_size(self, h, w, resize=1600):
        aspect_ratio = w / h
        size = int(resize / aspect_ratio), resize

        size = list(map(lambda x: int(x // 32 * 32), size)) # make sure size is divisible by 32
        
        return size