wufan commited on
Commit
63ece4d
·
verified ·
1 Parent(s): 27eaed6

Delete visual_matcher.py

Browse files
Files changed (1) hide show
  1. visual_matcher.py +0 -209
visual_matcher.py DELETED
@@ -1,209 +0,0 @@
1
- import time
2
- import numpy as np
3
- from PIL import Image
4
- from scipy.spatial.distance import cdist
5
- from scipy.optimize import linear_sum_assignment
6
-
7
-
8
- class SimpleAffineTransform:
9
- """
10
- simple affine transform, only translation and scale.
11
- """
12
- def __init__(self, translation=(0, 0), scale=1.0):
13
- self.translation = np.array(translation)
14
- self.scale = scale
15
-
16
- def estimate(self, src, dst):
17
- src_center = np.mean(src, axis=0)
18
- dst_center = np.mean(dst, axis=0)
19
- self.translation = dst_center - src_center
20
-
21
- src_dists = np.linalg.norm(src - src_center, axis=1)
22
- dst_dists = np.linalg.norm(dst - dst_center, axis=1)
23
- self.scale = np.mean(dst_dists) / (np.mean(src_dists) + 1e-10)
24
-
25
- def inverse(self):
26
- inverse_transform = AffineTransform(-self.translation, 1.0/self.scale)
27
- return inverse_transform
28
-
29
- def __call__(self, coords):
30
- return self.scale * (coords - np.mean(coords, axis=0)) + np.mean(coords, axis=0) + self.translation
31
-
32
- def residuals(self, src, dst):
33
- return np.sqrt(np.sum((self(src) - dst) ** 2, axis=1))
34
-
35
-
36
- def norm_coords(x, left, right):
37
- if x < left:
38
- return left
39
- if x > right:
40
- return right
41
- return x
42
-
43
- def norm_same_token(token):
44
- special_map = {
45
- "\\dot": ".",
46
- "\\Dot": ".",
47
- "\\cdot": ".",
48
- "\\cdotp": ".",
49
- "\\ldotp": ".",
50
- "\\mid": "|",
51
- "\\rightarrow": "\\to",
52
- "\\top": "T",
53
- "\\Tilde": "\\tilde",
54
- "\\prime": "'",
55
- "\\ast": "*",
56
- "\\left<": "\\langle",
57
- "\\right>": "\\rangle",
58
- "\\lbrace": "\{",
59
- "\\rbrace": "\}",
60
- "\\lbrack": "[",
61
- "\\rbrack": "]",
62
- "\\blackslash": "/",
63
- "\\slash": "/",
64
- "\\leq": "\\le",
65
- "\\geq": "\\ge",
66
- "\\neq": "\\ne",
67
- "\\Vert": "\\|",
68
- "\\lVert": "\\|",
69
- "\\rVert": "\\|",
70
- "\\vert": "|",
71
- "\\lvert": "|",
72
- "\\rvert": "|",
73
- "\\colon": ":",
74
- "\\Ddot": "\\ddot",
75
- "\\Bar": "\\bar",
76
- "\\Vec": "\\vec",
77
- "\\parallel": "\\|",
78
- "\\dag": "\\dagger",
79
- "\\ddag": "\\ddagger",
80
- "\\textlangle": "<",
81
- "\\textrangle": ">",
82
- "\\textgreater": ">",
83
- "\\textless": "<",
84
- "\\textbackslash": "n",
85
- "\\textunderscore": "_",
86
- "\\=": "_",
87
- "\\neg": "\\lnot",
88
- "\\neq": "\\not=",
89
- }
90
- if token.startswith('\\left') or token.startswith('\\right'):
91
- if "arrow" not in token and "<" not in token and ">" not in token and "harpoon" not in token:
92
- token = token.replace("\\left", "").replace("\\right", "")
93
- if token.startswith('\\big') or token.startswith('\\Big'):
94
- if "\\" in token[4:]:
95
- token = "\\"+token[4:].split("\\")[-1]
96
- else:
97
- token = token[-1]
98
- if token in special_map.keys():
99
- token = special_map[token]
100
- if token.startswith('\\wide'):
101
- return token.replace("wide", "")
102
- if token.startswith('\\var'):
103
- return token.replace("var", "")
104
- if token.startswith('\\string'):
105
- return token.replace("\\string", "")
106
- return token
107
-
108
-
109
- class HungarianMatcher:
110
- def __init__(
111
- self,
112
- cost_token: float = 1,
113
- cost_position: float = 0.05,
114
- cost_order: float = 0.15,
115
- ):
116
- self.cost_token = cost_token
117
- self.cost_position = cost_position
118
- self.cost_order = cost_order
119
- self.cost = {}
120
-
121
- def calculate_token_cost(self, box_gt, box_pred):
122
- token2id = {}
123
- for data in box_gt+box_pred:
124
- if data['token'] not in token2id:
125
- token2id[data['token']] = len(token2id)
126
- num_classes = len(token2id)
127
-
128
- token2id_norm = {}
129
- for data in box_gt+box_pred:
130
- if norm_same_token(data['token']) not in token2id_norm:
131
- token2id_norm[norm_same_token(data['token'])] = len(token2id_norm)
132
- num_classes_norm = len(token2id_norm)
133
-
134
- gt_token_array = []
135
- norm_gt_token_array = []
136
- for data in box_gt:
137
- gt_token_array.append(token2id[data['token']])
138
- norm_gt_token_array.append(token2id_norm[norm_same_token(data['token'])])
139
-
140
- pred_token_logits = []
141
- norm_pred_token_logits = []
142
- for data in box_pred:
143
- logits = [0] * num_classes
144
- logits[token2id[data['token']]] = 1
145
- pred_token_logits.append(logits)
146
-
147
- logits_norm = [0] * num_classes_norm
148
- logits_norm[token2id_norm[norm_same_token(data['token'])]] = 1
149
- norm_pred_token_logits.append(logits_norm)
150
-
151
- gt_token_array = np.array(gt_token_array)
152
- pred_token_logits = np.array(pred_token_logits)
153
-
154
- norm_gt_token_array = np.array(norm_gt_token_array)
155
- norm_pred_token_logits = np.array(norm_pred_token_logits)
156
-
157
- token_cost = 1.0 - pred_token_logits[:, gt_token_array]
158
- norm_token_cost = 1.0 - norm_pred_token_logits[:, norm_gt_token_array]
159
-
160
- token_cost[np.logical_and(token_cost==1, norm_token_cost==0)] = 0.005
161
- return token_cost.T
162
-
163
-
164
- def box2array(self, box_list, size):
165
- W, H = size
166
- box_array = []
167
- for box in box_list:
168
- x_min, y_min, x_max, y_max = box['bbox']
169
- box_array.append([x_min/W, y_min/H, x_max/W, y_max/H])
170
- return np.array(box_array)
171
-
172
- def order2array(self, box_list, max_token_lens=None):
173
- if not max_token_lens:
174
- max_token_lens = len(box_list)
175
- order_array = []
176
- for idx, box in enumerate(box_list):
177
- order_array.append([idx / max_token_lens])
178
- return np.array(order_array)
179
-
180
- def calculate_l1_cost(self, gt_array, pred_array):
181
- scale = gt_array.shape[-1]
182
- l1_cost = cdist(gt_array, pred_array, 'minkowski', p=1)
183
- return l1_cost / scale
184
-
185
- def __call__(self, box_gt, box_pred, gt_size, pred_size):
186
- aa = time.time()
187
- gt_box_array = self.box2array(box_gt, gt_size)
188
- pred_box_array = self.box2array(box_pred, pred_size)
189
-
190
- max_token_lens = max(len(box_gt), len(box_pred))
191
- gt_order_array = self.order2array(box_gt, max_token_lens)
192
- pred_order_array = self.order2array(box_pred, max_token_lens)
193
-
194
- token_cost = self.calculate_token_cost(box_gt, box_pred)
195
- position_cost = self.calculate_l1_cost(gt_box_array, pred_box_array)
196
- order_cost = self.calculate_l1_cost(gt_order_array, pred_order_array)
197
-
198
- self.cost["token"] = token_cost
199
- self.cost["position"] = position_cost
200
- self.cost["order"] = order_cost
201
-
202
- cost = self.cost_token * token_cost + self.cost_position * position_cost + self.cost_order * order_cost
203
- cost[np.isnan(cost) | np.isinf(cost)] = 100
204
- indexes = linear_sum_assignment(cost)
205
- matched_idxes = []
206
- for a, b in zip(*indexes):
207
- matched_idxes.append((a, b))
208
-
209
- return matched_idxes