fisherman611 commited on
Commit
324b9ef
·
verified ·
1 Parent(s): 89ae6ce

Upload 3 files

Browse files
models/can/can_dataloader.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader, ConcatDataset
4
+ import albumentations as A
5
+ from PIL import Image
6
+ import pandas as pd
7
+ import cv2
8
+ import numpy as np
9
+ from collections import Counter
10
+
11
+ import json
12
+
13
+ with open("config.json", "r") as json_file:
14
+ cfg = json.load(json_file)
15
+
16
+ CAN_CONFIG = cfg["can"]
17
+
18
+
19
+ # Global constants
20
+ INPUT_HEIGHT = CAN_CONFIG["input_height"]
21
+ INPUT_WIDTH = CAN_CONFIG["input_width"]
22
+ BASE_DIR = CAN_CONFIG["base_dir"]
23
+ BATCH_SIZE = CAN_CONFIG["batch_size"]
24
+ NUM_WORKERS = CAN_CONFIG["num_workers"]
25
+
26
+
27
+ def is_effectively_binary(img, threshold_percentage=0.9):
28
+ dark_pixels = np.sum(img < 20)
29
+ bright_pixels = np.sum(img > 235)
30
+ total_pixels = img.size
31
+
32
+ return (dark_pixels + bright_pixels) / total_pixels > threshold_percentage
33
+
34
+
35
+ def before_padding(image):
36
+ # Apply Canny edge detector to find text edges
37
+ edges = cv2.Canny(image, 50, 150)
38
+
39
+ # Apply dilation to connect nearby edges
40
+ kernel = np.ones((7, 13), np.uint8)
41
+ dilated = cv2.dilate(edges, kernel, iterations=8)
42
+
43
+ # Find connected components
44
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
45
+ dilated, connectivity=8
46
+ )
47
+
48
+ # Optimize crop rectangle using F1 score
49
+ # Sort components by number of white pixels (excluding background which is label 0)
50
+ sorted_components = sorted(
51
+ range(1, num_labels), key=lambda i: stats[i, cv2.CC_STAT_AREA], reverse=True
52
+ )
53
+
54
+ # Initialize with empty crop
55
+ best_f1 = 0
56
+ best_crop = (0, 0, image.shape[1], image.shape[0])
57
+ total_white_pixels = np.sum(dilated > 0)
58
+
59
+ current_mask = np.zeros_like(dilated)
60
+ x_min, y_min = image.shape[1], image.shape[0]
61
+ x_max, y_max = 0, 0
62
+
63
+ for component_idx in sorted_components:
64
+ # Add this component to our mask
65
+ component_mask = labels == component_idx
66
+ current_mask = np.logical_or(current_mask, component_mask)
67
+
68
+ # Update bounding box
69
+ comp_y, comp_x = np.where(component_mask)
70
+ if len(comp_x) > 0 and len(comp_y) > 0:
71
+ x_min = min(x_min, np.min(comp_x))
72
+ y_min = min(y_min, np.min(comp_y))
73
+ x_max = max(x_max, np.max(comp_x))
74
+ y_max = max(y_max, np.max(comp_y))
75
+
76
+ # Calculate the current crop
77
+ width = x_max - x_min + 1
78
+ height = y_max - y_min + 1
79
+ crop_area = width * height
80
+
81
+ crop_mask = np.zeros_like(dilated)
82
+ crop_mask[y_min : y_max + 1, x_min : x_max + 1] = 1
83
+ white_in_crop = np.sum(np.logical_and(dilated > 0, crop_mask > 0))
84
+
85
+ # Calculate F1 score
86
+ precision = white_in_crop / crop_area
87
+ recall = white_in_crop / total_white_pixels
88
+ f1 = 2 * precision * recall / (precision + recall)
89
+
90
+ if f1 > best_f1:
91
+ best_f1 = f1
92
+ best_crop = (x_min, y_min, x_max, y_max)
93
+
94
+ # Apply the best crop to the original image
95
+ x_min, y_min, x_max, y_max = best_crop
96
+ cropped_image = image[y_min : y_max + 1, x_min : x_max + 1]
97
+
98
+ # Apply Gaussian adaptive thresholding
99
+ if is_effectively_binary(cropped_image):
100
+ _, thresh = cv2.threshold(cropped_image, 127, 255, cv2.THRESH_BINARY)
101
+ else:
102
+ thresh = cv2.adaptiveThreshold(
103
+ cropped_image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2
104
+ )
105
+
106
+ # Ensure background is black
107
+ white = np.sum(thresh == 255)
108
+ black = np.sum(thresh == 0)
109
+ if white > black:
110
+ thresh = 255 - thresh
111
+
112
+ # Clean up noise using median filter
113
+ denoised = cv2.medianBlur(thresh, 3)
114
+ for _ in range(3):
115
+ denoised = cv2.medianBlur(denoised, 3)
116
+
117
+ # Add padding
118
+ result = cv2.copyMakeBorder(denoised, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=0)
119
+
120
+ return result, best_crop
121
+
122
+
123
+ def process_img(filename, convert_to_rgb=False):
124
+ """
125
+ Load, binarize, ensure black background, resize, and apply padding
126
+
127
+ Args:
128
+ filename: Path to the image file
129
+ convert_to_rgb: Whether to convert to RGB
130
+
131
+ Returns:
132
+ Processed image and crop information
133
+ """
134
+ image = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)
135
+ if image is None:
136
+ raise ValueError(f"Could not read image file: {filename}")
137
+
138
+ bin_img, best_crop = before_padding(image)
139
+ h, w = bin_img.shape
140
+ new_w = int((INPUT_HEIGHT / h) * w)
141
+
142
+ if new_w > INPUT_WIDTH:
143
+ resized_img = cv2.resize(
144
+ bin_img, (INPUT_WIDTH, INPUT_HEIGHT), interpolation=cv2.INTER_AREA
145
+ )
146
+ else:
147
+ resized_img = cv2.resize(
148
+ bin_img, (new_w, INPUT_HEIGHT), interpolation=cv2.INTER_AREA
149
+ )
150
+ padded_img = (
151
+ np.ones((INPUT_HEIGHT, INPUT_WIDTH), dtype=np.uint8) * 0
152
+ ) # Black background
153
+ x_offset = (INPUT_WIDTH - new_w) // 2
154
+ padded_img[:, x_offset : x_offset + new_w] = resized_img
155
+ resized_img = padded_img
156
+
157
+ # Convert to BGR/RGB only if necessary
158
+ if convert_to_rgb:
159
+ resized_img = cv2.cvtColor(resized_img, cv2.COLOR_GRAY2BGR)
160
+
161
+ return resized_img, best_crop
162
+
163
+
164
+ class HMERDatasetForCAN(Dataset):
165
+ """
166
+ Dataset integrated with the CAN model for HMER
167
+ """
168
+
169
+ def __init__(self, data_folder, label_file, vocab, transform=None, max_length=150):
170
+ """
171
+ Initialize the dataset
172
+
173
+ data_folder: Directory containing images
174
+ label_file: TSV file with two columns (filename, label), no header
175
+ vocab: Vocabulary object for tokenization
176
+ transform: Image transformations
177
+ max_length: Maximum length of the token sequence
178
+ """
179
+ self.data_folder = data_folder
180
+ self.max_length = max_length
181
+ self.vocab = vocab
182
+
183
+ # Read the label file
184
+ df = pd.read_csv(label_file, sep="\t", header=None, names=["filename", "label"])
185
+
186
+ # Check image file format
187
+ if os.path.exists(data_folder):
188
+ img_files = os.listdir(data_folder)
189
+ if img_files:
190
+ # Get the extension of the first file
191
+ extension = os.path.splitext(img_files[0])[1]
192
+ # Add extension to filenames if not present
193
+ df["filename"] = df["filename"].apply(
194
+ lambda x: x if os.path.splitext(x)[1] else x + extension
195
+ )
196
+
197
+ self.annotations = dict(zip(df["filename"], df["label"]))
198
+ self.image_paths = list(self.annotations.keys())
199
+
200
+ # Default transformation
201
+ if transform is None:
202
+ transform = A.Compose(
203
+ [
204
+ A.Normalize(
205
+ mean=[0.0], std=[1.0]
206
+ ), # Normalize for single channel (grayscale)
207
+ A.pytorch.ToTensorV2(),
208
+ ]
209
+ )
210
+ self.transform = transform
211
+
212
+ def __len__(self):
213
+ return len(self.image_paths)
214
+
215
+ def __getitem__(self, idx):
216
+ # Get image path and LaTeX expression
217
+ image_path = self.image_paths[idx]
218
+ latex = self.annotations[image_path]
219
+
220
+ # Process image
221
+ file_path = os.path.join(self.data_folder, image_path)
222
+ processed_img, _ = process_img(
223
+ file_path, convert_to_rgb=False
224
+ ) # Keep image as grayscale
225
+
226
+ # Convert to [C, H, W] format and normalize
227
+ if self.transform:
228
+ # Ensure image has the correct format for albumentations
229
+ processed_img = np.expand_dims(processed_img, axis=-1) # [H, W, 1]
230
+ image = self.transform(image=processed_img)["image"]
231
+ else:
232
+ # If no transform, manually convert to tensor
233
+ image = torch.from_numpy(processed_img).float() / 255.0
234
+ image = image.unsqueeze(0) # Add grayscale channel: [1, H, W]
235
+
236
+ # Tokenize LaTeX expression
237
+ tokens = self.vocab.tokenize(latex)
238
+
239
+ # Add start and end tokens
240
+ tokens = [self.vocab.start_token] + tokens + [self.vocab.end_token]
241
+
242
+ # Truncate if exceeding max length
243
+ if len(tokens) > self.max_length:
244
+ tokens = tokens[: self.max_length]
245
+
246
+ # Create counting vector for CAN
247
+ count_vector = self.create_count_vector(tokens)
248
+
249
+ # Store actual caption length
250
+ caption_length = torch.LongTensor([len(tokens)])
251
+
252
+ # Pad to max length
253
+ if len(tokens) < self.max_length:
254
+ tokens = tokens + [self.vocab.pad_token] * (self.max_length - len(tokens))
255
+
256
+ # Convert to tensor
257
+ caption = torch.LongTensor(tokens)
258
+
259
+ return image, caption, caption_length, count_vector
260
+
261
+ def create_count_vector(self, tokens):
262
+ """
263
+ Create counting vector for the CAN model
264
+
265
+ Args:
266
+ tokens: List of token IDs
267
+
268
+ Returns:
269
+ Tensor counting the occurrence of each symbol
270
+ """
271
+ # Count occurrences of each token
272
+ counter = Counter(tokens)
273
+
274
+ # Create counting vector with size equal to vocabulary size
275
+ count_vector = torch.zeros(len(self.vocab))
276
+
277
+ # Fill counting vector with counts
278
+ for token_id, count in counter.items():
279
+ if 0 <= token_id < len(count_vector):
280
+ count_vector[token_id] = count
281
+
282
+ return count_vector
283
+
284
+
285
+ class Vocabulary:
286
+ """
287
+ Advanced Vocabulary class for tokenization
288
+ """
289
+
290
+ def __init__(self):
291
+ self.word2idx = {}
292
+ self.idx2word = {}
293
+ self.idx = 0
294
+
295
+ # Add special tokens
296
+ self.add_word("<pad>") # Padding token
297
+ self.add_word("<start>") # Start token
298
+ self.add_word("<end>") # End token
299
+ self.add_word("<unk>") # Unknown token
300
+
301
+ self.pad_token = self.word2idx["<pad>"]
302
+ self.start_token = self.word2idx["<start>"]
303
+ self.end_token = self.word2idx["<end>"]
304
+ self.unk_token = self.word2idx["<unk>"]
305
+
306
+ def add_word(self, word):
307
+ if word not in self.word2idx:
308
+ self.word2idx[word] = self.idx
309
+ self.idx2word[self.idx] = word
310
+ self.idx += 1
311
+
312
+ def __len__(self):
313
+ return len(self.word2idx)
314
+
315
+ def tokenize(self, latex):
316
+ """
317
+ Tokenize LaTeX string into indices. Assumes tokens are space-separated.
318
+ """
319
+ tokens = []
320
+
321
+ for char in latex.split():
322
+ if char in self.word2idx:
323
+ tokens.append(self.word2idx[char])
324
+ else:
325
+ tokens.append(self.unk_token)
326
+
327
+ return tokens
328
+
329
+ def build_vocab(self, label_file):
330
+ """
331
+ Build vocabulary from label file
332
+ """
333
+ try:
334
+ df = pd.read_csv(
335
+ label_file, sep="\t", header=None, names=["filename", "label"]
336
+ )
337
+ all_labels_text = " ".join(df["label"].astype(str).tolist())
338
+ tokens = sorted(set(all_labels_text.split()))
339
+ for char in tokens:
340
+ self.add_word(char)
341
+ except Exception as e:
342
+ print(f"Error building vocabulary from {label_file}: {e}")
343
+
344
+ def save_vocab(self, path):
345
+ """
346
+ Save vocabulary to file
347
+ """
348
+ data = {"word2idx": self.word2idx, "idx2word": self.idx2word, "idx": self.idx}
349
+ torch.save(data, path)
350
+
351
+ def load_vocab(self, path):
352
+ """
353
+ Load vocabulary from file
354
+ """
355
+ data = torch.load(path)
356
+ self.word2idx = data["word2idx"]
357
+ self.idx2word = data["idx2word"]
358
+ self.idx = data["idx"]
359
+
360
+ # Update special tokens
361
+ self.pad_token = self.word2idx["<pad>"]
362
+ self.start_token = self.word2idx["<start>"]
363
+ self.end_token = self.word2idx["<end>"]
364
+ self.unk_token = self.word2idx["<unk>"]
365
+
366
+
367
+ def build_unified_vocabulary(base_dir="data/CROHME"):
368
+ """
369
+ Build a unified vocabulary from all caption.txt files
370
+
371
+ Args:
372
+ base_dir: Root directory containing CROHME data
373
+
374
+ Returns:
375
+ Constructed Vocabulary object
376
+ """
377
+ vocab = Vocabulary()
378
+ # Get all subdirectories
379
+ subdirs = [
380
+ d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))
381
+ ]
382
+
383
+ for subdir in subdirs:
384
+ caption_path = os.path.join(base_dir, subdir, "caption.txt")
385
+ if os.path.exists(caption_path):
386
+ vocab.build_vocab(caption_path)
387
+ print(f"Built vocabulary from {caption_path}")
388
+
389
+ print(f"Final vocabulary size: {len(vocab)}")
390
+ return vocab
391
+
392
+
393
+ def create_dataloaders_for_can(base_dir="data/CROHME", batch_size=32, num_workers=4):
394
+ """
395
+ Create dataloaders for training the CAN model
396
+
397
+ Args:
398
+ base_dir: Root directory containing CROHME data
399
+ batch_size: Batch size
400
+ num_workers: Number of workers for DataLoader
401
+
402
+ Returns:
403
+ train_loader, val_loader, test_loader, vocab
404
+ """
405
+ # Build unified vocabulary
406
+ vocab = build_unified_vocabulary(base_dir)
407
+
408
+ # Save vocabulary for later use
409
+ os.makedirs("models", exist_ok=True)
410
+ vocab.save_vocab("models/hmer_vocab.pth")
411
+
412
+ # Create transform for grayscale data
413
+ transform = A.Compose(
414
+ [
415
+ A.Normalize(
416
+ mean=[0.0], std=[1.0]
417
+ ), # Normalize for single channel (grayscale)
418
+ A.pytorch.ToTensorV2(),
419
+ ]
420
+ )
421
+
422
+ # Create datasets
423
+ train_datasets = []
424
+
425
+ # Use 'train' and possibly add other datasets to training set
426
+ train_dirs = ["train", "2014"] # Add other directories if desired
427
+ for train_dir in train_dirs:
428
+ data_folder = os.path.join(base_dir, train_dir, "img")
429
+ label_file = os.path.join(base_dir, train_dir, "caption.txt")
430
+
431
+ if os.path.exists(data_folder) and os.path.exists(label_file):
432
+ train_datasets.append(
433
+ HMERDatasetForCAN(
434
+ data_folder=data_folder,
435
+ label_file=label_file,
436
+ vocab=vocab,
437
+ transform=transform,
438
+ )
439
+ )
440
+
441
+ # Combine training datasets
442
+ if train_datasets:
443
+ train_dataset = ConcatDataset(train_datasets)
444
+ else:
445
+ raise ValueError("No training datasets found")
446
+
447
+ # Validation dataset
448
+ val_data_folder = os.path.join(base_dir, "val", "img")
449
+ val_label_file = os.path.join(base_dir, "val", "caption.txt")
450
+
451
+ if not os.path.exists(val_data_folder) or not os.path.exists(val_label_file):
452
+ # Use '2016' as validation set if 'val' is not available
453
+ val_data_folder = os.path.join(base_dir, "2016", "img")
454
+ val_label_file = os.path.join(base_dir, "2016", "caption.txt")
455
+
456
+ val_dataset = HMERDatasetForCAN(
457
+ data_folder=val_data_folder,
458
+ label_file=val_label_file,
459
+ vocab=vocab,
460
+ transform=transform,
461
+ )
462
+
463
+ # Test dataset
464
+ test_data_folder = os.path.join(base_dir, "test", "img")
465
+ test_label_file = os.path.join(base_dir, "test", "caption.txt")
466
+
467
+ if not os.path.exists(test_data_folder) or not os.path.exists(test_label_file):
468
+ # Use '2019' as test set if 'test' is not available
469
+ test_data_folder = os.path.join(base_dir, "2019", "img")
470
+ test_label_file = os.path.join(base_dir, "2019", "caption.txt")
471
+
472
+ test_dataset = HMERDatasetForCAN(
473
+ data_folder=test_data_folder,
474
+ label_file=test_label_file,
475
+ vocab=vocab,
476
+ transform=transform,
477
+ )
478
+
479
+ # Create dataloaders
480
+ train_loader = DataLoader(
481
+ train_dataset,
482
+ batch_size=batch_size,
483
+ shuffle=True,
484
+ num_workers=num_workers,
485
+ pin_memory=True,
486
+ )
487
+
488
+ val_loader = DataLoader(
489
+ val_dataset,
490
+ batch_size=batch_size,
491
+ shuffle=False,
492
+ num_workers=num_workers,
493
+ pin_memory=True,
494
+ )
495
+
496
+ test_loader = DataLoader(
497
+ test_dataset,
498
+ batch_size=batch_size,
499
+ shuffle=False,
500
+ num_workers=num_workers,
501
+ pin_memory=True,
502
+ )
503
+
504
+ return train_loader, val_loader, test_loader, vocab
505
+
506
+
507
+ # Use functionality integrated with the CAN model
508
+ def main():
509
+ # Create dataloader for the CAN model
510
+ train_loader, val_loader, test_loader, vocab = create_dataloaders_for_can(
511
+ base_dir=BASE_DIR, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS
512
+ )
513
+
514
+ # Print information
515
+ print(f"Training samples: {len(train_loader.dataset)}")
516
+ print(f"Validation samples: {len(val_loader.dataset)}")
517
+ print(f"Test samples: {len(test_loader.dataset)}")
518
+
519
+ # Check dataloader output
520
+ for images, captions, lengths, count_vectors in train_loader:
521
+ print(f"Image batch shape: {images.shape}")
522
+ print(f"Caption batch shape: {captions.shape}")
523
+ print(f"Lengths batch shape: {lengths.shape}")
524
+ print(f"Count vectors batch shape: {count_vectors.shape}")
525
+ break
526
+
527
+
528
+ if __name__ == "__main__":
529
+ main()
models/can/can_eval.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
5
+
6
+ import torch
7
+ import pandas as pd
8
+ from PIL import Image
9
+ import cv2
10
+ import albumentations as A
11
+ from albumentations.pytorch import ToTensorV2
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ from tqdm.auto import tqdm
15
+ import json
16
+ import torch.nn.functional as F
17
+
18
+ from models.can.can import CAN, create_can_model
19
+ from models.can.can_dataloader import Vocabulary, process_img, INPUT_HEIGHT, INPUT_WIDTH
20
+
21
+ torch.serialization.add_safe_globals([Vocabulary])
22
+
23
+ os.environ['QT_QPA_PLATFORM'] = 'offscreen'
24
+
25
+ with open("config.json", "r") as json_file:
26
+ cfg = json.load(json_file)
27
+
28
+ CAN_CONFIG = cfg["can"]
29
+
30
+
31
+ # Global constants here
32
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
+ MODE = CAN_CONFIG["mode"] # 'single' or 'evaluate'
34
+ BACKBONE_TYPE = CAN_CONFIG["backbone_type"]
35
+ PRETRAINED_BACKBONE = True if CAN_CONFIG["pretrained_backbone"] == 1 else False
36
+ CHECKPOINT_PATH = f'checkpoints/{BACKBONE_TYPE}_can_best.pth' if PRETRAINED_BACKBONE == False else f'checkpoints/p_{BACKBONE_TYPE}_can_best.pth'
37
+ IMAGE_PATH = f'{CAN_CONFIG["test_folder"]}/{CAN_CONFIG["relative_image_path"]}'
38
+ VISUALIZE = True if CAN_CONFIG["visualize"] == 1 else False
39
+ TEST_FOLDER = CAN_CONFIG["test_folder"]
40
+ LABEL_FILE = CAN_CONFIG["label_file"]
41
+ CLASSIFIER = CAN_CONFIG["classifier"] # choose between 'frac', 'sum_or_lim', 'long_expr', and 'all'
42
+
43
+
44
+ def filter_formula(formula_tokens, mode):
45
+ if mode == "frac":
46
+ return "\\frac" in formula_tokens
47
+ elif mode == "sum_or_lim":
48
+ return "\\sum" in formula_tokens or "\\limit" in formula_tokens
49
+ elif mode == "long_expr":
50
+ return len(formula_tokens) >= 30
51
+ elif mode == 'short_expr':
52
+ return len(formula_tokens) <= 10
53
+ return True
54
+
55
+
56
+ def levenshtein_distance(lst1, lst2):
57
+ """
58
+ Calculate Levenshtein distance between two lists
59
+ """
60
+ m = len(lst1)
61
+ n = len(lst2)
62
+
63
+ prev_row = [j for j in range(n + 1)]
64
+ curr_row = [0] * (n + 1)
65
+ for i in range(1, m + 1):
66
+ curr_row[0] = i
67
+
68
+ for j in range(1, n + 1):
69
+ if lst1[i - 1] == lst2[j - 1]:
70
+ curr_row[j] = prev_row[j - 1]
71
+ else:
72
+ curr_row[j] = 1 + min(
73
+ curr_row[j - 1], # insertion
74
+ prev_row[j], # deletion
75
+ prev_row[j - 1] # substitution
76
+ )
77
+
78
+ prev_row = curr_row.copy()
79
+ return curr_row[n]
80
+
81
+
82
+ def load_checkpoint(checkpoint_path, device, pretrained_backbone=True, backbone='densenet'):
83
+ """
84
+ Load checkpoint and return model and vocabulary
85
+ """
86
+ checkpoint = torch.load(checkpoint_path,
87
+ map_location=device,
88
+ weights_only=False)
89
+
90
+ vocab = checkpoint.get('vocab')
91
+ if vocab is None:
92
+ # Try to load vocab from a separate file if not in checkpoint
93
+ vocab_path = os.path.join(os.path.dirname(checkpoint_path),
94
+ 'hmer_vocab.pth')
95
+ if os.path.exists(vocab_path):
96
+ vocab_data = torch.load(vocab_path)
97
+ vocab = Vocabulary()
98
+ vocab.word2idx = vocab_data['word2idx']
99
+ vocab.idx2word = vocab_data['idx2word']
100
+ vocab.idx = vocab_data['idx']
101
+ # Update special tokens
102
+ vocab.pad_token = vocab.word2idx['<pad>']
103
+ vocab.start_token = vocab.word2idx['<start>']
104
+ vocab.end_token = vocab.word2idx['<end>']
105
+ vocab.unk_token = vocab.word2idx['<unk>']
106
+ else:
107
+ raise ValueError(
108
+ f"Vocabulary not found in checkpoint and {vocab_path} does not exist"
109
+ )
110
+
111
+ # Initialize model with parameters from checkpoint
112
+ hidden_size = checkpoint.get('hidden_size', 256)
113
+ embedding_dim = checkpoint.get('embedding_dim', 256)
114
+ use_coverage = checkpoint.get('use_coverage', True)
115
+
116
+ model = create_can_model(num_classes=len(vocab),
117
+ hidden_size=hidden_size,
118
+ embedding_dim=embedding_dim,
119
+ use_coverage=use_coverage,
120
+ pretrained_backbone=pretrained_backbone,
121
+ backbone_type=backbone).to(device)
122
+
123
+ model.load_state_dict(checkpoint['model'])
124
+ print(f"Loaded model from checkpoint {checkpoint_path}")
125
+
126
+ return model, vocab
127
+
128
+
129
+ def recognize_single_image(model,
130
+ image_path,
131
+ vocab,
132
+ device,
133
+ max_length=150,
134
+ visualize_attention=False):
135
+ """
136
+ Recognize handwritten mathematical expression from a single image using the CAN model
137
+ """
138
+ # Prepare image transform for grayscale images
139
+ transform = A.Compose([
140
+ A.Normalize(mean=[0.0], std=[1.0]), # For grayscale
141
+ A.pytorch.ToTensorV2()
142
+ ])
143
+
144
+ # Load and transform image
145
+ processed_img, best_crop = process_img(image_path, convert_to_rgb=False)
146
+
147
+ # Ensure image has the correct format for albumentations
148
+ processed_img = np.expand_dims(processed_img, axis=-1) # [H, W, 1]
149
+ image_tensor = transform(
150
+ image=processed_img)['image'].unsqueeze(0).to(device)
151
+
152
+ model.eval()
153
+ with torch.no_grad():
154
+ # Generate LaTeX using beam search
155
+ predictions, attention_weights = model.recognize(
156
+ image_tensor,
157
+ max_length=max_length,
158
+ start_token=vocab.start_token,
159
+ end_token=vocab.end_token,
160
+ beam_width=5 # Use beam search with width 5
161
+ )
162
+
163
+ # Convert indices to LaTeX tokens
164
+ latex_tokens = []
165
+ for idx in predictions:
166
+ if idx == vocab.end_token:
167
+ break
168
+ if idx != vocab.start_token: # Skip start token
169
+ latex_tokens.append(vocab.idx2word[idx])
170
+
171
+ # Join tokens to get complete LaTeX
172
+ latex = ' '.join(latex_tokens)
173
+
174
+ # Visualize attention if requested
175
+ if visualize_attention and attention_weights is not None:
176
+ visualize_attention_maps(processed_img, attention_weights,
177
+ latex_tokens, best_crop)
178
+
179
+ return latex
180
+
181
+
182
+ def visualize_attention_maps(orig_image,
183
+ attention_weights,
184
+ latex_tokens,
185
+ best_crop,
186
+ max_cols=4):
187
+ """
188
+ Visualize attention maps over the image for CAN model
189
+ """
190
+ # Create PIL image from numpy array
191
+ orig_image = orig_image.crop(best_crop)
192
+ orig_w, orig_h = orig_image.size
193
+ ratio = INPUT_HEIGHT / INPUT_WIDTH
194
+
195
+ num_tokens = len(latex_tokens)
196
+ num_cols = min(max_cols, num_tokens)
197
+ num_rows = int(np.ceil(num_tokens / num_cols))
198
+
199
+ fig, axes = plt.subplots(num_rows,
200
+ num_cols,
201
+ figsize=(num_cols * 3, int(num_rows * 6 * orig_h / orig_w)))
202
+ axes = np.array(axes).reshape(-1)
203
+
204
+ for i, (token, attn) in enumerate(zip(latex_tokens, attention_weights)):
205
+ ax = axes[i]
206
+
207
+ attn = attn[0:1].squeeze(0)
208
+ attn_len = attn.shape[0]
209
+ attn_w = int(np.sqrt(attn_len / ratio))
210
+ attn_h = int(np.sqrt(attn_len * ratio))
211
+
212
+ # resize to (orig_h, interpolated_w)
213
+ attn = attn.view(1, 1, attn_h, attn_w)
214
+ interp_w = int(orig_h / ratio)
215
+
216
+ attn = F.interpolate(attn, size=(orig_h, interp_w), mode='bilinear', align_corners=False)
217
+ attn = attn.squeeze().cpu().numpy()
218
+
219
+ # fix aspect ratio mismatch
220
+ if interp_w > orig_w:
221
+ # center crop width
222
+ start = (interp_w - orig_w) // 2
223
+ attn = attn[:, start:start + orig_w]
224
+ elif interp_w < orig_w:
225
+ # stretch to fit width
226
+ attn = cv2.resize(attn, (orig_w, orig_h), interpolation=cv2.INTER_CUBIC)
227
+
228
+ ax.imshow(orig_image)
229
+ ax.imshow(attn, cmap='jet', alpha=0.4)
230
+ ax.set_title(f'{token}', fontsize=10 * 8 * orig_h / orig_w)
231
+ ax.axis('off')
232
+
233
+ for j in range(i + 1, len(axes)):
234
+ axes[j].axis('off')
235
+
236
+ plt.tight_layout()
237
+ plt.savefig('attention_maps_can.png', bbox_inches='tight', dpi=150)
238
+ plt.close()
239
+
240
+
241
+ def evaluate_model(model,
242
+ test_folder,
243
+ label_file,
244
+ vocab,
245
+ device,
246
+ max_length=150,
247
+ batch_size=32):
248
+ """
249
+ Evaluate CAN model on test set
250
+ """
251
+ df = pd.read_csv(label_file,
252
+ sep='\t',
253
+ header=None,
254
+ names=['filename', 'label'])
255
+
256
+ # Check image file format
257
+ if os.path.exists(test_folder):
258
+ img_files = os.listdir(test_folder)
259
+ if img_files:
260
+ # Get the extension of the first file
261
+ extension = os.path.splitext(img_files[0])[1]
262
+ # Add extension to filenames if not present
263
+ df['filename'] = df['filename'].apply(
264
+ lambda x: x if os.path.splitext(x)[1] else x + extension)
265
+
266
+ annotations = dict(zip(df['filename'], df['label']))
267
+
268
+ model.eval()
269
+
270
+ correct = 0
271
+ err1 = 0
272
+ err2 = 0
273
+ err3 = 0
274
+ total = 0
275
+
276
+ transform = A.Compose([
277
+ A.Normalize(mean=[0.0], std=[1.0]), # For grayscale
278
+ A.pytorch.ToTensorV2()
279
+ ])
280
+
281
+ results = {}
282
+
283
+ for image_path, gt_latex in tqdm(annotations.items(), desc="Evaluating"):
284
+ gt_latex: str = gt_latex
285
+ if not filter_formula(gt_latex.split(), CLASSIFIER):
286
+ continue
287
+ file_path = os.path.join(test_folder, image_path)
288
+
289
+ try:
290
+ processed_img, _ = process_img(file_path, convert_to_rgb=False)
291
+
292
+ # Ensure image has the correct format for albumentations
293
+ processed_img = np.expand_dims(processed_img, axis=-1) # [H, W, 1]
294
+ image_tensor = transform(
295
+ image=processed_img)['image'].unsqueeze(0).to(device)
296
+
297
+ with torch.no_grad():
298
+ predictions, _ = model.recognize(
299
+ image_tensor,
300
+ max_length=max_length,
301
+ start_token=vocab.start_token,
302
+ end_token=vocab.end_token,
303
+ beam_width=5 # Use beam search
304
+ )
305
+
306
+ # Convert indices to LaTeX tokens
307
+ pred_latex_tokens = []
308
+ for idx in predictions:
309
+ if idx == vocab.end_token:
310
+ break
311
+ if idx != vocab.start_token: # Skip start token
312
+ pred_latex_tokens.append(vocab.idx2word[idx])
313
+
314
+ pred_latex = ' '.join(pred_latex_tokens)
315
+
316
+ gt_latex_tokens = gt_latex.split()
317
+ edit_distance = levenshtein_distance(pred_latex_tokens,
318
+ gt_latex_tokens)
319
+
320
+ if edit_distance == 0:
321
+ correct += 1
322
+ elif edit_distance == 1:
323
+ err1 += 1
324
+ elif edit_distance == 2:
325
+ err2 += 1
326
+ elif edit_distance == 3:
327
+ err3 += 1
328
+
329
+ total += 1
330
+
331
+ # Save result
332
+ results[image_path] = {
333
+ 'ground_truth': gt_latex,
334
+ 'prediction': pred_latex,
335
+ 'edit_distance': edit_distance
336
+ }
337
+ except Exception as e:
338
+ print(f"Error processing {image_path}: {e}")
339
+
340
+ # Calculate accuracy metrics
341
+ exprate = round(correct / total, 4) if total > 0 else 0
342
+ exprate_leq1 = round((correct + err1) / total, 4) if total > 0 else 0
343
+ exprate_leq2 = round(
344
+ (correct + err1 + err2) / total, 4) if total > 0 else 0
345
+ exprate_leq3 = round(
346
+ (correct + err1 + err2 + err3) / total, 4) if total > 0 else 0
347
+
348
+ print(f"Exact match rate: {exprate:.4f}")
349
+ print(f"Edit distance ≤ 1: {exprate_leq1:.4f}")
350
+ print(f"Edit distance ≤ 2: {exprate_leq2:.4f}")
351
+ print(f"Edit distance ≤ 3: {exprate_leq3:.4f}")
352
+
353
+ # Save results to file
354
+ with open('evaluation_results_can.json', 'w', encoding='utf-8') as f:
355
+ json.dump(
356
+ {
357
+ 'accuracy': {
358
+ 'exprate': exprate,
359
+ 'exprate_leq1': exprate_leq1,
360
+ 'exprate_leq2': exprate_leq2,
361
+ 'exprate_leq3': exprate_leq3
362
+ },
363
+ 'results': results
364
+ },
365
+ f,
366
+ indent=4)
367
+
368
+ return {
369
+ 'exprate': exprate,
370
+ 'exprate_leq1': exprate_leq1,
371
+ 'exprate_leq2': exprate_leq2,
372
+ 'exprate_leq3': exprate_leq3
373
+ }, results
374
+
375
+
376
+ def main(mode):
377
+ device = DEVICE
378
+ print(f'Using device: {device}')
379
+
380
+ checkpoint_path = CHECKPOINT_PATH
381
+ backbone = BACKBONE_TYPE
382
+ pretrained_backbone = PRETRAINED_BACKBONE
383
+
384
+ # For single mode
385
+ image_path = IMAGE_PATH
386
+ visualize = VISUALIZE
387
+
388
+ # For evaluation mode
389
+ test_folder = TEST_FOLDER
390
+ label_file = LABEL_FILE
391
+
392
+ # Load model and vocabulary
393
+ model, vocab = load_checkpoint(checkpoint_path, device, pretrained_backbone=pretrained_backbone, backbone=backbone)
394
+
395
+ if mode == 'single':
396
+ if image_path is None:
397
+ raise ValueError('Image path is required for single mode')
398
+
399
+ latex = recognize_single_image(model,
400
+ image_path,
401
+ vocab,
402
+ device,
403
+ visualize_attention=visualize)
404
+ print(f'Recognized LaTeX: {latex}')
405
+
406
+ elif mode == 'evaluate':
407
+ if test_folder is None or label_file is None:
408
+ raise ValueError(
409
+ 'Test folder and annotation file are required for evaluate mode'
410
+ )
411
+
412
+ metrics, results = evaluate_model(model, test_folder, label_file,
413
+ vocab, device)
414
+ print(f"##### Score of {CLASSIFIER} expression type: #####")
415
+ print(f'Evaluation metrics: {metrics}')
416
+
417
+
418
+ if __name__ == '__main__':
419
+ # Ensure Vocabulary is safe for serialization
420
+ torch.serialization.add_safe_globals([Vocabulary])
421
+
422
+ # Run the main function
423
+ main(MODE)
models/can/can_trainer.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ import numpy as np
10
+ from torch.utils.data import DataLoader
11
+ import time
12
+ import wandb
13
+ from datetime import datetime
14
+ from tqdm.auto import tqdm
15
+
16
+ from models.can.can import CAN, create_can_model
17
+ from models.can.can_dataloader import create_dataloaders_for_can, Vocabulary
18
+
19
+ import albumentations as A
20
+ import cv2
21
+ import random
22
+
23
+ import json
24
+
25
+ with open("config.json", "r") as json_file:
26
+ cfg = json.load(json_file)
27
+
28
+ CAN_CONFIG = cfg["can"]
29
+
30
+
31
+ # Global constants
32
+ BASE_DIR = CAN_CONFIG["base_dir"]
33
+ SEED = CAN_CONFIG["seed"]
34
+ CHECKPOINT_DIR = CAN_CONFIG["checkpoint_dir"]
35
+ PRETRAINED_BACKBONE = True if CAN_CONFIG["pretrained_backbone"] == 1 else False
36
+ BACKBONE_TYPE = CAN_CONFIG["backbone_type"]
37
+ CHECKPOINT_NAME = f'{BACKBONE_TYPE}_can_best.pth' if PRETRAINED_BACKBONE == False else f'p_{BACKBONE_TYPE}_can_best.pth'
38
+ BATCH_SIZE = CAN_CONFIG["batch_size"]
39
+
40
+ HIDDEN_SIZE = CAN_CONFIG["hidden_size"]
41
+ EMBEDDING_DIM = CAN_CONFIG["embedding_dim"]
42
+ USE_COVERAGE = True if CAN_CONFIG["use_coverage"] == 1 else False
43
+ LAMBDA_COUNT = CAN_CONFIG["lambda_count"]
44
+
45
+ LR = CAN_CONFIG["lr"]
46
+ EPOCHS = CAN_CONFIG["epochs"]
47
+ GRAD_CLIP = CAN_CONFIG["grad_clip"]
48
+ PRINT_FREQ = CAN_CONFIG["print_freq"]
49
+
50
+ T = CAN_CONFIG["t"]
51
+ T_MULT = CAN_CONFIG["t_mult"]
52
+
53
+ PROJECT_NAME = f'final-hmer-can-{BACKBONE_TYPE}-pretrained' if PRETRAINED_BACKBONE == True else f'final-hmer-can-{BACKBONE_TYPE}'
54
+ NUM_WORKERS = cfg["can"]["num_workers"]
55
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
56
+
57
+ class RandomMorphology(A.ImageOnlyTransform):
58
+
59
+ def __init__(self, p=0.5, kernel_size=3):
60
+ super(RandomMorphology, self).__init__(p)
61
+ self.kernel_size = kernel_size
62
+
63
+ def apply(self, img, **params):
64
+ op = random.choice(['erode', 'dilate'])
65
+ kernel = np.ones((self.kernel_size, self.kernel_size), np.uint8)
66
+ if op == 'erode':
67
+ return cv2.erode(img, kernel, iterations=1)
68
+ else:
69
+ return cv2.dilate(img, kernel, iterations=1)
70
+
71
+
72
+ # Custom transforms for CAN model (grayscale images)
73
+ train_transforms = A.Compose([
74
+ A.Rotate(limit=5, p=0.25, border_mode=cv2.BORDER_REPLICATE),
75
+ A.ElasticTransform(alpha=100,
76
+ sigma=7,
77
+ p=0.5,
78
+ interpolation=cv2.INTER_CUBIC),
79
+ RandomMorphology(p=0.5, kernel_size=2),
80
+ A.Normalize(mean=[0.0], std=[1.0]), # For grayscale
81
+ A.pytorch.ToTensorV2()
82
+ ])
83
+
84
+
85
+ def train_epoch(model,
86
+ train_loader,
87
+ optimizer,
88
+ device,
89
+ grad_clip=5.0,
90
+ lambda_count=0.01,
91
+ print_freq=10):
92
+ """
93
+ Train the model for one epoch
94
+ """
95
+ model.train()
96
+ total_loss = 0.0
97
+ total_cls_loss = 0.0
98
+ total_count_loss = 0.0
99
+ batch_count = 0
100
+
101
+ for i, (images, captions, caption_lengths,
102
+ count_targets) in tqdm(enumerate(train_loader),
103
+ total=len(train_loader)):
104
+ batch_count += 1
105
+ images = images.to(device)
106
+ captions = captions.to(device)
107
+ count_targets = count_targets.to(device)
108
+
109
+ # Forward pass
110
+ outputs, count_vectors = model(images,
111
+ captions,
112
+ teacher_forcing_ratio=0.5)
113
+
114
+ # Calculate loss
115
+ loss, cls_loss, counting_loss = model.calculate_loss(
116
+ outputs=outputs,
117
+ targets=captions,
118
+ count_vectors=count_vectors,
119
+ count_targets=count_targets,
120
+ lambda_count=lambda_count)
121
+
122
+ # Backward pass
123
+ optimizer.zero_grad()
124
+ loss.backward()
125
+
126
+ # Clip gradients
127
+ if grad_clip:
128
+ nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
129
+
130
+ # Update weights
131
+ optimizer.step()
132
+
133
+ # Track losses
134
+ total_loss += loss.item()
135
+ total_cls_loss += cls_loss.item()
136
+ total_count_loss += counting_loss.item()
137
+
138
+ # Print progress
139
+ if i % print_freq == 0 and i > 0:
140
+ print(
141
+ f'Batch {i}/{len(train_loader)}, Loss: {loss.item():.4f}, '
142
+ f'Cls Loss: {cls_loss.item():.4f}, Count Loss: {counting_loss.item():.4f}'
143
+ )
144
+
145
+ return total_loss / batch_count, total_cls_loss / batch_count, total_count_loss / batch_count
146
+
147
+
148
+ def validate(model, val_loader, device, lambda_count=0.01):
149
+ """
150
+ Validate the model
151
+ """
152
+ model.eval()
153
+ total_loss = 0.0
154
+ total_cls_loss = 0.0
155
+ total_count_loss = 0.0
156
+ batch_count = 0
157
+
158
+ with torch.no_grad():
159
+ for i, (images, captions, caption_lengths,
160
+ count_targets) in tqdm(enumerate(val_loader),
161
+ total=len(val_loader)):
162
+ batch_count += 1
163
+ images = images.to(device)
164
+ captions = captions.to(device)
165
+ count_targets = count_targets.to(device)
166
+
167
+ # Forward pass
168
+ outputs, count_vectors = model(
169
+ images, captions,
170
+ teacher_forcing_ratio=0.0) # No teacher forcing in validation
171
+
172
+ # Calculate loss
173
+ loss, cls_loss, counting_loss = model.calculate_loss(
174
+ outputs=outputs,
175
+ targets=captions,
176
+ count_vectors=count_vectors,
177
+ count_targets=count_targets,
178
+ lambda_count=lambda_count)
179
+
180
+ # Track losses
181
+ total_loss += loss.item()
182
+ total_cls_loss += cls_loss.item()
183
+ total_count_loss += counting_loss.item()
184
+
185
+ return total_loss / batch_count, total_cls_loss / batch_count, total_count_loss / batch_count
186
+
187
+
188
+ def main():
189
+ # Configuration
190
+ dataset_dir = BASE_DIR
191
+ seed = SEED
192
+ checkpoints_dir = CHECKPOINT_DIR
193
+ checkpoint_name = CHECKPOINT_NAME
194
+ batch_size = BATCH_SIZE
195
+
196
+ # Model parameters
197
+ hidden_size = HIDDEN_SIZE
198
+ embedding_dim = EMBEDDING_DIM
199
+ use_coverage = USE_COVERAGE
200
+ lambda_count = LAMBDA_COUNT
201
+
202
+ # Training parameters
203
+ lr = LR
204
+ epochs = EPOCHS
205
+ grad_clip = GRAD_CLIP
206
+ print_freq = PRINT_FREQ
207
+
208
+ # Scheduler parameters
209
+ T_0 = T
210
+ T_mult = T_MULT
211
+
212
+ # Set random seeds
213
+ torch.manual_seed(seed)
214
+ np.random.seed(seed)
215
+ if torch.cuda.is_available():
216
+ torch.cuda.manual_seed(seed)
217
+
218
+ # Create checkpoint directory
219
+ os.makedirs(checkpoints_dir, exist_ok=True)
220
+
221
+ # Set device
222
+ device = DEVICE
223
+ print(f'Using device: {device}')
224
+
225
+ # Create dataloaders
226
+ train_loader, val_loader, test_loader, vocab = create_dataloaders_for_can(
227
+ base_dir=dataset_dir, batch_size=batch_size, num_workers=NUM_WORKERS)
228
+
229
+ print(f"Training samples: {len(train_loader.dataset)}")
230
+ print(f"Validation samples: {len(val_loader.dataset)}")
231
+ print(f"Test samples: {len(test_loader.dataset)}")
232
+ print(f"Vocabulary size: {len(vocab)}")
233
+
234
+ # Create model
235
+ model = create_can_model(num_classes=len(vocab),
236
+ hidden_size=hidden_size,
237
+ embedding_dim=embedding_dim,
238
+ use_coverage=use_coverage,
239
+ pretrained_backbone=PRETRAINED_BACKBONE,
240
+ backbone_type=BACKBONE_TYPE).to(device)
241
+
242
+ # Create optimizer
243
+ optimizer = optim.Adam(model.parameters(), lr=lr)
244
+
245
+ # Create learning rate scheduler
246
+ scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
247
+ T_0=T_0,
248
+ T_mult=T_mult)
249
+
250
+ # Initialize wandb
251
+ run_name = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
252
+ wandb.init(project=PROJECT_NAME,
253
+ name=run_name,
254
+ config={
255
+ 'seed': seed,
256
+ 'batch_size': batch_size,
257
+ 'hidden_size': hidden_size,
258
+ 'embedding_dim': embedding_dim,
259
+ 'use_coverage': use_coverage,
260
+ 'lambda_count': lambda_count,
261
+ 'lr': lr,
262
+ 'epochs': epochs,
263
+ 'grad_clip': grad_clip,
264
+ 'T_0': T_0,
265
+ 'T_mult': T_mult
266
+ })
267
+
268
+ # Training loop
269
+ best_val_loss = float('inf')
270
+
271
+ for epoch in tqdm(range(epochs)):
272
+ curr_lr = scheduler.get_last_lr()[0]
273
+ print(f'Epoch {epoch+1:03}/{epochs:03}')
274
+ t1 = time.time()
275
+
276
+ # Train
277
+ train_loss, train_cls_loss, train_count_loss = train_epoch(
278
+ model=model,
279
+ train_loader=train_loader,
280
+ optimizer=optimizer,
281
+ device=device,
282
+ grad_clip=grad_clip,
283
+ lambda_count=lambda_count,
284
+ print_freq=print_freq)
285
+
286
+ # Validate
287
+ val_loss, val_cls_loss, val_count_loss = validate(
288
+ model=model,
289
+ val_loader=val_loader,
290
+ device=device,
291
+ lambda_count=lambda_count)
292
+
293
+ # Update learning rate
294
+ scheduler.step()
295
+ t2 = time.time()
296
+
297
+ # Print stats
298
+ print(
299
+ f'Train - Total Loss: {train_loss:.4f}, Class Loss: {train_cls_loss:.4f}, Count Loss: {train_count_loss:.4f}'
300
+ )
301
+ print(
302
+ f'Val - Total Loss: {val_loss:.4f}, Class Loss: {val_cls_loss:.4f}, Count Loss: {val_count_loss:.4f}'
303
+ )
304
+ print(f'Time: {t2 - t1:.2f}s, Learning Rate: {curr_lr:.6f}')
305
+
306
+ # Log metrics to wandb
307
+ wandb.log({
308
+ 'train_loss': train_loss,
309
+ 'train_cls_loss': train_cls_loss,
310
+ 'train_count_loss': train_count_loss,
311
+ 'val_loss': val_loss,
312
+ 'val_cls_loss': val_cls_loss,
313
+ 'val_count_loss': val_count_loss,
314
+ 'learning_rate': curr_lr,
315
+ 'epoch': epoch
316
+ })
317
+
318
+ # Save checkpoint
319
+ if val_loss < best_val_loss:
320
+ best_val_loss = val_loss
321
+ checkpoint = {
322
+ 'epoch': epoch,
323
+ 'model': model.state_dict(),
324
+ 'optimizer': optimizer.state_dict(),
325
+ 'val_loss': val_loss,
326
+ 'vocab': vocab
327
+ }
328
+ torch.save(checkpoint, os.path.join(checkpoints_dir,
329
+ checkpoint_name))
330
+ print('Model saved!')
331
+
332
+ print('Training completed!')
333
+
334
+
335
+ if __name__ == "__main__":
336
+ main()