wufan commited on
Commit
ad8cb30
·
verified ·
1 Parent(s): dab2a52

Upload 3 files

Browse files

update CDM:
1. support chinese formula
2. update process speed
3. update match and check code
4. fix '\n' bug

Files changed (3) hide show
  1. latex2bbox_color.py +215 -0
  2. latex_processor.py +536 -0
  3. visual_matcher.py +191 -0
latex2bbox_color.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import shutil
5
+ import logging
6
+ import subprocess
7
+ import numpy as np
8
+
9
+ from threading import Timer
10
+ from PIL import Image, ImageDraw
11
+ from modules.latex_processor import (
12
+ normalize_latex,
13
+ token_add_color_RGB,
14
+ clean_latex
15
+ )
16
+ from modules.tokenize_latex.tokenize_latex import tokenize_latex
17
+
18
+
19
+ tabular_template = r"""
20
+ \documentclass[12pt]{article}
21
+ \usepackage[landscape]{geometry}
22
+ \usepackage{geometry}
23
+ \geometry{a<PaperSize>paper,scale=0.98}
24
+ \pagestyle{empty}
25
+ \usepackage{booktabs}
26
+ \usepackage{multirow}
27
+ \usepackage{amssymb}
28
+ \usepackage{upgreek}
29
+ \usepackage{amsmath}
30
+ \usepackage{xcolor}
31
+ \begin{document}
32
+ \makeatletter
33
+ \renewcommand*{\@textcolor}[3]{%%
34
+ \protect\leavevmode
35
+ \begingroup
36
+ \color#1{#2}#3%%
37
+ \endgroup
38
+ }
39
+ \makeatother
40
+ \begin{displaymath}
41
+ %s
42
+ \end{displaymath}
43
+ \end{document}
44
+ """
45
+
46
+ formular_template = r"""
47
+ \documentclass[12pt]{article}
48
+ \usepackage[landscape]{geometry}
49
+ \usepackage{geometry}
50
+ \geometry{a<PaperSize>paper,scale=0.98}
51
+ \pagestyle{empty}
52
+ \usepackage{booktabs}
53
+ \usepackage{amsmath}
54
+ \usepackage{upgreek}
55
+ \usepackage{amssymb}
56
+ \usepackage{xcolor}
57
+ \begin{document}
58
+ \makeatletter
59
+ \renewcommand*{\@textcolor}[3]{%%
60
+ \protect\leavevmode
61
+ \begingroup
62
+ \color#1{#2}#3%%
63
+ \endgroup
64
+ }
65
+ \makeatother
66
+ \begin{displaymath}
67
+ %s
68
+ \end{displaymath}
69
+ \end{document}
70
+ """
71
+
72
+
73
+ def run_cmd(cmd, timeout_sec=30):
74
+ proc = subprocess.Popen(cmd, shell=True)
75
+ kill_proc = lambda p: p.kill()
76
+ timer = Timer(timeout_sec, kill_proc, [proc])
77
+ try:
78
+ timer.start()
79
+ stdout,stderr = proc.communicate()
80
+ finally:
81
+ timer.cancel()
82
+
83
+ def convert_pdf2img(pdf_filename, png_filename):
84
+ cmd = "magick -density 200 -quality 100 %s %s"%(pdf_filename, png_filename)
85
+ os.system(cmd)
86
+
87
+ def crop_image(image_path, pad=8):
88
+ img = Image.open(image_path).convert("L")
89
+ img_data = np.asarray(img, dtype=np.uint8)
90
+ nnz_inds = np.where(img_data!=255)
91
+ if len(nnz_inds[0]) == 0:
92
+ y_min = 0
93
+ y_max = 10
94
+ x_min = 0
95
+ x_max = 10
96
+ else:
97
+ y_min = np.min(nnz_inds[0])
98
+ y_max = np.max(nnz_inds[0])
99
+ x_min = np.min(nnz_inds[1])
100
+ x_max = np.max(nnz_inds[1])
101
+
102
+ img = Image.open(image_path).convert("RGB").crop((x_min-pad, y_min-pad, x_max+pad, y_max+pad))
103
+ img.save(image_path)
104
+
105
+ def extrac_bbox_from_color_image(image_path, color_list):
106
+ img = Image.open(image_path).convert("RGB")
107
+ W, H = img.size
108
+ pixels = list(img.getdata())
109
+
110
+ bbox_list = []
111
+ for target_color in color_list:
112
+ target_pixels = [ i for i, pixel in enumerate(pixels)if pixel == target_color ]
113
+ x_list = []
114
+ y_list = []
115
+ for idx in target_pixels:
116
+ x_list.append(idx % W)
117
+ y_list.append(idx // W)
118
+ try:
119
+ y_min, y_max, x_min, x_max = min(y_list), max(y_list), min(x_list), max(x_list)
120
+ bbox_list.append([x_min-1, y_min-1, x_max+1, y_max+1])
121
+
122
+ except:
123
+ bbox_list.append([])
124
+ continue
125
+
126
+ img = img.convert("L")
127
+ img_bw = img.point(lambda x: 255 if x == 255 else 0, '1')
128
+ img_bw.convert("RGB").save(image_path)
129
+ return bbox_list
130
+
131
+
132
+ def latex2bbox_color(input_arg):
133
+ latex, basename, output_path, temp_dir, total_color_list = input_arg
134
+ template = tabular_template if "tabular" in latex else formular_template
135
+ output_bbox_path = os.path.join(output_path, 'bbox', basename+'.jsonl')
136
+ output_vis_path = os.path.join(output_path, 'vis', basename+'.png')
137
+ output_base_path = os.path.join(output_path, 'vis', basename+'_base.png')
138
+
139
+ if os.path.exists(output_bbox_path) and os.path.exists(output_vis_path) and os.path.exists(output_base_path):
140
+ return
141
+
142
+ try:
143
+ ret, new_latex = tokenize_latex(latex, middle_file=os.path.join(temp_dir, basename+'.txt'))
144
+ if not(ret and new_latex):
145
+ log = f"ERROR, Tokenize latex failed: {basename}."
146
+ logging.info(log)
147
+ new_latex = latex
148
+ latex = normalize_latex(new_latex)
149
+ token_list = []
150
+ l_split = latex.strip().split(' ')
151
+ color_list = total_color_list[0:len(l_split)]
152
+ idx = 0
153
+ while idx < len(l_split):
154
+ l_split, idx, token_list = token_add_color_RGB(l_split, idx, token_list)
155
+
156
+ rgb_latex = " ".join(l_split)
157
+ for idx, color in enumerate(color_list):
158
+ R, G, B = color
159
+ rgb_latex = rgb_latex.replace(f"<color_{idx}>", f"{R},{G},{B}")
160
+
161
+ if len(token_list) > 1300:
162
+ paper_size = 3
163
+ elif len(token_list) > 600:
164
+ paper_size = 4
165
+ else:
166
+ paper_size = 5
167
+ final_latex = formular_template.replace("<PaperSize>", str(paper_size)) % rgb_latex
168
+
169
+ except Exception as e:
170
+ log = f"ERROR, Preprocess latex failed: {basename}; {e}."
171
+ logging.info(log)
172
+ return
173
+
174
+ pre_name = output_path.replace('/', '_').replace('.','_') + '_' + basename
175
+ tex_filename = os.path.join(temp_dir, pre_name+'.tex')
176
+ log_filename = os.path.join(temp_dir, pre_name+'.log')
177
+ aux_filename = os.path.join(temp_dir, pre_name+'.aux')
178
+
179
+ with open(tex_filename, "w") as w:
180
+ print(final_latex, file=w)
181
+ run_cmd(f"pdflatex -interaction=nonstopmode -output-directory={temp_dir} {tex_filename} >/dev/null")
182
+ try:
183
+ os.remove(tex_filename)
184
+ os.remove(log_filename)
185
+ os.remove(aux_filename)
186
+ except:
187
+ pass
188
+ pdf_filename = tex_filename[:-4]+'.pdf'
189
+ if not os.path.exists(pdf_filename):
190
+ log = f"ERROR, Compile pdf failed: {pdf_filename}"
191
+ logging.info(log)
192
+ else:
193
+ convert_pdf2img(pdf_filename, output_base_path)
194
+ os.remove(pdf_filename)
195
+
196
+ crop_image(output_base_path)
197
+ bbox_list = extrac_bbox_from_color_image(output_base_path, color_list)
198
+ vis = Image.open(output_base_path)
199
+ draw = ImageDraw.Draw(vis)
200
+
201
+ with open(output_bbox_path, 'w') as f:
202
+ for token, box in zip(token_list, bbox_list):
203
+ item = {
204
+ "bbox": box,
205
+ "token": token
206
+ }
207
+ f.write(json.dumps(item)+'\n')
208
+
209
+ if not box:
210
+ continue
211
+ x_min, y_min, x_max, y_max = box
212
+ draw.rectangle([x_min, y_min, x_max, y_max], fill=None, outline=(0,250,0), width=1)
213
+ draw.text((x_min, y_min), token, (250,0,0))
214
+
215
+ vis.save(output_vis_path)
latex_processor.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import shutil
5
+ import logging
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+
10
+ SKIP_PATTERNS = [r'\{', r'\}', r'[\[\]]', r'\\begin\{.*?\}', r'\\end\{.*?\}', r'\^', r'\_', r'\\.*rule.*', r'\\.*line.*', r'\[[\-.0-9]+[epm][xtm]\]']
11
+ SKIP_Tokens = ['\\', '\\\\', '\\index', '\\a', '&', '$', '\\multirow', '\\def', '\\raggedright', '\\url', '\\cr', '\\ensuremath', '\\left', '\\right',
12
+ '\\mathchoice', '\\scriptstyle', '\\displaystyle', '\\qquad', '\\quad', '\\,', '\\!', '~', '\\boldmath']
13
+ PHANTOM_Tokens = ['\\fontfamily', '\\vphantom', '\\phantom', '\\rowcolor', '\\ref']
14
+ TWO_Tail_Tokens = ['\\frac', '\\binom']
15
+ AB_Tail_Tokens = ['\\xrightarrow', '\\xleftarrow', '\\sqrt'] # special token \xxx [] {}
16
+ TWO_Tail_Invisb_Tokens = ['\\overset', '\\underset', '\\stackrel']
17
+ ONE_Tail_Tokens = ['\\widetilde', '\\overline', '\\hat', '\\widehat', '\\tilde', '\\Tilde', '\\dot', '\\bar', '\\vec', '\\underline', '\\underbrace', '\\check',
18
+ '\\breve', '\\Bar', '\\Vec', '\\mathring', '\\ddot']
19
+ ONE_Tail_Invisb_Tokens = ['\\boldsymbol', '\\pmb', '\\textbf', '\\mathrm', '\\mathbf', '\\mathbb', '\\mathcal', '\\textmd', '\\texttt', '\\textnormal',
20
+ '\\text', '\\textit', '\\textup', '\\mathop', '\\mathbin', '\\smash', '\\operatorname', '\\textrm', '\\mathfrak', '\\emph',
21
+ '\\textsf', '\\textsc']
22
+
23
+
24
+ def flatten_multiline(latex):
25
+ brace_map = {
26
+ "\\left(": "\\right)",
27
+ "\\left[": "\\right]",
28
+ "\\left{": "\\right}",
29
+ }
30
+ l_split = latex.split(' ')
31
+ if l_split[0] == "\\begin{array}":
32
+ if l_split[-1] == "\\end{array}":
33
+ l_split = l_split[2:-1]
34
+ else:
35
+ l_split = l_split[2:]
36
+
37
+ idx = 0
38
+ while idx < len(l_split):
39
+ token = l_split[idx]
40
+ if token.startswith("\\left") and token in brace_map.keys():
41
+ end_idx = find_matching_brace(l_split, idx, brace=[token, brace_map[token]])
42
+ if end_idx != -1:
43
+ idx = end_idx
44
+ elif token in ["\\\\", "~", "\\qquad"]:
45
+ l_split = l_split[0:idx] + l_split[idx+1:]
46
+ idx -= 1
47
+ idx += 1
48
+ latex = ' '.join(l_split)
49
+ return "$ "+latex+" $"
50
+
51
+
52
+ def clean_latex(text):
53
+ # TODO 让GPT写的去空格函数, 初步测了是没问题的, 不确定是否完全没有bug
54
+ cleaned_text = re.sub(r'(?<=[^\\])\s+(?=[^\\])', '', text)
55
+ # TODO 有一些不能去掉的空格给补充回来
56
+ for item in ["\\hline", "\\midrule", "\\times", "\\bf", "\\footnotesize", "\\cr", '\\log']:
57
+ cleaned_text = cleaned_text.replace(item, item+" ")
58
+ cleaned_text = cleaned_text.replace(" \\mathcolor{black}", "\\mathcolor{black}")
59
+ return cleaned_text
60
+
61
+ def remove_trailing_latex(formula):
62
+ pattern = r'(\\(hspace\*?\{[^{}]*?\}|vspace\*?\{[^{}]*?\}|smallskip|medskip|quad|qquad|bigskip|[;,])|\~|\.)*$'
63
+ # Replace the matched pattern with an empty string
64
+ cleaned_formula = re.sub(pattern, '', formula, count=1)
65
+ return cleaned_formula
66
+
67
+ def find_matching_brace(sequence, start_index, brace=['{', '}']):
68
+ # Finds the index of the matching brace for the one at start_index
69
+ left_brace, right_brace = brace
70
+ depth = 0
71
+ for i, char in enumerate(sequence[start_index:], start=start_index):
72
+ if char == left_brace:
73
+ depth += 1
74
+ elif char == right_brace:
75
+ depth -= 1
76
+ if depth == 0:
77
+ return i
78
+ if depth > 0:
79
+ error_info = "Warning! found no matching brace in sequence !"
80
+ raise ValueError(error_info)
81
+ return -1
82
+
83
+ def normalize_latex(l, rm_trail=False):
84
+ if "tabular" in l:
85
+ latex_type = "tabular"
86
+ else:
87
+ latex_type = "formula"
88
+
89
+ if rm_trail:
90
+ l = remove_trailing_latex(l)
91
+ l = l.strip().replace(r'\pmatrix', r'\mypmatrix').replace(r'\matrix', r'\mymatrix')
92
+
93
+ # TODO \raggedright \arraybackslash, these align method, difficult to handle, remove it.
94
+ for item in ['\\raggedright', '\\arraybackslash']:
95
+ l = l.replace(item, "")
96
+
97
+ for item in ['\\lowercase', '\\uppercase']:
98
+ l = l.replace(item, "")
99
+
100
+ # TODO \hspace {1 . 5 cm}, for formula, change to \hspace{1.5cm}, for table, remove it.
101
+ pattern = r'\\[hv]space { [.0-9a-z ]+ }'
102
+ old_token = re.findall(pattern, l, re.DOTALL)
103
+ if latex_type == "tabular":
104
+ new_token = ["" for item in old_token]
105
+ else:
106
+ new_token = [item.replace(" ", "") for item in old_token]
107
+ for bef, aft in zip(old_token, new_token):
108
+ l = l.replace(bef, aft)
109
+
110
+ # TODO take \begin {tabular} {} as one token
111
+ # TODO there are \begin{array} in table too,so the process should run in both formula and table.
112
+ if latex_type == "tabular":
113
+ l = l.replace("\\begin {tabular}", "\\begin{tabular}")
114
+ l = l.replace("\\end {tabular}", "\\end{tabular}")
115
+ l = l.replace("\\begin {array}", "\\begin{array}")
116
+ l = l.replace("\\end {array}", "\\end{array}")
117
+ l_split = l.split(' ')
118
+ idx = 0
119
+ while idx < len(l_split):
120
+ token = l_split[idx]
121
+ if token == "\\begin{tabular}":
122
+ sub_idx = idx + 1
123
+ end_idx = find_matching_brace(l_split, sub_idx)
124
+ new_token = "".join(l_split[idx: end_idx+1])
125
+ l_split = l_split[0:idx] + [new_token] + l_split[end_idx+1:]
126
+ break
127
+ idx += 1
128
+ l = ' '.join(l_split)
129
+
130
+ # TODO some complex format, hart to deal with re.match, so using brace match, such as:\cmidrule ( l { 3 p t } r { 3 p t } ) { 1 - 1 }
131
+ l_split = l.split(' ')
132
+ idx = 0
133
+ while idx < len(l_split):
134
+ token = l_split[idx]
135
+ if token in ["\\cmidrule", "\\cline"]:
136
+ sub_idx = idx + 1
137
+ if l_split[sub_idx] == "(":
138
+ mid_end = find_matching_brace(l_split, sub_idx, brace=['(', ')'])
139
+ end_idx = find_matching_brace(l_split, mid_end+1)
140
+ else:
141
+ end_idx = find_matching_brace(l_split, sub_idx)
142
+ new_token = "".join(l_split[idx: end_idx+1])
143
+ l_split = l_split[0:idx] + [new_token] + l_split[end_idx+1:]
144
+ idx += 1
145
+ l = ' '.join(l_split)
146
+
147
+ pattern = r'\\begin{array} { [lrc ]+ }'
148
+ old_token = re.findall(pattern, l, re.DOTALL)
149
+ new_token = [item.replace("\\begin{array} ", "<s>").replace(" ", "").replace("<s>", "\\begin{array} ") for item in old_token]
150
+ for bef, aft in zip(old_token, new_token):
151
+ l = l.replace(bef, aft)
152
+
153
+ # TODO token such \not= should be one token
154
+ pattern = r'\\not [<>+=\-]'
155
+ old_token = re.findall(pattern, l, re.DOTALL)
156
+ new_token = [item.replace(" ", "") for item in old_token]
157
+ for bef, aft in zip(old_token, new_token):
158
+ l = l.replace(bef, aft)
159
+
160
+ # TODO tokens such as \dots \exp \sinh, split them to parts, so the bbox match will be easier.
161
+
162
+ l = " "+l+" "
163
+ l = l.replace(" \\ldots ", " . . . ")
164
+ l = l.replace(" \\cdots ", " . . . ")
165
+ l = l.replace(" \\dots ", " . . . ")
166
+ l = l.replace(" \\dotsb ", " . . . ")
167
+ l = l.replace(" \\log ", " \\mathrm { l o g } ")
168
+ l = l.replace(" \\exp ", " \\mathrm { e x p } ")
169
+ l = l.replace(" \\sin ", " \\mathrm { s i n } ")
170
+ l = l.replace(" \\cos ", " \\mathrm { c o s } ")
171
+ l = l.replace(" \\tan ", " \\mathrm { t a n } ")
172
+ l = l.replace(" \\tanh ", " \\mathrm { t a n h } ")
173
+ l = l.replace(" \\cosh ", " \\mathrm { c o s h } ")
174
+ l = l.replace(" \\sinh ", " \\mathrm { s i n h } ")
175
+
176
+ # ** token such as \big( should be one token
177
+ pattern = r'\\[Bb]ig[g]?[glrm]? [(){}|\[\]] '
178
+ old_token = re.findall(pattern, l, re.DOTALL)
179
+ new_token = [item.replace(" ", "") for item in old_token]
180
+ for bef, aft in zip(old_token, new_token):
181
+ l = l.replace(bef, aft+" ")
182
+
183
+ pattern = r'\\[Bb]ig[g]?[glrm]? \\.*? '
184
+ old_token = re.findall(pattern, l, re.DOTALL)
185
+ new_token = [item.replace(" ", "") for item in old_token]
186
+ for bef, aft in zip(old_token, new_token):
187
+ l = l.replace(bef, aft+" ")
188
+
189
+ # TODO when \operatorname * meets mathcolor it comes error, yet the * is useless, so we simply remove it bynow.
190
+ pattern = r'\\operatorname \*'
191
+ old_token = re.findall(pattern, l, re.DOTALL)
192
+ new_token = ["\\operatorname" for item in old_token]
193
+ for bef, aft in zip(old_token, new_token):
194
+ l = l.replace(bef, aft)
195
+
196
+ # TODO \lefteqn will lead to letter overlap, it's harmfull for render, so simply remove it.
197
+ l = l.replace("\\lefteqn", "")
198
+
199
+ # TODO \footnote can not seem as ONE_Tail_Invisb_Tokens(usually this type token add color by \mathrm {\color(x)}, yet \footnode should be \color{\footnote{x}}), so we simple change it to "^".
200
+ l = l.replace("\\footnote ", "^ ")
201
+
202
+ # TODO \' can not be rendered separately(cause to different visulize performence), so we take these tokens as one token such as \' e -> \'e, on the other hand, if { after \' then render them separately.
203
+ pattern = r'\\\' [^{] '
204
+ old_token = re.findall(pattern, l, re.DOTALL)
205
+ new_token = [item.replace(" ", "") for item in old_token]
206
+ for bef, aft in zip(old_token, new_token):
207
+ l = l.replace(bef, aft+" ")
208
+
209
+ # TODO [ -1.5ex ] [ 1.5pt ] [ 3 mm ] some layout adjustment, no need to render. combine them as one token.
210
+ if latex_type == "tabular":
211
+ pattern = r'\[ [\-.0-9 ]+[exptcm ]+ \]'
212
+ old_token = re.findall(pattern, l, re.DOTALL)
213
+ new_token = [item.replace(" ", "") for item in old_token]
214
+ for bef, aft in zip(old_token, new_token):
215
+ l = l.replace(bef, aft)
216
+
217
+ # ** \parbox { 3cm } {} shoudle be combined as one token
218
+ pattern = r'\\parbox {[^{]+}'
219
+ old_token = re.findall(pattern, l, re.DOTALL)
220
+ new_token = [item.replace(" ", "") for item in old_token]
221
+ for bef, aft in zip(old_token, new_token):
222
+ l = l.replace(bef, aft)
223
+
224
+ # ** \raisebox{<lift>}[<height>][<depth>] {} shoudle be combined as one token, \raisebox{-1.5ex}[0pt]
225
+ pattern = r'\\raisebox {[^{]+} [\[\]0-9 exptcm]+{'
226
+ old_token = re.findall(pattern, l, re.DOTALL)
227
+ new_token = [item.replace(" ", "") for item in old_token]
228
+ for bef, aft in zip(old_token, new_token):
229
+ l = l.replace(bef, aft[0:-1]+" {")
230
+
231
+ # ** \char shoudle be combined as one token
232
+ pattern = r'{ \\char[0-9\' ]+}'
233
+ old_token = re.findall(pattern, l, re.DOTALL)
234
+ new_token = [item.replace(" ", "") for item in old_token]
235
+ for bef, aft in zip(old_token, new_token):
236
+ l = l.replace(bef, "{ "+aft[1:-1]+" }")
237
+
238
+ # ** \not xx shoudle be combined as one token
239
+ pattern = r'\\not [\\=\<\>][^ ]+ '
240
+ old_token = re.findall(pattern, l, re.DOTALL)
241
+ new_token = [item.replace(" ", "") for item in old_token]
242
+ for bef, aft in zip(old_token, new_token):
243
+ l = l.replace(bef, aft+" ")
244
+
245
+ # ** \specialrule{1pt}{2pt}{2pt}, special lines, shoudle be combined as one token
246
+ pattern = r'\\specialrule {[ .0-9a-z]+} {[ .0-9a-z]+} {[ .0-9a-z]+}'
247
+ old_token = re.findall(pattern, l, re.DOTALL)
248
+ new_token = [item.replace(" ", "") for item in old_token]
249
+ for bef, aft in zip(old_token, new_token):
250
+ l = l.replace(bef, aft)
251
+
252
+ # ** for easier add color, the original color should be removed, there are two type of color for now: \color[rgb]{0, 1, 0} and \color{red}
253
+ pattern = r'\\colorbox[ \[\]RGBrgb]+{ [A-Za-z 0-9,!]+ } |\\color[ \[\]RGBrgb]+{ [A-Za-z 0-9,!]+ } |\\textcolor[ \[\]RGBrgb]+{ [A-Za-z 0-9,!]+ } |\\cellcolor[ \[\]RGBrgb]+{ [A-Za-z 0-9,!]+ } '
254
+ old_token = re.findall(pattern, l, re.DOTALL)
255
+ for bef in old_token:
256
+ l = l.replace(bef, "")
257
+
258
+ # ** filling the missing brace [] and {} according to token.
259
+ l_split = l.split(' ')
260
+ idx = 0
261
+ while idx < len(l_split):
262
+ token = l_split[idx]
263
+ if token in ONE_Tail_Tokens + ONE_Tail_Invisb_Tokens:
264
+ # ** normalize tokens such as \hat, fill missing the {}, such as \hat \lambda -> \hat {\lambda}
265
+ sub_idx = idx + 1
266
+ while sub_idx < len(l_split) and l_split[sub_idx] in ONE_Tail_Tokens+ONE_Tail_Invisb_Tokens:
267
+ sub_idx += 1
268
+ new_split = l_split[0:idx]
269
+ for ii in range(idx, sub_idx):
270
+ new_split = new_split + [l_split[ii], "{"]
271
+ if l_split[sub_idx] != "{":
272
+ new_split = new_split + [l_split[sub_idx]] + ["}"]*(sub_idx-idx)
273
+ l_split = new_split + l_split[sub_idx+1:]
274
+ else:
275
+ end_idx = find_matching_brace(l_split, sub_idx)
276
+ new_split = new_split + l_split[sub_idx+1:end_idx] + ["}"]*(sub_idx-idx)
277
+ l_split = new_split + l_split[end_idx+1:]
278
+ elif token in AB_Tail_Tokens:
279
+ # ** normalize special tokens such as \sqrt, fill the missing [] {} in \sqrt [] {}, yet the [] is optional, for example: \sqrt A B -> \sqrt {A} B and \sqrt [A] B -> \sqrt [A] {B}
280
+ if l_split[idx + 1] != "[" and l_split[idx + 1] != "{":
281
+ l_split = l_split[0:idx+1] + ["{"] + [l_split[idx+1]] + ["}"] + l_split[idx+2:]
282
+ else:
283
+ if l_split[idx + 1] == "[":
284
+ end1 = find_matching_brace(l_split, idx+1, brace=['[', ']'])
285
+ else:
286
+ end1 = idx
287
+ if l_split[end1 + 1] != "{":
288
+ l_split = l_split[0:end1+1] + ["{"] + [l_split[end1+1]] + ["}"] + l_split[end1+2:]
289
+ elif token in TWO_Tail_Tokens + TWO_Tail_Invisb_Tokens:
290
+ # ** normalize special tokens such as \frac, add missing brace in \frac {A} {B} for example: \frac {\lambda} 2 -> \frac {\lambda} {2}
291
+ if l_split[idx + 1] != "{":
292
+ l_split = l_split[0:idx+1] + ["{"] + [l_split[idx+1]] + ["}"] + l_split[idx+2:]
293
+ end1 = find_matching_brace(l_split, idx+1)
294
+ if l_split[end1 + 1] != "{":
295
+ l_split = l_split[0:end1+1] + ["{"] + [l_split[end1+1]] + ["}"] + l_split[end1+2:]
296
+
297
+ idx += 1
298
+ l = ' '.join(l_split)
299
+
300
+ return l
301
+
302
+ def token_add_color(l_split, idx, render_dict):
303
+ token = l_split[idx]
304
+ if token in PHANTOM_Tokens:
305
+ # ** special tokens that do not need render, skip it
306
+ if l_split[idx + 1] == '{':
307
+ brace_end = find_matching_brace(l_split, idx + 1)
308
+ else:
309
+ brace_end = idx + 1
310
+ next_idx = brace_end + 1
311
+ elif token in TWO_Tail_Tokens:
312
+ # ** tokens such as \frac A B, and the token needs render too.
313
+ num_start = idx + 1
314
+ num_end = find_matching_brace(l_split, num_start)
315
+ den_start = num_end + 1
316
+ den_end = find_matching_brace(l_split, den_start)
317
+ l_split_copy = l_split[:idx] + [r'\mathcolor{black}{'+token+'{'] + \
318
+ [r'\mathcolor{gray}{'] + l_split[num_start + 1:num_end] + \
319
+ ['}'] + [r'}{'] + [r'\mathcolor{gray}{'] + l_split[den_start + 1:den_end] + \
320
+ ['}'] + ['}'] + ['}'] + l_split[den_end + 1:]
321
+
322
+ l_new = ' '.join(l_split_copy)
323
+ l_new = r'\mathcolor{gray}{ ' + l_new + ' }'
324
+ render_dict[str(idx)] = l_new, token
325
+ next_idx = idx + 1
326
+ elif token in ONE_Tail_Tokens:
327
+ # ** tokens such as \hat A, and the token needs render too.
328
+ num_start = idx + 1
329
+ num_end = find_matching_brace(l_split, num_start)
330
+ l_split_copy = l_split[:idx] + [r'\mathcolor{black}{'] + l_split[idx: num_start+1] + \
331
+ [r'\mathcolor{gray}{'] + l_split[num_start+1: num_end] + \
332
+ ['}'] + l_split[num_end: num_end+1] + ['}'] + l_split[num_end+1:]
333
+ l_new = ' '.join(l_split_copy)
334
+ l_new = r'\mathcolor{gray}{ ' + l_new + ' }'
335
+ render_dict[str(idx)] = l_new, token
336
+ next_idx = idx + 1
337
+ elif token in ONE_Tail_Invisb_Tokens:
338
+ # ** tokens such as \text A B, and the token does not need render.
339
+ num_start = idx + 1
340
+ num_end = find_matching_brace(l_split, num_start)
341
+ sub_idx = num_start+1
342
+ if num_end-num_start == 2:
343
+ l_split_copy = l_split.copy()
344
+ l_split_copy[sub_idx] = r'{\mathcolor{black}{' + l_split_copy[sub_idx] + '}}'
345
+ l_new = ' '.join(l_split_copy)
346
+ l_new = r'\mathcolor{gray}{ ' + l_new + ' }'
347
+ render_dict[str(idx)] = l_new, l_split[sub_idx]
348
+ next_idx = num_end
349
+ else:
350
+ while sub_idx < num_end:
351
+ l_split, sub_idx, render_dict = token_add_color(l_split, sub_idx, render_dict)
352
+ next_idx = num_end + 1
353
+ elif token in AB_Tail_Tokens:
354
+ # ** special token \xrightarrow, could be \xrightarrow [] {} or \xrightarrow {}, process method are different.
355
+ if l_split[idx+1] == '{':
356
+ num_start = idx + 1
357
+ num_end = find_matching_brace(l_split, num_start)
358
+ l_split_copy = l_split[:idx] + [r'\mathcolor{black}{'] + l_split[idx: idx+2] \
359
+ + [r'\mathcolor{gray}{'] + l_split[num_start+1: num_end] + ['}}'] + l_split[num_end:]
360
+ l_new = ' '.join(l_split_copy)
361
+ l_new = r'\mathcolor{gray}{ ' + l_new + ' }'
362
+ render_dict[str(idx)] = l_new, token
363
+ sub_idx = num_start+1
364
+ while sub_idx < num_end:
365
+ l_split, sub_idx, render_dict = token_add_color(l_split, sub_idx, render_dict)
366
+ next_idx = num_end + 1
367
+ elif l_split[idx+1] == '[':
368
+ num_start = idx + 1
369
+ num_end = find_matching_brace(l_split, num_start, brace=['[', ']'])
370
+ den_start = num_end + 1
371
+ den_end = find_matching_brace(l_split, den_start)
372
+ l_split_copy = l_split[:idx] + [r'{\mathcolor{black}{'] + l_split[idx: idx+2] \
373
+ + [r'\mathcolor{gray}{'] + l_split[idx+2: num_end] + ['}'] + l_split[num_end:den_start+1] \
374
+ + [r'\mathcolor{gray}{'] + l_split[den_start+1: den_end] + ['}'] + l_split[den_end: den_end+1] \
375
+ + ['}}'] + l_split[den_end+1:]
376
+ l_new = ' '.join(l_split_copy)
377
+ l_new = r'\mathcolor{gray}{ ' + l_new + ' }'
378
+ render_dict[str(idx)] = l_new, token
379
+ sub_idx = num_start + 1
380
+ while sub_idx < num_end:
381
+ l_split, sub_idx, render_dict = token_add_color(l_split, sub_idx, render_dict)
382
+ sub_idx = den_start + 1
383
+ while sub_idx < den_end:
384
+ l_split, sub_idx, render_dict = token_add_color(l_split, sub_idx, render_dict)
385
+ next_idx = den_end + 1
386
+ elif token in ["\\multicolumn", "\\multirow"]:
387
+ # ** tokens with three {}, such as \multicolumn {} {} {}, the text in third {} need be rendered.
388
+ first_start = idx + 1
389
+ first_end = find_matching_brace(l_split, first_start)
390
+ second_start = first_end + 1
391
+ second_end = find_matching_brace(l_split, second_start)
392
+ third_start = second_end + 1
393
+ third_end = find_matching_brace(l_split, third_start)
394
+
395
+ sub_idx = third_start+1
396
+ while sub_idx < third_end:
397
+ l_split, sub_idx, render_dict = token_add_color(l_split, sub_idx, render_dict)
398
+ next_idx = third_end + 1
399
+ elif token in SKIP_Tokens+TWO_Tail_Invisb_Tokens or any(re.match(pattern, token) for pattern in SKIP_PATTERNS):
400
+ # ** tokens no need render, just skip
401
+ # print('skip', idx, token)
402
+ # TODO special case :[], could be single, or in \sqrt[]{}.
403
+ if (token == "[" and l_split[idx-1]!="\\sqrt") or (token == "]" and idx>=3 and l_split[idx-3]!="\\sqrt"):
404
+ l_split_copy = l_split.copy()
405
+ l_split_copy[idx] = r'\mathcolor{black}{ ' + l_split_copy[idx] + ' }'
406
+ l_new = ' '.join(l_split_copy)
407
+ l_new = r'\mathcolor{gray}{ ' + l_new + ' }'
408
+ render_dict[str(idx)] = l_new, token
409
+ next_idx = idx + 1
410
+ else:
411
+ next_idx = idx + 1
412
+ else:
413
+ # ** nomal token
414
+ l_split_copy = l_split.copy()
415
+ # TODO sometimes there is translation after add color, the exp prove that \mathcolor{black}{ A } is better than \mathcolor{black}{A}
416
+ l_split_copy[idx] = r'\mathcolor{black}{ ' + l_split_copy[idx] + ' }'
417
+
418
+ l_new = ' '.join(l_split_copy)
419
+ l_new = r'\mathcolor{gray}{ ' + l_new + ' }'
420
+ render_dict[str(idx)] = l_new, token
421
+ next_idx = idx + 1
422
+
423
+ return l_split, next_idx, render_dict
424
+
425
+
426
+ def token_add_color_RGB(l_split, idx, token_list, brace_color=False):
427
+ """using \mathcolor[RGB]{r,g,b} to render latex.
428
+ """
429
+ token = l_split[idx]
430
+ if not token:
431
+ next_idx = idx + 1
432
+ elif token in PHANTOM_Tokens:
433
+ # ** special tokens that do not need render, skip it
434
+ if l_split[idx + 1] == '{':
435
+ brace_end = find_matching_brace(l_split, idx + 1)
436
+ else:
437
+ brace_end = idx + 1
438
+ next_idx = brace_end + 1
439
+ elif token in TWO_Tail_Tokens:
440
+ # ** tokens such as \frac A B, and the token needs render too.
441
+ num_start = idx + 1
442
+ num_end = find_matching_brace(l_split, num_start)
443
+ den_start = num_end + 1
444
+ den_end = find_matching_brace(l_split, den_start)
445
+ color_token = "\\mathcolor[RGB]{<color_<idx>>}{".replace("<idx>", str(len(token_list)))
446
+ l_split = l_split[:idx] + [color_token+token] + l_split[idx+1: den_end+1] + ["}"] + l_split[den_end+1:]
447
+ token_list.append(token)
448
+ next_idx = idx + 1
449
+ elif token in ONE_Tail_Tokens:
450
+ # ** tokens such as \hat A, and the token needs render too.
451
+ num_start = idx + 1
452
+ num_end = find_matching_brace(l_split, num_start)
453
+ color_token = "\\mathcolor[RGB]{<color_<idx>>}{".replace("<idx>", str(len(token_list)))
454
+ if token != "\\underbrace" and num_end+1 < len(l_split) and l_split[num_end+1] == "_":
455
+ l_split = l_split[:idx] + ["{"+color_token+token] + l_split[idx+1: num_end+1] + ["}}"] + l_split[num_end+1:]
456
+ else:
457
+ l_split = l_split[:idx] + [color_token+token] + l_split[idx+1: num_end+1] + ["}"] + l_split[num_end+1:]
458
+ token_list.append(token)
459
+ next_idx = idx + 1
460
+ elif token in ONE_Tail_Invisb_Tokens:
461
+ # ** tokens such as \text A B, and the token does not need render.
462
+ num_start = idx + 1
463
+ num_end = find_matching_brace(l_split, num_start)
464
+ sub_idx = num_start+1
465
+ if num_end-num_start == 2:
466
+ color_token = "\\mathcolor[RGB]{<color_<idx>>}{".replace("<idx>", str(len(token_list)))
467
+ token_list.append(l_split[num_start+1])
468
+ l_split = l_split[:num_start+1] + [color_token+l_split[num_start+1]+"}"] + l_split[num_end:]
469
+ else:
470
+ while sub_idx < num_end:
471
+ l_split, sub_idx, token_list = token_add_color_RGB(l_split, sub_idx, token_list)
472
+ next_idx = num_end + 1
473
+ elif token in AB_Tail_Tokens:
474
+ # ** special token \xrightarrow, could be \xrightarrow [] {} or \xrightarrow {}, process method are different.
475
+ if l_split[idx+1] == '{':
476
+ num_start = idx + 1
477
+ num_end = find_matching_brace(l_split, num_start)
478
+ color_token = "\\mathcolor[RGB]{<color_<idx>>}{".replace("<idx>", str(len(token_list)))
479
+ l_split = l_split[:idx] + [color_token+token] + l_split[idx+1: num_end+1] + ["}"] + l_split[num_end+1:]
480
+ token_list.append(token)
481
+ sub_idx = num_start+1
482
+ while sub_idx < num_end:
483
+ l_split, sub_idx, token_list = token_add_color_RGB(l_split, sub_idx, token_list)
484
+ next_idx = num_end + 1
485
+ elif l_split[idx+1] == '[':
486
+ num_start = idx + 1
487
+ num_end = find_matching_brace(l_split, num_start, brace=['[', ']'])
488
+ den_start = num_end + 1
489
+ den_end = find_matching_brace(l_split, den_start)
490
+ color_token = "\\mathcolor[RGB]{<color_<idx>>}{".replace("<idx>", str(len(token_list)))
491
+ l_split = l_split[:idx] + [color_token+token] + l_split[idx+1: den_end+1] + ["}"] + l_split[den_end+1:]
492
+ token_list.append(token)
493
+ sub_idx = num_start + 1
494
+ while sub_idx < num_end:
495
+ l_split, sub_idx, token_list = token_add_color_RGB(l_split, sub_idx, token_list, brace_color=True)
496
+ sub_idx = den_start + 1
497
+ while sub_idx < den_end:
498
+ l_split, sub_idx, token_list = token_add_color_RGB(l_split, sub_idx, token_list)
499
+ next_idx = den_end + 1
500
+ elif token in ["\\multicolumn", "\\multirow"]:
501
+ # ** tokens with three {}, such as \multicolumn {} {} {}, the text in third {} need be rendered.
502
+ first_start = idx + 1
503
+ first_end = find_matching_brace(l_split, first_start)
504
+ second_start = first_end + 1
505
+ second_end = find_matching_brace(l_split, second_start)
506
+ third_start = second_end + 1
507
+ third_end = find_matching_brace(l_split, third_start)
508
+
509
+ sub_idx = third_start+1
510
+ while sub_idx < third_end:
511
+ l_split, sub_idx, token_list = token_add_color_RGB(l_split, sub_idx, token_list)
512
+ next_idx = third_end + 1
513
+ elif token in SKIP_Tokens+TWO_Tail_Invisb_Tokens or any(re.match(pattern, token) for pattern in SKIP_PATTERNS):
514
+ # ** tokens no need render, just skip
515
+ # print('skip', idx, token)
516
+ # TODO special case :[], could be single, or in \sqrt[]{}.
517
+ if (token == "[" and l_split[idx-1]!="\\sqrt") or (token == "]" and idx>=3 and l_split[idx-3]!="\\sqrt"):
518
+ color_token = "\\mathcolor[RGB]{<color_<idx>>}{".replace("<idx>", str(len(token_list)))
519
+ l_split = l_split[:idx] + [color_token + l_split[idx] + "}"] + l_split[idx+1:]
520
+ token_list.append(token)
521
+ next_idx = idx + 1
522
+ else:
523
+ next_idx = idx + 1
524
+ else:
525
+ # ** nomal token
526
+ if brace_color or (idx > 1 and l_split[idx-1] == "_"):
527
+ color_token = "\\mathcolor[RGB]{<color_<idx>>}{".replace("<idx>", str(len(token_list)))
528
+ l_split = l_split[:idx] + ["{" + color_token + l_split[idx] + "}}"] + l_split[idx+1:]
529
+ token_list.append(token)
530
+ next_idx = idx + 1
531
+ else:
532
+ color_token = "\\mathcolor[RGB]{<color_<idx>>}{".replace("<idx>", str(len(token_list)))
533
+ l_split = l_split[:idx] + [color_token + l_split[idx] + "}"] + l_split[idx+1:]
534
+ token_list.append(token)
535
+ next_idx = idx + 1
536
+ return l_split, next_idx, token_list
visual_matcher.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ from PIL import Image
4
+ from scipy.spatial.distance import cdist
5
+ from scipy.optimize import linear_sum_assignment
6
+
7
+
8
+ class SimpleAffineTransform:
9
+ """
10
+ simple affine transform, only translation and scale.
11
+ """
12
+ def __init__(self, translation=(0, 0), scale=1.0):
13
+ self.translation = np.array(translation)
14
+ self.scale = scale
15
+
16
+ def estimate(self, src, dst):
17
+ src_center = np.mean(src, axis=0)
18
+ dst_center = np.mean(dst, axis=0)
19
+ self.translation = dst_center - src_center
20
+
21
+ src_dists = np.linalg.norm(src - src_center, axis=1)
22
+ dst_dists = np.linalg.norm(dst - dst_center, axis=1)
23
+ self.scale = np.mean(dst_dists) / (np.mean(src_dists) + 1e-10)
24
+
25
+ def inverse(self):
26
+ inverse_transform = AffineTransform(-self.translation, 1.0/self.scale)
27
+ return inverse_transform
28
+
29
+ def __call__(self, coords):
30
+ return self.scale * (coords - np.mean(coords, axis=0)) + np.mean(coords, axis=0) + self.translation
31
+
32
+ def residuals(self, src, dst):
33
+ return np.sqrt(np.sum((self(src) - dst) ** 2, axis=1))
34
+
35
+
36
+ def norm_coords(x, left, right):
37
+ if x < left:
38
+ return left
39
+ if x > right:
40
+ return right
41
+ return x
42
+
43
+ def norm_same_token(token):
44
+ special_map = {
45
+ "\\cdot": ".",
46
+ "\\mid": "|",
47
+ "\\to": "\\rightarrow",
48
+ "\\top": "T",
49
+ "\\Tilde": "\\tilde",
50
+ "\\cdots": "\\dots",
51
+ "\\prime": "'",
52
+ "\\ast": "*",
53
+ "\\left<": "\\langle",
54
+ "\\right>": "\\rangle"
55
+ }
56
+ if token in special_map.keys():
57
+ token = special_map[token]
58
+ if token.startswith('\\left') or token.startswith('\\right'):
59
+ token = token.replace("\\left", "").replace("\\right", "")
60
+ if token.startswith('\\big') or token.startswith('\\Big'):
61
+ if "\\" in token[4:]:
62
+ token = "\\"+token[4:].split("\\")[-1]
63
+ else:
64
+ token = token[-1]
65
+
66
+ if token in ['\\leq', '\\geq']:
67
+ return token[0:-1]
68
+ if token in ['\\lVert', '\\rVert', '\\Vert']:
69
+ return '\\|'
70
+ if token in ['\\lvert', '\\rvert', '\\vert']:
71
+ return '|'
72
+ if token.endswith("rightarrow"):
73
+ return "\\rightarrow"
74
+ if token.endswith("leftarrow"):
75
+ return "\\leftarrow"
76
+ if token.startswith('\\wide'):
77
+ return token.replace("wide", "")
78
+ if token.startswith('\\var'):
79
+ return token.replace("\\var", "")
80
+ return token
81
+
82
+
83
+ class HungarianMatcher:
84
+ def __init__(
85
+ self,
86
+ cost_token: float = 1,
87
+ cost_position: float = 0.05,
88
+ cost_order: float = 0.15,
89
+ ):
90
+ self.cost_token = cost_token
91
+ self.cost_position = cost_position
92
+ self.cost_order = cost_order
93
+ self.cost = {}
94
+
95
+ def calculate_token_cost_old(self, box_gt, box_pred):
96
+ token_cost = np.ones((len(box_gt), len(box_pred)))
97
+ for i in range(token_cost.shape[0]):
98
+ box1 = box_gt[i]
99
+ for j in range(token_cost.shape[1]):
100
+ box2 = box_pred[j]
101
+ if box1['token'] == box2['token']:
102
+ token_cost[i, j] = 0
103
+ elif norm_same_token(box1['token']) == norm_same_token(box2['token']):
104
+ token_cost[i, j] = 0.05
105
+ return np.array(token_cost)
106
+
107
+ def calculate_token_cost(self, box_gt, box_pred):
108
+ token2id = {}
109
+ for data in box_gt+box_pred:
110
+ if data['token'] not in token2id:
111
+ token2id[data['token']] = len(token2id)
112
+ num_classes = len(token2id)
113
+
114
+ token2id_norm = {}
115
+ for data in box_gt+box_pred:
116
+ if norm_same_token(data['token']) not in token2id_norm:
117
+ token2id_norm[norm_same_token(data['token'])] = len(token2id_norm)
118
+ num_classes_norm = len(token2id_norm)
119
+
120
+ gt_token_array = []
121
+ norm_gt_token_array = []
122
+ for data in box_gt:
123
+ gt_token_array.append(token2id[data['token']])
124
+ norm_gt_token_array.append(token2id_norm[norm_same_token(data['token'])])
125
+
126
+ pred_token_logits = []
127
+ norm_pred_token_logits = []
128
+ for data in box_pred:
129
+ logits = [0] * num_classes
130
+ logits[token2id[data['token']]] = 1
131
+ pred_token_logits.append(logits)
132
+
133
+ logits_norm = [0] * num_classes_norm
134
+ logits_norm[token2id_norm[norm_same_token(data['token'])]] = 1
135
+ norm_pred_token_logits.append(logits_norm)
136
+
137
+ gt_token_array = np.array(gt_token_array)
138
+ pred_token_logits = np.array(pred_token_logits)
139
+
140
+ norm_gt_token_array = np.array(norm_gt_token_array)
141
+ norm_pred_token_logits = np.array(norm_pred_token_logits)
142
+
143
+ token_cost = 1.0 - pred_token_logits[:, gt_token_array]
144
+ norm_token_cost = 1.0 - norm_pred_token_logits[:, norm_gt_token_array]
145
+
146
+ token_cost[np.logical_and(token_cost==1, norm_token_cost==0)] = 0.05
147
+ return token_cost.T
148
+
149
+
150
+ def box2array(self, box_list, size):
151
+ W, H = size
152
+ box_array = []
153
+ for box in box_list:
154
+ x_min, y_min, x_max, y_max = box['bbox']
155
+ box_array.append([x_min/W, y_min/H, x_max/W, y_max/H])
156
+ return np.array(box_array)
157
+
158
+ def order2array(self, box_list):
159
+ order_array = []
160
+ for idx, box in enumerate(box_list):
161
+ order_array.append([idx / len(box_list)])
162
+ return np.array(order_array)
163
+
164
+ def calculate_l1_cost(self, gt_array, pred_array):
165
+ scale = gt_array.shape[-1]
166
+ l1_cost = cdist(gt_array, pred_array, 'minkowski', p=1)
167
+ return l1_cost / scale
168
+
169
+ def __call__(self, box_gt, box_pred, gt_size, pred_size):
170
+ aa = time.time()
171
+ gt_box_array = self.box2array(box_gt, gt_size)
172
+ pred_box_array = self.box2array(box_pred, pred_size)
173
+ gt_order_array = self.order2array(box_gt)
174
+ pred_order_array = self.order2array(box_pred)
175
+
176
+ token_cost = self.calculate_token_cost(box_gt, box_pred)
177
+ position_cost = self.calculate_l1_cost(gt_box_array, pred_box_array)
178
+ order_cost = self.calculate_l1_cost(gt_order_array, pred_order_array)
179
+
180
+ self.cost["token"] = token_cost
181
+ self.cost["position"] = position_cost
182
+ self.cost["order"] = order_cost
183
+
184
+ cost = self.cost_token * token_cost + self.cost_position * position_cost + self.cost_order * order_cost
185
+ cost[np.isnan(cost) | np.isinf(cost)] = 100
186
+ indexes = linear_sum_assignment(cost)
187
+ matched_idxes = []
188
+ for a, b in zip(*indexes):
189
+ matched_idxes.append((a, b))
190
+
191
+ return matched_idxes