File size: 14,600 Bytes
324b9ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
import os
import sys

sys.path.append(os.path.dirname(os.path.abspath(__file__)))

import torch
import pandas as pd
from PIL import Image
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
import json
import torch.nn.functional as F

from models.can.can import CAN, create_can_model
from models.can.can_dataloader import Vocabulary, process_img, INPUT_HEIGHT, INPUT_WIDTH

torch.serialization.add_safe_globals([Vocabulary])

os.environ['QT_QPA_PLATFORM'] = 'offscreen'

with open("config.json", "r") as json_file:
    cfg = json.load(json_file)

CAN_CONFIG = cfg["can"]


# Global constants here
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODE = CAN_CONFIG["mode"]  # 'single' or 'evaluate'
BACKBONE_TYPE = CAN_CONFIG["backbone_type"]
PRETRAINED_BACKBONE = True if CAN_CONFIG["pretrained_backbone"] == 1 else False
CHECKPOINT_PATH = f'checkpoints/{BACKBONE_TYPE}_can_best.pth' if PRETRAINED_BACKBONE == False else f'checkpoints/p_{BACKBONE_TYPE}_can_best.pth'
IMAGE_PATH = f'{CAN_CONFIG["test_folder"]}/{CAN_CONFIG["relative_image_path"]}'
VISUALIZE = True if CAN_CONFIG["visualize"] == 1 else False
TEST_FOLDER = CAN_CONFIG["test_folder"]
LABEL_FILE = CAN_CONFIG["label_file"]
CLASSIFIER = CAN_CONFIG["classifier"]  # choose between 'frac', 'sum_or_lim', 'long_expr', and 'all'


def filter_formula(formula_tokens, mode):
    if mode == "frac":
        return "\\frac" in formula_tokens
    elif mode == "sum_or_lim":
        return "\\sum" in formula_tokens or "\\limit" in formula_tokens
    elif mode == "long_expr":
        return len(formula_tokens) >= 30
    elif mode == 'short_expr':
        return len(formula_tokens) <= 10
    return True


def levenshtein_distance(lst1, lst2):
    """

    Calculate Levenshtein distance between two lists

    """
    m = len(lst1)
    n = len(lst2)

    prev_row = [j for j in range(n + 1)]
    curr_row = [0] * (n + 1)
    for i in range(1, m + 1):
        curr_row[0] = i

        for j in range(1, n + 1):
            if lst1[i - 1] == lst2[j - 1]:
                curr_row[j] = prev_row[j - 1]
            else:
                curr_row[j] = 1 + min(
                    curr_row[j - 1],  # insertion
                    prev_row[j],  # deletion
                    prev_row[j - 1]  # substitution
                )

        prev_row = curr_row.copy()
    return curr_row[n]


def load_checkpoint(checkpoint_path, device, pretrained_backbone=True, backbone='densenet'):
    """

    Load checkpoint and return model and vocabulary

    """
    checkpoint = torch.load(checkpoint_path,
                            map_location=device,
                            weights_only=False)

    vocab = checkpoint.get('vocab')
    if vocab is None:
        # Try to load vocab from a separate file if not in checkpoint
        vocab_path = os.path.join(os.path.dirname(checkpoint_path),
                                  'hmer_vocab.pth')
        if os.path.exists(vocab_path):
            vocab_data = torch.load(vocab_path)
            vocab = Vocabulary()
            vocab.word2idx = vocab_data['word2idx']
            vocab.idx2word = vocab_data['idx2word']
            vocab.idx = vocab_data['idx']
            # Update special tokens
            vocab.pad_token = vocab.word2idx['<pad>']
            vocab.start_token = vocab.word2idx['<start>']
            vocab.end_token = vocab.word2idx['<end>']
            vocab.unk_token = vocab.word2idx['<unk>']
        else:
            raise ValueError(
                f"Vocabulary not found in checkpoint and {vocab_path} does not exist"
            )

    # Initialize model with parameters from checkpoint
    hidden_size = checkpoint.get('hidden_size', 256)
    embedding_dim = checkpoint.get('embedding_dim', 256)
    use_coverage = checkpoint.get('use_coverage', True)

    model = create_can_model(num_classes=len(vocab),
                             hidden_size=hidden_size,
                             embedding_dim=embedding_dim,
                             use_coverage=use_coverage,
                             pretrained_backbone=pretrained_backbone,
                             backbone_type=backbone).to(device)

    model.load_state_dict(checkpoint['model'])
    print(f"Loaded model from checkpoint {checkpoint_path}")

    return model, vocab


def recognize_single_image(model,

                           image_path,

                           vocab,

                           device,

                           max_length=150,

                           visualize_attention=False):
    """

    Recognize handwritten mathematical expression from a single image using the CAN model

    """
    # Prepare image transform for grayscale images
    transform = A.Compose([
        A.Normalize(mean=[0.0], std=[1.0]),  # For grayscale        
        A.pytorch.ToTensorV2()
    ])

    # Load and transform image
    processed_img, best_crop = process_img(image_path, convert_to_rgb=False)

    # Ensure image has the correct format for albumentations
    processed_img = np.expand_dims(processed_img, axis=-1)  # [H, W, 1]
    image_tensor = transform(
        image=processed_img)['image'].unsqueeze(0).to(device)

    model.eval()
    with torch.no_grad():
        # Generate LaTeX using beam search
        predictions, attention_weights = model.recognize(
            image_tensor,
            max_length=max_length,
            start_token=vocab.start_token,
            end_token=vocab.end_token,
            beam_width=5  # Use beam search with width 5
        )

    # Convert indices to LaTeX tokens
    latex_tokens = []
    for idx in predictions:
        if idx == vocab.end_token:
            break
        if idx != vocab.start_token:  # Skip start token
            latex_tokens.append(vocab.idx2word[idx])

    # Join tokens to get complete LaTeX
    latex = ' '.join(latex_tokens)

    # Visualize attention if requested
    if visualize_attention and attention_weights is not None:
        visualize_attention_maps(processed_img, attention_weights,
                                 latex_tokens, best_crop)

    return latex


def visualize_attention_maps(orig_image,

                             attention_weights,

                             latex_tokens,

                             best_crop,

                             max_cols=4):
    """

    Visualize attention maps over the image for CAN model

    """
    # Create PIL image from numpy array
    orig_image = orig_image.crop(best_crop)
    orig_w, orig_h = orig_image.size
    ratio = INPUT_HEIGHT / INPUT_WIDTH
    
    num_tokens = len(latex_tokens)
    num_cols = min(max_cols, num_tokens)
    num_rows = int(np.ceil(num_tokens / num_cols))

    fig, axes = plt.subplots(num_rows,
                             num_cols,
                             figsize=(num_cols * 3, int(num_rows * 6 * orig_h / orig_w)))
    axes = np.array(axes).reshape(-1)

    for i, (token, attn) in enumerate(zip(latex_tokens, attention_weights)):
        ax = axes[i]

        attn = attn[0:1].squeeze(0)
        attn_len = attn.shape[0]
        attn_w = int(np.sqrt(attn_len / ratio))
        attn_h = int(np.sqrt(attn_len * ratio))

        # resize to (orig_h, interpolated_w)
        attn = attn.view(1, 1, attn_h, attn_w)
        interp_w = int(orig_h / ratio)

        attn = F.interpolate(attn, size=(orig_h, interp_w), mode='bilinear', align_corners=False)
        attn = attn.squeeze().cpu().numpy()

        # fix aspect ratio mismatch
        if interp_w > orig_w:
            # center crop width
            start = (interp_w - orig_w) // 2
            attn = attn[:, start:start + orig_w]
        elif interp_w < orig_w:
            # stretch to fit width
            attn = cv2.resize(attn, (orig_w, orig_h), interpolation=cv2.INTER_CUBIC)

        ax.imshow(orig_image)
        ax.imshow(attn, cmap='jet', alpha=0.4)
        ax.set_title(f'{token}', fontsize=10 * 8 * orig_h / orig_w)
        ax.axis('off')

    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.savefig('attention_maps_can.png', bbox_inches='tight', dpi=150)
    plt.close()


def evaluate_model(model,

                   test_folder,

                   label_file,

                   vocab,

                   device,

                   max_length=150,

                   batch_size=32):
    """

    Evaluate CAN model on test set

    """
    df = pd.read_csv(label_file,
                     sep='\t',
                     header=None,
                     names=['filename', 'label'])

    # Check image file format
    if os.path.exists(test_folder):
        img_files = os.listdir(test_folder)
        if img_files:
            # Get the extension of the first file
            extension = os.path.splitext(img_files[0])[1]
            # Add extension to filenames if not present
            df['filename'] = df['filename'].apply(
                lambda x: x if os.path.splitext(x)[1] else x + extension)

    annotations = dict(zip(df['filename'], df['label']))

    model.eval()

    correct = 0
    err1 = 0
    err2 = 0
    err3 = 0
    total = 0

    transform = A.Compose([
        A.Normalize(mean=[0.0], std=[1.0]),  # For grayscale            
        A.pytorch.ToTensorV2()
    ])

    results = {}

    for image_path, gt_latex in tqdm(annotations.items(), desc="Evaluating"):
        gt_latex: str = gt_latex
        if not filter_formula(gt_latex.split(), CLASSIFIER):
            continue 
        file_path = os.path.join(test_folder, image_path)

        try:
            processed_img, _ = process_img(file_path, convert_to_rgb=False)

            # Ensure image has the correct format for albumentations
            processed_img = np.expand_dims(processed_img, axis=-1)  # [H, W, 1]
            image_tensor = transform(
                image=processed_img)['image'].unsqueeze(0).to(device)

            with torch.no_grad():
                predictions, _ = model.recognize(
                    image_tensor,
                    max_length=max_length,
                    start_token=vocab.start_token,
                    end_token=vocab.end_token,
                    beam_width=5  # Use beam search
                )

            # Convert indices to LaTeX tokens
            pred_latex_tokens = []
            for idx in predictions:
                if idx == vocab.end_token:
                    break
                if idx != vocab.start_token:  # Skip start token
                    pred_latex_tokens.append(vocab.idx2word[idx])

            pred_latex = ' '.join(pred_latex_tokens)

            gt_latex_tokens = gt_latex.split()
            edit_distance = levenshtein_distance(pred_latex_tokens,
                                                 gt_latex_tokens)

            if edit_distance == 0:
                correct += 1
            elif edit_distance == 1:
                err1 += 1
            elif edit_distance == 2:
                err2 += 1
            elif edit_distance == 3:
                err3 += 1

            total += 1

            # Save result
            results[image_path] = {
                'ground_truth': gt_latex,
                'prediction': pred_latex,
                'edit_distance': edit_distance
            }
        except Exception as e:
            print(f"Error processing {image_path}: {e}")

    # Calculate accuracy metrics
    exprate = round(correct / total, 4) if total > 0 else 0
    exprate_leq1 = round((correct + err1) / total, 4) if total > 0 else 0
    exprate_leq2 = round(
        (correct + err1 + err2) / total, 4) if total > 0 else 0
    exprate_leq3 = round(
        (correct + err1 + err2 + err3) / total, 4) if total > 0 else 0

    print(f"Exact match rate: {exprate:.4f}")
    print(f"Edit distance ≤ 1: {exprate_leq1:.4f}")
    print(f"Edit distance ≤ 2: {exprate_leq2:.4f}")
    print(f"Edit distance ≤ 3: {exprate_leq3:.4f}")

    # Save results to file
    with open('evaluation_results_can.json', 'w', encoding='utf-8') as f:
        json.dump(
            {
                'accuracy': {
                    'exprate': exprate,
                    'exprate_leq1': exprate_leq1,
                    'exprate_leq2': exprate_leq2,
                    'exprate_leq3': exprate_leq3
                },
                'results': results
            },
            f,
            indent=4)

    return {
        'exprate': exprate,
        'exprate_leq1': exprate_leq1,
        'exprate_leq2': exprate_leq2,
        'exprate_leq3': exprate_leq3
    }, results


def main(mode):
    device = DEVICE
    print(f'Using device: {device}')

    checkpoint_path = CHECKPOINT_PATH
    backbone = BACKBONE_TYPE
    pretrained_backbone = PRETRAINED_BACKBONE

    # For single mode
    image_path = IMAGE_PATH
    visualize = VISUALIZE

    # For evaluation mode
    test_folder = TEST_FOLDER
    label_file = LABEL_FILE

    # Load model and vocabulary
    model, vocab = load_checkpoint(checkpoint_path, device, pretrained_backbone=pretrained_backbone, backbone=backbone)

    if mode == 'single':
        if image_path is None:
            raise ValueError('Image path is required for single mode')

        latex = recognize_single_image(model,
                                       image_path,
                                       vocab,
                                       device,
                                       visualize_attention=visualize)
        print(f'Recognized LaTeX: {latex}')

    elif mode == 'evaluate':
        if test_folder is None or label_file is None:
            raise ValueError(
                'Test folder and annotation file are required for evaluate mode'
            )

        metrics, results = evaluate_model(model, test_folder, label_file,
                                          vocab, device)
        print(f"##### Score of {CLASSIFIER} expression type: #####")
        print(f'Evaluation metrics: {metrics}')


if __name__ == '__main__':
    # Ensure Vocabulary is safe for serialization
    torch.serialization.add_safe_globals([Vocabulary])

    # Run the main function
    main(MODE)