Upload 3 files
Browse files- models/can/can_dataloader.py +529 -0
- models/can/can_eval.py +423 -0
- models/can/can_trainer.py +336 -0
models/can/can_dataloader.py
ADDED
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Dataset, DataLoader, ConcatDataset
|
4 |
+
import albumentations as A
|
5 |
+
from PIL import Image
|
6 |
+
import pandas as pd
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
from collections import Counter
|
10 |
+
|
11 |
+
import json
|
12 |
+
|
13 |
+
with open("config.json", "r") as json_file:
|
14 |
+
cfg = json.load(json_file)
|
15 |
+
|
16 |
+
CAN_CONFIG = cfg["can"]
|
17 |
+
|
18 |
+
|
19 |
+
# Global constants
|
20 |
+
INPUT_HEIGHT = CAN_CONFIG["input_height"]
|
21 |
+
INPUT_WIDTH = CAN_CONFIG["input_width"]
|
22 |
+
BASE_DIR = CAN_CONFIG["base_dir"]
|
23 |
+
BATCH_SIZE = CAN_CONFIG["batch_size"]
|
24 |
+
NUM_WORKERS = CAN_CONFIG["num_workers"]
|
25 |
+
|
26 |
+
|
27 |
+
def is_effectively_binary(img, threshold_percentage=0.9):
|
28 |
+
dark_pixels = np.sum(img < 20)
|
29 |
+
bright_pixels = np.sum(img > 235)
|
30 |
+
total_pixels = img.size
|
31 |
+
|
32 |
+
return (dark_pixels + bright_pixels) / total_pixels > threshold_percentage
|
33 |
+
|
34 |
+
|
35 |
+
def before_padding(image):
|
36 |
+
# Apply Canny edge detector to find text edges
|
37 |
+
edges = cv2.Canny(image, 50, 150)
|
38 |
+
|
39 |
+
# Apply dilation to connect nearby edges
|
40 |
+
kernel = np.ones((7, 13), np.uint8)
|
41 |
+
dilated = cv2.dilate(edges, kernel, iterations=8)
|
42 |
+
|
43 |
+
# Find connected components
|
44 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
45 |
+
dilated, connectivity=8
|
46 |
+
)
|
47 |
+
|
48 |
+
# Optimize crop rectangle using F1 score
|
49 |
+
# Sort components by number of white pixels (excluding background which is label 0)
|
50 |
+
sorted_components = sorted(
|
51 |
+
range(1, num_labels), key=lambda i: stats[i, cv2.CC_STAT_AREA], reverse=True
|
52 |
+
)
|
53 |
+
|
54 |
+
# Initialize with empty crop
|
55 |
+
best_f1 = 0
|
56 |
+
best_crop = (0, 0, image.shape[1], image.shape[0])
|
57 |
+
total_white_pixels = np.sum(dilated > 0)
|
58 |
+
|
59 |
+
current_mask = np.zeros_like(dilated)
|
60 |
+
x_min, y_min = image.shape[1], image.shape[0]
|
61 |
+
x_max, y_max = 0, 0
|
62 |
+
|
63 |
+
for component_idx in sorted_components:
|
64 |
+
# Add this component to our mask
|
65 |
+
component_mask = labels == component_idx
|
66 |
+
current_mask = np.logical_or(current_mask, component_mask)
|
67 |
+
|
68 |
+
# Update bounding box
|
69 |
+
comp_y, comp_x = np.where(component_mask)
|
70 |
+
if len(comp_x) > 0 and len(comp_y) > 0:
|
71 |
+
x_min = min(x_min, np.min(comp_x))
|
72 |
+
y_min = min(y_min, np.min(comp_y))
|
73 |
+
x_max = max(x_max, np.max(comp_x))
|
74 |
+
y_max = max(y_max, np.max(comp_y))
|
75 |
+
|
76 |
+
# Calculate the current crop
|
77 |
+
width = x_max - x_min + 1
|
78 |
+
height = y_max - y_min + 1
|
79 |
+
crop_area = width * height
|
80 |
+
|
81 |
+
crop_mask = np.zeros_like(dilated)
|
82 |
+
crop_mask[y_min : y_max + 1, x_min : x_max + 1] = 1
|
83 |
+
white_in_crop = np.sum(np.logical_and(dilated > 0, crop_mask > 0))
|
84 |
+
|
85 |
+
# Calculate F1 score
|
86 |
+
precision = white_in_crop / crop_area
|
87 |
+
recall = white_in_crop / total_white_pixels
|
88 |
+
f1 = 2 * precision * recall / (precision + recall)
|
89 |
+
|
90 |
+
if f1 > best_f1:
|
91 |
+
best_f1 = f1
|
92 |
+
best_crop = (x_min, y_min, x_max, y_max)
|
93 |
+
|
94 |
+
# Apply the best crop to the original image
|
95 |
+
x_min, y_min, x_max, y_max = best_crop
|
96 |
+
cropped_image = image[y_min : y_max + 1, x_min : x_max + 1]
|
97 |
+
|
98 |
+
# Apply Gaussian adaptive thresholding
|
99 |
+
if is_effectively_binary(cropped_image):
|
100 |
+
_, thresh = cv2.threshold(cropped_image, 127, 255, cv2.THRESH_BINARY)
|
101 |
+
else:
|
102 |
+
thresh = cv2.adaptiveThreshold(
|
103 |
+
cropped_image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2
|
104 |
+
)
|
105 |
+
|
106 |
+
# Ensure background is black
|
107 |
+
white = np.sum(thresh == 255)
|
108 |
+
black = np.sum(thresh == 0)
|
109 |
+
if white > black:
|
110 |
+
thresh = 255 - thresh
|
111 |
+
|
112 |
+
# Clean up noise using median filter
|
113 |
+
denoised = cv2.medianBlur(thresh, 3)
|
114 |
+
for _ in range(3):
|
115 |
+
denoised = cv2.medianBlur(denoised, 3)
|
116 |
+
|
117 |
+
# Add padding
|
118 |
+
result = cv2.copyMakeBorder(denoised, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=0)
|
119 |
+
|
120 |
+
return result, best_crop
|
121 |
+
|
122 |
+
|
123 |
+
def process_img(filename, convert_to_rgb=False):
|
124 |
+
"""
|
125 |
+
Load, binarize, ensure black background, resize, and apply padding
|
126 |
+
|
127 |
+
Args:
|
128 |
+
filename: Path to the image file
|
129 |
+
convert_to_rgb: Whether to convert to RGB
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
Processed image and crop information
|
133 |
+
"""
|
134 |
+
image = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)
|
135 |
+
if image is None:
|
136 |
+
raise ValueError(f"Could not read image file: {filename}")
|
137 |
+
|
138 |
+
bin_img, best_crop = before_padding(image)
|
139 |
+
h, w = bin_img.shape
|
140 |
+
new_w = int((INPUT_HEIGHT / h) * w)
|
141 |
+
|
142 |
+
if new_w > INPUT_WIDTH:
|
143 |
+
resized_img = cv2.resize(
|
144 |
+
bin_img, (INPUT_WIDTH, INPUT_HEIGHT), interpolation=cv2.INTER_AREA
|
145 |
+
)
|
146 |
+
else:
|
147 |
+
resized_img = cv2.resize(
|
148 |
+
bin_img, (new_w, INPUT_HEIGHT), interpolation=cv2.INTER_AREA
|
149 |
+
)
|
150 |
+
padded_img = (
|
151 |
+
np.ones((INPUT_HEIGHT, INPUT_WIDTH), dtype=np.uint8) * 0
|
152 |
+
) # Black background
|
153 |
+
x_offset = (INPUT_WIDTH - new_w) // 2
|
154 |
+
padded_img[:, x_offset : x_offset + new_w] = resized_img
|
155 |
+
resized_img = padded_img
|
156 |
+
|
157 |
+
# Convert to BGR/RGB only if necessary
|
158 |
+
if convert_to_rgb:
|
159 |
+
resized_img = cv2.cvtColor(resized_img, cv2.COLOR_GRAY2BGR)
|
160 |
+
|
161 |
+
return resized_img, best_crop
|
162 |
+
|
163 |
+
|
164 |
+
class HMERDatasetForCAN(Dataset):
|
165 |
+
"""
|
166 |
+
Dataset integrated with the CAN model for HMER
|
167 |
+
"""
|
168 |
+
|
169 |
+
def __init__(self, data_folder, label_file, vocab, transform=None, max_length=150):
|
170 |
+
"""
|
171 |
+
Initialize the dataset
|
172 |
+
|
173 |
+
data_folder: Directory containing images
|
174 |
+
label_file: TSV file with two columns (filename, label), no header
|
175 |
+
vocab: Vocabulary object for tokenization
|
176 |
+
transform: Image transformations
|
177 |
+
max_length: Maximum length of the token sequence
|
178 |
+
"""
|
179 |
+
self.data_folder = data_folder
|
180 |
+
self.max_length = max_length
|
181 |
+
self.vocab = vocab
|
182 |
+
|
183 |
+
# Read the label file
|
184 |
+
df = pd.read_csv(label_file, sep="\t", header=None, names=["filename", "label"])
|
185 |
+
|
186 |
+
# Check image file format
|
187 |
+
if os.path.exists(data_folder):
|
188 |
+
img_files = os.listdir(data_folder)
|
189 |
+
if img_files:
|
190 |
+
# Get the extension of the first file
|
191 |
+
extension = os.path.splitext(img_files[0])[1]
|
192 |
+
# Add extension to filenames if not present
|
193 |
+
df["filename"] = df["filename"].apply(
|
194 |
+
lambda x: x if os.path.splitext(x)[1] else x + extension
|
195 |
+
)
|
196 |
+
|
197 |
+
self.annotations = dict(zip(df["filename"], df["label"]))
|
198 |
+
self.image_paths = list(self.annotations.keys())
|
199 |
+
|
200 |
+
# Default transformation
|
201 |
+
if transform is None:
|
202 |
+
transform = A.Compose(
|
203 |
+
[
|
204 |
+
A.Normalize(
|
205 |
+
mean=[0.0], std=[1.0]
|
206 |
+
), # Normalize for single channel (grayscale)
|
207 |
+
A.pytorch.ToTensorV2(),
|
208 |
+
]
|
209 |
+
)
|
210 |
+
self.transform = transform
|
211 |
+
|
212 |
+
def __len__(self):
|
213 |
+
return len(self.image_paths)
|
214 |
+
|
215 |
+
def __getitem__(self, idx):
|
216 |
+
# Get image path and LaTeX expression
|
217 |
+
image_path = self.image_paths[idx]
|
218 |
+
latex = self.annotations[image_path]
|
219 |
+
|
220 |
+
# Process image
|
221 |
+
file_path = os.path.join(self.data_folder, image_path)
|
222 |
+
processed_img, _ = process_img(
|
223 |
+
file_path, convert_to_rgb=False
|
224 |
+
) # Keep image as grayscale
|
225 |
+
|
226 |
+
# Convert to [C, H, W] format and normalize
|
227 |
+
if self.transform:
|
228 |
+
# Ensure image has the correct format for albumentations
|
229 |
+
processed_img = np.expand_dims(processed_img, axis=-1) # [H, W, 1]
|
230 |
+
image = self.transform(image=processed_img)["image"]
|
231 |
+
else:
|
232 |
+
# If no transform, manually convert to tensor
|
233 |
+
image = torch.from_numpy(processed_img).float() / 255.0
|
234 |
+
image = image.unsqueeze(0) # Add grayscale channel: [1, H, W]
|
235 |
+
|
236 |
+
# Tokenize LaTeX expression
|
237 |
+
tokens = self.vocab.tokenize(latex)
|
238 |
+
|
239 |
+
# Add start and end tokens
|
240 |
+
tokens = [self.vocab.start_token] + tokens + [self.vocab.end_token]
|
241 |
+
|
242 |
+
# Truncate if exceeding max length
|
243 |
+
if len(tokens) > self.max_length:
|
244 |
+
tokens = tokens[: self.max_length]
|
245 |
+
|
246 |
+
# Create counting vector for CAN
|
247 |
+
count_vector = self.create_count_vector(tokens)
|
248 |
+
|
249 |
+
# Store actual caption length
|
250 |
+
caption_length = torch.LongTensor([len(tokens)])
|
251 |
+
|
252 |
+
# Pad to max length
|
253 |
+
if len(tokens) < self.max_length:
|
254 |
+
tokens = tokens + [self.vocab.pad_token] * (self.max_length - len(tokens))
|
255 |
+
|
256 |
+
# Convert to tensor
|
257 |
+
caption = torch.LongTensor(tokens)
|
258 |
+
|
259 |
+
return image, caption, caption_length, count_vector
|
260 |
+
|
261 |
+
def create_count_vector(self, tokens):
|
262 |
+
"""
|
263 |
+
Create counting vector for the CAN model
|
264 |
+
|
265 |
+
Args:
|
266 |
+
tokens: List of token IDs
|
267 |
+
|
268 |
+
Returns:
|
269 |
+
Tensor counting the occurrence of each symbol
|
270 |
+
"""
|
271 |
+
# Count occurrences of each token
|
272 |
+
counter = Counter(tokens)
|
273 |
+
|
274 |
+
# Create counting vector with size equal to vocabulary size
|
275 |
+
count_vector = torch.zeros(len(self.vocab))
|
276 |
+
|
277 |
+
# Fill counting vector with counts
|
278 |
+
for token_id, count in counter.items():
|
279 |
+
if 0 <= token_id < len(count_vector):
|
280 |
+
count_vector[token_id] = count
|
281 |
+
|
282 |
+
return count_vector
|
283 |
+
|
284 |
+
|
285 |
+
class Vocabulary:
|
286 |
+
"""
|
287 |
+
Advanced Vocabulary class for tokenization
|
288 |
+
"""
|
289 |
+
|
290 |
+
def __init__(self):
|
291 |
+
self.word2idx = {}
|
292 |
+
self.idx2word = {}
|
293 |
+
self.idx = 0
|
294 |
+
|
295 |
+
# Add special tokens
|
296 |
+
self.add_word("<pad>") # Padding token
|
297 |
+
self.add_word("<start>") # Start token
|
298 |
+
self.add_word("<end>") # End token
|
299 |
+
self.add_word("<unk>") # Unknown token
|
300 |
+
|
301 |
+
self.pad_token = self.word2idx["<pad>"]
|
302 |
+
self.start_token = self.word2idx["<start>"]
|
303 |
+
self.end_token = self.word2idx["<end>"]
|
304 |
+
self.unk_token = self.word2idx["<unk>"]
|
305 |
+
|
306 |
+
def add_word(self, word):
|
307 |
+
if word not in self.word2idx:
|
308 |
+
self.word2idx[word] = self.idx
|
309 |
+
self.idx2word[self.idx] = word
|
310 |
+
self.idx += 1
|
311 |
+
|
312 |
+
def __len__(self):
|
313 |
+
return len(self.word2idx)
|
314 |
+
|
315 |
+
def tokenize(self, latex):
|
316 |
+
"""
|
317 |
+
Tokenize LaTeX string into indices. Assumes tokens are space-separated.
|
318 |
+
"""
|
319 |
+
tokens = []
|
320 |
+
|
321 |
+
for char in latex.split():
|
322 |
+
if char in self.word2idx:
|
323 |
+
tokens.append(self.word2idx[char])
|
324 |
+
else:
|
325 |
+
tokens.append(self.unk_token)
|
326 |
+
|
327 |
+
return tokens
|
328 |
+
|
329 |
+
def build_vocab(self, label_file):
|
330 |
+
"""
|
331 |
+
Build vocabulary from label file
|
332 |
+
"""
|
333 |
+
try:
|
334 |
+
df = pd.read_csv(
|
335 |
+
label_file, sep="\t", header=None, names=["filename", "label"]
|
336 |
+
)
|
337 |
+
all_labels_text = " ".join(df["label"].astype(str).tolist())
|
338 |
+
tokens = sorted(set(all_labels_text.split()))
|
339 |
+
for char in tokens:
|
340 |
+
self.add_word(char)
|
341 |
+
except Exception as e:
|
342 |
+
print(f"Error building vocabulary from {label_file}: {e}")
|
343 |
+
|
344 |
+
def save_vocab(self, path):
|
345 |
+
"""
|
346 |
+
Save vocabulary to file
|
347 |
+
"""
|
348 |
+
data = {"word2idx": self.word2idx, "idx2word": self.idx2word, "idx": self.idx}
|
349 |
+
torch.save(data, path)
|
350 |
+
|
351 |
+
def load_vocab(self, path):
|
352 |
+
"""
|
353 |
+
Load vocabulary from file
|
354 |
+
"""
|
355 |
+
data = torch.load(path)
|
356 |
+
self.word2idx = data["word2idx"]
|
357 |
+
self.idx2word = data["idx2word"]
|
358 |
+
self.idx = data["idx"]
|
359 |
+
|
360 |
+
# Update special tokens
|
361 |
+
self.pad_token = self.word2idx["<pad>"]
|
362 |
+
self.start_token = self.word2idx["<start>"]
|
363 |
+
self.end_token = self.word2idx["<end>"]
|
364 |
+
self.unk_token = self.word2idx["<unk>"]
|
365 |
+
|
366 |
+
|
367 |
+
def build_unified_vocabulary(base_dir="data/CROHME"):
|
368 |
+
"""
|
369 |
+
Build a unified vocabulary from all caption.txt files
|
370 |
+
|
371 |
+
Args:
|
372 |
+
base_dir: Root directory containing CROHME data
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
Constructed Vocabulary object
|
376 |
+
"""
|
377 |
+
vocab = Vocabulary()
|
378 |
+
# Get all subdirectories
|
379 |
+
subdirs = [
|
380 |
+
d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))
|
381 |
+
]
|
382 |
+
|
383 |
+
for subdir in subdirs:
|
384 |
+
caption_path = os.path.join(base_dir, subdir, "caption.txt")
|
385 |
+
if os.path.exists(caption_path):
|
386 |
+
vocab.build_vocab(caption_path)
|
387 |
+
print(f"Built vocabulary from {caption_path}")
|
388 |
+
|
389 |
+
print(f"Final vocabulary size: {len(vocab)}")
|
390 |
+
return vocab
|
391 |
+
|
392 |
+
|
393 |
+
def create_dataloaders_for_can(base_dir="data/CROHME", batch_size=32, num_workers=4):
|
394 |
+
"""
|
395 |
+
Create dataloaders for training the CAN model
|
396 |
+
|
397 |
+
Args:
|
398 |
+
base_dir: Root directory containing CROHME data
|
399 |
+
batch_size: Batch size
|
400 |
+
num_workers: Number of workers for DataLoader
|
401 |
+
|
402 |
+
Returns:
|
403 |
+
train_loader, val_loader, test_loader, vocab
|
404 |
+
"""
|
405 |
+
# Build unified vocabulary
|
406 |
+
vocab = build_unified_vocabulary(base_dir)
|
407 |
+
|
408 |
+
# Save vocabulary for later use
|
409 |
+
os.makedirs("models", exist_ok=True)
|
410 |
+
vocab.save_vocab("models/hmer_vocab.pth")
|
411 |
+
|
412 |
+
# Create transform for grayscale data
|
413 |
+
transform = A.Compose(
|
414 |
+
[
|
415 |
+
A.Normalize(
|
416 |
+
mean=[0.0], std=[1.0]
|
417 |
+
), # Normalize for single channel (grayscale)
|
418 |
+
A.pytorch.ToTensorV2(),
|
419 |
+
]
|
420 |
+
)
|
421 |
+
|
422 |
+
# Create datasets
|
423 |
+
train_datasets = []
|
424 |
+
|
425 |
+
# Use 'train' and possibly add other datasets to training set
|
426 |
+
train_dirs = ["train", "2014"] # Add other directories if desired
|
427 |
+
for train_dir in train_dirs:
|
428 |
+
data_folder = os.path.join(base_dir, train_dir, "img")
|
429 |
+
label_file = os.path.join(base_dir, train_dir, "caption.txt")
|
430 |
+
|
431 |
+
if os.path.exists(data_folder) and os.path.exists(label_file):
|
432 |
+
train_datasets.append(
|
433 |
+
HMERDatasetForCAN(
|
434 |
+
data_folder=data_folder,
|
435 |
+
label_file=label_file,
|
436 |
+
vocab=vocab,
|
437 |
+
transform=transform,
|
438 |
+
)
|
439 |
+
)
|
440 |
+
|
441 |
+
# Combine training datasets
|
442 |
+
if train_datasets:
|
443 |
+
train_dataset = ConcatDataset(train_datasets)
|
444 |
+
else:
|
445 |
+
raise ValueError("No training datasets found")
|
446 |
+
|
447 |
+
# Validation dataset
|
448 |
+
val_data_folder = os.path.join(base_dir, "val", "img")
|
449 |
+
val_label_file = os.path.join(base_dir, "val", "caption.txt")
|
450 |
+
|
451 |
+
if not os.path.exists(val_data_folder) or not os.path.exists(val_label_file):
|
452 |
+
# Use '2016' as validation set if 'val' is not available
|
453 |
+
val_data_folder = os.path.join(base_dir, "2016", "img")
|
454 |
+
val_label_file = os.path.join(base_dir, "2016", "caption.txt")
|
455 |
+
|
456 |
+
val_dataset = HMERDatasetForCAN(
|
457 |
+
data_folder=val_data_folder,
|
458 |
+
label_file=val_label_file,
|
459 |
+
vocab=vocab,
|
460 |
+
transform=transform,
|
461 |
+
)
|
462 |
+
|
463 |
+
# Test dataset
|
464 |
+
test_data_folder = os.path.join(base_dir, "test", "img")
|
465 |
+
test_label_file = os.path.join(base_dir, "test", "caption.txt")
|
466 |
+
|
467 |
+
if not os.path.exists(test_data_folder) or not os.path.exists(test_label_file):
|
468 |
+
# Use '2019' as test set if 'test' is not available
|
469 |
+
test_data_folder = os.path.join(base_dir, "2019", "img")
|
470 |
+
test_label_file = os.path.join(base_dir, "2019", "caption.txt")
|
471 |
+
|
472 |
+
test_dataset = HMERDatasetForCAN(
|
473 |
+
data_folder=test_data_folder,
|
474 |
+
label_file=test_label_file,
|
475 |
+
vocab=vocab,
|
476 |
+
transform=transform,
|
477 |
+
)
|
478 |
+
|
479 |
+
# Create dataloaders
|
480 |
+
train_loader = DataLoader(
|
481 |
+
train_dataset,
|
482 |
+
batch_size=batch_size,
|
483 |
+
shuffle=True,
|
484 |
+
num_workers=num_workers,
|
485 |
+
pin_memory=True,
|
486 |
+
)
|
487 |
+
|
488 |
+
val_loader = DataLoader(
|
489 |
+
val_dataset,
|
490 |
+
batch_size=batch_size,
|
491 |
+
shuffle=False,
|
492 |
+
num_workers=num_workers,
|
493 |
+
pin_memory=True,
|
494 |
+
)
|
495 |
+
|
496 |
+
test_loader = DataLoader(
|
497 |
+
test_dataset,
|
498 |
+
batch_size=batch_size,
|
499 |
+
shuffle=False,
|
500 |
+
num_workers=num_workers,
|
501 |
+
pin_memory=True,
|
502 |
+
)
|
503 |
+
|
504 |
+
return train_loader, val_loader, test_loader, vocab
|
505 |
+
|
506 |
+
|
507 |
+
# Use functionality integrated with the CAN model
|
508 |
+
def main():
|
509 |
+
# Create dataloader for the CAN model
|
510 |
+
train_loader, val_loader, test_loader, vocab = create_dataloaders_for_can(
|
511 |
+
base_dir=BASE_DIR, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS
|
512 |
+
)
|
513 |
+
|
514 |
+
# Print information
|
515 |
+
print(f"Training samples: {len(train_loader.dataset)}")
|
516 |
+
print(f"Validation samples: {len(val_loader.dataset)}")
|
517 |
+
print(f"Test samples: {len(test_loader.dataset)}")
|
518 |
+
|
519 |
+
# Check dataloader output
|
520 |
+
for images, captions, lengths, count_vectors in train_loader:
|
521 |
+
print(f"Image batch shape: {images.shape}")
|
522 |
+
print(f"Caption batch shape: {captions.shape}")
|
523 |
+
print(f"Lengths batch shape: {lengths.shape}")
|
524 |
+
print(f"Count vectors batch shape: {count_vectors.shape}")
|
525 |
+
break
|
526 |
+
|
527 |
+
|
528 |
+
if __name__ == "__main__":
|
529 |
+
main()
|
models/can/can_eval.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import pandas as pd
|
8 |
+
from PIL import Image
|
9 |
+
import cv2
|
10 |
+
import albumentations as A
|
11 |
+
from albumentations.pytorch import ToTensorV2
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import numpy as np
|
14 |
+
from tqdm.auto import tqdm
|
15 |
+
import json
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
from models.can.can import CAN, create_can_model
|
19 |
+
from models.can.can_dataloader import Vocabulary, process_img, INPUT_HEIGHT, INPUT_WIDTH
|
20 |
+
|
21 |
+
torch.serialization.add_safe_globals([Vocabulary])
|
22 |
+
|
23 |
+
os.environ['QT_QPA_PLATFORM'] = 'offscreen'
|
24 |
+
|
25 |
+
with open("config.json", "r") as json_file:
|
26 |
+
cfg = json.load(json_file)
|
27 |
+
|
28 |
+
CAN_CONFIG = cfg["can"]
|
29 |
+
|
30 |
+
|
31 |
+
# Global constants here
|
32 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
33 |
+
MODE = CAN_CONFIG["mode"] # 'single' or 'evaluate'
|
34 |
+
BACKBONE_TYPE = CAN_CONFIG["backbone_type"]
|
35 |
+
PRETRAINED_BACKBONE = True if CAN_CONFIG["pretrained_backbone"] == 1 else False
|
36 |
+
CHECKPOINT_PATH = f'checkpoints/{BACKBONE_TYPE}_can_best.pth' if PRETRAINED_BACKBONE == False else f'checkpoints/p_{BACKBONE_TYPE}_can_best.pth'
|
37 |
+
IMAGE_PATH = f'{CAN_CONFIG["test_folder"]}/{CAN_CONFIG["relative_image_path"]}'
|
38 |
+
VISUALIZE = True if CAN_CONFIG["visualize"] == 1 else False
|
39 |
+
TEST_FOLDER = CAN_CONFIG["test_folder"]
|
40 |
+
LABEL_FILE = CAN_CONFIG["label_file"]
|
41 |
+
CLASSIFIER = CAN_CONFIG["classifier"] # choose between 'frac', 'sum_or_lim', 'long_expr', and 'all'
|
42 |
+
|
43 |
+
|
44 |
+
def filter_formula(formula_tokens, mode):
|
45 |
+
if mode == "frac":
|
46 |
+
return "\\frac" in formula_tokens
|
47 |
+
elif mode == "sum_or_lim":
|
48 |
+
return "\\sum" in formula_tokens or "\\limit" in formula_tokens
|
49 |
+
elif mode == "long_expr":
|
50 |
+
return len(formula_tokens) >= 30
|
51 |
+
elif mode == 'short_expr':
|
52 |
+
return len(formula_tokens) <= 10
|
53 |
+
return True
|
54 |
+
|
55 |
+
|
56 |
+
def levenshtein_distance(lst1, lst2):
|
57 |
+
"""
|
58 |
+
Calculate Levenshtein distance between two lists
|
59 |
+
"""
|
60 |
+
m = len(lst1)
|
61 |
+
n = len(lst2)
|
62 |
+
|
63 |
+
prev_row = [j for j in range(n + 1)]
|
64 |
+
curr_row = [0] * (n + 1)
|
65 |
+
for i in range(1, m + 1):
|
66 |
+
curr_row[0] = i
|
67 |
+
|
68 |
+
for j in range(1, n + 1):
|
69 |
+
if lst1[i - 1] == lst2[j - 1]:
|
70 |
+
curr_row[j] = prev_row[j - 1]
|
71 |
+
else:
|
72 |
+
curr_row[j] = 1 + min(
|
73 |
+
curr_row[j - 1], # insertion
|
74 |
+
prev_row[j], # deletion
|
75 |
+
prev_row[j - 1] # substitution
|
76 |
+
)
|
77 |
+
|
78 |
+
prev_row = curr_row.copy()
|
79 |
+
return curr_row[n]
|
80 |
+
|
81 |
+
|
82 |
+
def load_checkpoint(checkpoint_path, device, pretrained_backbone=True, backbone='densenet'):
|
83 |
+
"""
|
84 |
+
Load checkpoint and return model and vocabulary
|
85 |
+
"""
|
86 |
+
checkpoint = torch.load(checkpoint_path,
|
87 |
+
map_location=device,
|
88 |
+
weights_only=False)
|
89 |
+
|
90 |
+
vocab = checkpoint.get('vocab')
|
91 |
+
if vocab is None:
|
92 |
+
# Try to load vocab from a separate file if not in checkpoint
|
93 |
+
vocab_path = os.path.join(os.path.dirname(checkpoint_path),
|
94 |
+
'hmer_vocab.pth')
|
95 |
+
if os.path.exists(vocab_path):
|
96 |
+
vocab_data = torch.load(vocab_path)
|
97 |
+
vocab = Vocabulary()
|
98 |
+
vocab.word2idx = vocab_data['word2idx']
|
99 |
+
vocab.idx2word = vocab_data['idx2word']
|
100 |
+
vocab.idx = vocab_data['idx']
|
101 |
+
# Update special tokens
|
102 |
+
vocab.pad_token = vocab.word2idx['<pad>']
|
103 |
+
vocab.start_token = vocab.word2idx['<start>']
|
104 |
+
vocab.end_token = vocab.word2idx['<end>']
|
105 |
+
vocab.unk_token = vocab.word2idx['<unk>']
|
106 |
+
else:
|
107 |
+
raise ValueError(
|
108 |
+
f"Vocabulary not found in checkpoint and {vocab_path} does not exist"
|
109 |
+
)
|
110 |
+
|
111 |
+
# Initialize model with parameters from checkpoint
|
112 |
+
hidden_size = checkpoint.get('hidden_size', 256)
|
113 |
+
embedding_dim = checkpoint.get('embedding_dim', 256)
|
114 |
+
use_coverage = checkpoint.get('use_coverage', True)
|
115 |
+
|
116 |
+
model = create_can_model(num_classes=len(vocab),
|
117 |
+
hidden_size=hidden_size,
|
118 |
+
embedding_dim=embedding_dim,
|
119 |
+
use_coverage=use_coverage,
|
120 |
+
pretrained_backbone=pretrained_backbone,
|
121 |
+
backbone_type=backbone).to(device)
|
122 |
+
|
123 |
+
model.load_state_dict(checkpoint['model'])
|
124 |
+
print(f"Loaded model from checkpoint {checkpoint_path}")
|
125 |
+
|
126 |
+
return model, vocab
|
127 |
+
|
128 |
+
|
129 |
+
def recognize_single_image(model,
|
130 |
+
image_path,
|
131 |
+
vocab,
|
132 |
+
device,
|
133 |
+
max_length=150,
|
134 |
+
visualize_attention=False):
|
135 |
+
"""
|
136 |
+
Recognize handwritten mathematical expression from a single image using the CAN model
|
137 |
+
"""
|
138 |
+
# Prepare image transform for grayscale images
|
139 |
+
transform = A.Compose([
|
140 |
+
A.Normalize(mean=[0.0], std=[1.0]), # For grayscale
|
141 |
+
A.pytorch.ToTensorV2()
|
142 |
+
])
|
143 |
+
|
144 |
+
# Load and transform image
|
145 |
+
processed_img, best_crop = process_img(image_path, convert_to_rgb=False)
|
146 |
+
|
147 |
+
# Ensure image has the correct format for albumentations
|
148 |
+
processed_img = np.expand_dims(processed_img, axis=-1) # [H, W, 1]
|
149 |
+
image_tensor = transform(
|
150 |
+
image=processed_img)['image'].unsqueeze(0).to(device)
|
151 |
+
|
152 |
+
model.eval()
|
153 |
+
with torch.no_grad():
|
154 |
+
# Generate LaTeX using beam search
|
155 |
+
predictions, attention_weights = model.recognize(
|
156 |
+
image_tensor,
|
157 |
+
max_length=max_length,
|
158 |
+
start_token=vocab.start_token,
|
159 |
+
end_token=vocab.end_token,
|
160 |
+
beam_width=5 # Use beam search with width 5
|
161 |
+
)
|
162 |
+
|
163 |
+
# Convert indices to LaTeX tokens
|
164 |
+
latex_tokens = []
|
165 |
+
for idx in predictions:
|
166 |
+
if idx == vocab.end_token:
|
167 |
+
break
|
168 |
+
if idx != vocab.start_token: # Skip start token
|
169 |
+
latex_tokens.append(vocab.idx2word[idx])
|
170 |
+
|
171 |
+
# Join tokens to get complete LaTeX
|
172 |
+
latex = ' '.join(latex_tokens)
|
173 |
+
|
174 |
+
# Visualize attention if requested
|
175 |
+
if visualize_attention and attention_weights is not None:
|
176 |
+
visualize_attention_maps(processed_img, attention_weights,
|
177 |
+
latex_tokens, best_crop)
|
178 |
+
|
179 |
+
return latex
|
180 |
+
|
181 |
+
|
182 |
+
def visualize_attention_maps(orig_image,
|
183 |
+
attention_weights,
|
184 |
+
latex_tokens,
|
185 |
+
best_crop,
|
186 |
+
max_cols=4):
|
187 |
+
"""
|
188 |
+
Visualize attention maps over the image for CAN model
|
189 |
+
"""
|
190 |
+
# Create PIL image from numpy array
|
191 |
+
orig_image = orig_image.crop(best_crop)
|
192 |
+
orig_w, orig_h = orig_image.size
|
193 |
+
ratio = INPUT_HEIGHT / INPUT_WIDTH
|
194 |
+
|
195 |
+
num_tokens = len(latex_tokens)
|
196 |
+
num_cols = min(max_cols, num_tokens)
|
197 |
+
num_rows = int(np.ceil(num_tokens / num_cols))
|
198 |
+
|
199 |
+
fig, axes = plt.subplots(num_rows,
|
200 |
+
num_cols,
|
201 |
+
figsize=(num_cols * 3, int(num_rows * 6 * orig_h / orig_w)))
|
202 |
+
axes = np.array(axes).reshape(-1)
|
203 |
+
|
204 |
+
for i, (token, attn) in enumerate(zip(latex_tokens, attention_weights)):
|
205 |
+
ax = axes[i]
|
206 |
+
|
207 |
+
attn = attn[0:1].squeeze(0)
|
208 |
+
attn_len = attn.shape[0]
|
209 |
+
attn_w = int(np.sqrt(attn_len / ratio))
|
210 |
+
attn_h = int(np.sqrt(attn_len * ratio))
|
211 |
+
|
212 |
+
# resize to (orig_h, interpolated_w)
|
213 |
+
attn = attn.view(1, 1, attn_h, attn_w)
|
214 |
+
interp_w = int(orig_h / ratio)
|
215 |
+
|
216 |
+
attn = F.interpolate(attn, size=(orig_h, interp_w), mode='bilinear', align_corners=False)
|
217 |
+
attn = attn.squeeze().cpu().numpy()
|
218 |
+
|
219 |
+
# fix aspect ratio mismatch
|
220 |
+
if interp_w > orig_w:
|
221 |
+
# center crop width
|
222 |
+
start = (interp_w - orig_w) // 2
|
223 |
+
attn = attn[:, start:start + orig_w]
|
224 |
+
elif interp_w < orig_w:
|
225 |
+
# stretch to fit width
|
226 |
+
attn = cv2.resize(attn, (orig_w, orig_h), interpolation=cv2.INTER_CUBIC)
|
227 |
+
|
228 |
+
ax.imshow(orig_image)
|
229 |
+
ax.imshow(attn, cmap='jet', alpha=0.4)
|
230 |
+
ax.set_title(f'{token}', fontsize=10 * 8 * orig_h / orig_w)
|
231 |
+
ax.axis('off')
|
232 |
+
|
233 |
+
for j in range(i + 1, len(axes)):
|
234 |
+
axes[j].axis('off')
|
235 |
+
|
236 |
+
plt.tight_layout()
|
237 |
+
plt.savefig('attention_maps_can.png', bbox_inches='tight', dpi=150)
|
238 |
+
plt.close()
|
239 |
+
|
240 |
+
|
241 |
+
def evaluate_model(model,
|
242 |
+
test_folder,
|
243 |
+
label_file,
|
244 |
+
vocab,
|
245 |
+
device,
|
246 |
+
max_length=150,
|
247 |
+
batch_size=32):
|
248 |
+
"""
|
249 |
+
Evaluate CAN model on test set
|
250 |
+
"""
|
251 |
+
df = pd.read_csv(label_file,
|
252 |
+
sep='\t',
|
253 |
+
header=None,
|
254 |
+
names=['filename', 'label'])
|
255 |
+
|
256 |
+
# Check image file format
|
257 |
+
if os.path.exists(test_folder):
|
258 |
+
img_files = os.listdir(test_folder)
|
259 |
+
if img_files:
|
260 |
+
# Get the extension of the first file
|
261 |
+
extension = os.path.splitext(img_files[0])[1]
|
262 |
+
# Add extension to filenames if not present
|
263 |
+
df['filename'] = df['filename'].apply(
|
264 |
+
lambda x: x if os.path.splitext(x)[1] else x + extension)
|
265 |
+
|
266 |
+
annotations = dict(zip(df['filename'], df['label']))
|
267 |
+
|
268 |
+
model.eval()
|
269 |
+
|
270 |
+
correct = 0
|
271 |
+
err1 = 0
|
272 |
+
err2 = 0
|
273 |
+
err3 = 0
|
274 |
+
total = 0
|
275 |
+
|
276 |
+
transform = A.Compose([
|
277 |
+
A.Normalize(mean=[0.0], std=[1.0]), # For grayscale
|
278 |
+
A.pytorch.ToTensorV2()
|
279 |
+
])
|
280 |
+
|
281 |
+
results = {}
|
282 |
+
|
283 |
+
for image_path, gt_latex in tqdm(annotations.items(), desc="Evaluating"):
|
284 |
+
gt_latex: str = gt_latex
|
285 |
+
if not filter_formula(gt_latex.split(), CLASSIFIER):
|
286 |
+
continue
|
287 |
+
file_path = os.path.join(test_folder, image_path)
|
288 |
+
|
289 |
+
try:
|
290 |
+
processed_img, _ = process_img(file_path, convert_to_rgb=False)
|
291 |
+
|
292 |
+
# Ensure image has the correct format for albumentations
|
293 |
+
processed_img = np.expand_dims(processed_img, axis=-1) # [H, W, 1]
|
294 |
+
image_tensor = transform(
|
295 |
+
image=processed_img)['image'].unsqueeze(0).to(device)
|
296 |
+
|
297 |
+
with torch.no_grad():
|
298 |
+
predictions, _ = model.recognize(
|
299 |
+
image_tensor,
|
300 |
+
max_length=max_length,
|
301 |
+
start_token=vocab.start_token,
|
302 |
+
end_token=vocab.end_token,
|
303 |
+
beam_width=5 # Use beam search
|
304 |
+
)
|
305 |
+
|
306 |
+
# Convert indices to LaTeX tokens
|
307 |
+
pred_latex_tokens = []
|
308 |
+
for idx in predictions:
|
309 |
+
if idx == vocab.end_token:
|
310 |
+
break
|
311 |
+
if idx != vocab.start_token: # Skip start token
|
312 |
+
pred_latex_tokens.append(vocab.idx2word[idx])
|
313 |
+
|
314 |
+
pred_latex = ' '.join(pred_latex_tokens)
|
315 |
+
|
316 |
+
gt_latex_tokens = gt_latex.split()
|
317 |
+
edit_distance = levenshtein_distance(pred_latex_tokens,
|
318 |
+
gt_latex_tokens)
|
319 |
+
|
320 |
+
if edit_distance == 0:
|
321 |
+
correct += 1
|
322 |
+
elif edit_distance == 1:
|
323 |
+
err1 += 1
|
324 |
+
elif edit_distance == 2:
|
325 |
+
err2 += 1
|
326 |
+
elif edit_distance == 3:
|
327 |
+
err3 += 1
|
328 |
+
|
329 |
+
total += 1
|
330 |
+
|
331 |
+
# Save result
|
332 |
+
results[image_path] = {
|
333 |
+
'ground_truth': gt_latex,
|
334 |
+
'prediction': pred_latex,
|
335 |
+
'edit_distance': edit_distance
|
336 |
+
}
|
337 |
+
except Exception as e:
|
338 |
+
print(f"Error processing {image_path}: {e}")
|
339 |
+
|
340 |
+
# Calculate accuracy metrics
|
341 |
+
exprate = round(correct / total, 4) if total > 0 else 0
|
342 |
+
exprate_leq1 = round((correct + err1) / total, 4) if total > 0 else 0
|
343 |
+
exprate_leq2 = round(
|
344 |
+
(correct + err1 + err2) / total, 4) if total > 0 else 0
|
345 |
+
exprate_leq3 = round(
|
346 |
+
(correct + err1 + err2 + err3) / total, 4) if total > 0 else 0
|
347 |
+
|
348 |
+
print(f"Exact match rate: {exprate:.4f}")
|
349 |
+
print(f"Edit distance ≤ 1: {exprate_leq1:.4f}")
|
350 |
+
print(f"Edit distance ≤ 2: {exprate_leq2:.4f}")
|
351 |
+
print(f"Edit distance ≤ 3: {exprate_leq3:.4f}")
|
352 |
+
|
353 |
+
# Save results to file
|
354 |
+
with open('evaluation_results_can.json', 'w', encoding='utf-8') as f:
|
355 |
+
json.dump(
|
356 |
+
{
|
357 |
+
'accuracy': {
|
358 |
+
'exprate': exprate,
|
359 |
+
'exprate_leq1': exprate_leq1,
|
360 |
+
'exprate_leq2': exprate_leq2,
|
361 |
+
'exprate_leq3': exprate_leq3
|
362 |
+
},
|
363 |
+
'results': results
|
364 |
+
},
|
365 |
+
f,
|
366 |
+
indent=4)
|
367 |
+
|
368 |
+
return {
|
369 |
+
'exprate': exprate,
|
370 |
+
'exprate_leq1': exprate_leq1,
|
371 |
+
'exprate_leq2': exprate_leq2,
|
372 |
+
'exprate_leq3': exprate_leq3
|
373 |
+
}, results
|
374 |
+
|
375 |
+
|
376 |
+
def main(mode):
|
377 |
+
device = DEVICE
|
378 |
+
print(f'Using device: {device}')
|
379 |
+
|
380 |
+
checkpoint_path = CHECKPOINT_PATH
|
381 |
+
backbone = BACKBONE_TYPE
|
382 |
+
pretrained_backbone = PRETRAINED_BACKBONE
|
383 |
+
|
384 |
+
# For single mode
|
385 |
+
image_path = IMAGE_PATH
|
386 |
+
visualize = VISUALIZE
|
387 |
+
|
388 |
+
# For evaluation mode
|
389 |
+
test_folder = TEST_FOLDER
|
390 |
+
label_file = LABEL_FILE
|
391 |
+
|
392 |
+
# Load model and vocabulary
|
393 |
+
model, vocab = load_checkpoint(checkpoint_path, device, pretrained_backbone=pretrained_backbone, backbone=backbone)
|
394 |
+
|
395 |
+
if mode == 'single':
|
396 |
+
if image_path is None:
|
397 |
+
raise ValueError('Image path is required for single mode')
|
398 |
+
|
399 |
+
latex = recognize_single_image(model,
|
400 |
+
image_path,
|
401 |
+
vocab,
|
402 |
+
device,
|
403 |
+
visualize_attention=visualize)
|
404 |
+
print(f'Recognized LaTeX: {latex}')
|
405 |
+
|
406 |
+
elif mode == 'evaluate':
|
407 |
+
if test_folder is None or label_file is None:
|
408 |
+
raise ValueError(
|
409 |
+
'Test folder and annotation file are required for evaluate mode'
|
410 |
+
)
|
411 |
+
|
412 |
+
metrics, results = evaluate_model(model, test_folder, label_file,
|
413 |
+
vocab, device)
|
414 |
+
print(f"##### Score of {CLASSIFIER} expression type: #####")
|
415 |
+
print(f'Evaluation metrics: {metrics}')
|
416 |
+
|
417 |
+
|
418 |
+
if __name__ == '__main__':
|
419 |
+
# Ensure Vocabulary is safe for serialization
|
420 |
+
torch.serialization.add_safe_globals([Vocabulary])
|
421 |
+
|
422 |
+
# Run the main function
|
423 |
+
main(MODE)
|
models/can/can_trainer.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.optim as optim
|
9 |
+
import numpy as np
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
import time
|
12 |
+
import wandb
|
13 |
+
from datetime import datetime
|
14 |
+
from tqdm.auto import tqdm
|
15 |
+
|
16 |
+
from models.can.can import CAN, create_can_model
|
17 |
+
from models.can.can_dataloader import create_dataloaders_for_can, Vocabulary
|
18 |
+
|
19 |
+
import albumentations as A
|
20 |
+
import cv2
|
21 |
+
import random
|
22 |
+
|
23 |
+
import json
|
24 |
+
|
25 |
+
with open("config.json", "r") as json_file:
|
26 |
+
cfg = json.load(json_file)
|
27 |
+
|
28 |
+
CAN_CONFIG = cfg["can"]
|
29 |
+
|
30 |
+
|
31 |
+
# Global constants
|
32 |
+
BASE_DIR = CAN_CONFIG["base_dir"]
|
33 |
+
SEED = CAN_CONFIG["seed"]
|
34 |
+
CHECKPOINT_DIR = CAN_CONFIG["checkpoint_dir"]
|
35 |
+
PRETRAINED_BACKBONE = True if CAN_CONFIG["pretrained_backbone"] == 1 else False
|
36 |
+
BACKBONE_TYPE = CAN_CONFIG["backbone_type"]
|
37 |
+
CHECKPOINT_NAME = f'{BACKBONE_TYPE}_can_best.pth' if PRETRAINED_BACKBONE == False else f'p_{BACKBONE_TYPE}_can_best.pth'
|
38 |
+
BATCH_SIZE = CAN_CONFIG["batch_size"]
|
39 |
+
|
40 |
+
HIDDEN_SIZE = CAN_CONFIG["hidden_size"]
|
41 |
+
EMBEDDING_DIM = CAN_CONFIG["embedding_dim"]
|
42 |
+
USE_COVERAGE = True if CAN_CONFIG["use_coverage"] == 1 else False
|
43 |
+
LAMBDA_COUNT = CAN_CONFIG["lambda_count"]
|
44 |
+
|
45 |
+
LR = CAN_CONFIG["lr"]
|
46 |
+
EPOCHS = CAN_CONFIG["epochs"]
|
47 |
+
GRAD_CLIP = CAN_CONFIG["grad_clip"]
|
48 |
+
PRINT_FREQ = CAN_CONFIG["print_freq"]
|
49 |
+
|
50 |
+
T = CAN_CONFIG["t"]
|
51 |
+
T_MULT = CAN_CONFIG["t_mult"]
|
52 |
+
|
53 |
+
PROJECT_NAME = f'final-hmer-can-{BACKBONE_TYPE}-pretrained' if PRETRAINED_BACKBONE == True else f'final-hmer-can-{BACKBONE_TYPE}'
|
54 |
+
NUM_WORKERS = cfg["can"]["num_workers"]
|
55 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
56 |
+
|
57 |
+
class RandomMorphology(A.ImageOnlyTransform):
|
58 |
+
|
59 |
+
def __init__(self, p=0.5, kernel_size=3):
|
60 |
+
super(RandomMorphology, self).__init__(p)
|
61 |
+
self.kernel_size = kernel_size
|
62 |
+
|
63 |
+
def apply(self, img, **params):
|
64 |
+
op = random.choice(['erode', 'dilate'])
|
65 |
+
kernel = np.ones((self.kernel_size, self.kernel_size), np.uint8)
|
66 |
+
if op == 'erode':
|
67 |
+
return cv2.erode(img, kernel, iterations=1)
|
68 |
+
else:
|
69 |
+
return cv2.dilate(img, kernel, iterations=1)
|
70 |
+
|
71 |
+
|
72 |
+
# Custom transforms for CAN model (grayscale images)
|
73 |
+
train_transforms = A.Compose([
|
74 |
+
A.Rotate(limit=5, p=0.25, border_mode=cv2.BORDER_REPLICATE),
|
75 |
+
A.ElasticTransform(alpha=100,
|
76 |
+
sigma=7,
|
77 |
+
p=0.5,
|
78 |
+
interpolation=cv2.INTER_CUBIC),
|
79 |
+
RandomMorphology(p=0.5, kernel_size=2),
|
80 |
+
A.Normalize(mean=[0.0], std=[1.0]), # For grayscale
|
81 |
+
A.pytorch.ToTensorV2()
|
82 |
+
])
|
83 |
+
|
84 |
+
|
85 |
+
def train_epoch(model,
|
86 |
+
train_loader,
|
87 |
+
optimizer,
|
88 |
+
device,
|
89 |
+
grad_clip=5.0,
|
90 |
+
lambda_count=0.01,
|
91 |
+
print_freq=10):
|
92 |
+
"""
|
93 |
+
Train the model for one epoch
|
94 |
+
"""
|
95 |
+
model.train()
|
96 |
+
total_loss = 0.0
|
97 |
+
total_cls_loss = 0.0
|
98 |
+
total_count_loss = 0.0
|
99 |
+
batch_count = 0
|
100 |
+
|
101 |
+
for i, (images, captions, caption_lengths,
|
102 |
+
count_targets) in tqdm(enumerate(train_loader),
|
103 |
+
total=len(train_loader)):
|
104 |
+
batch_count += 1
|
105 |
+
images = images.to(device)
|
106 |
+
captions = captions.to(device)
|
107 |
+
count_targets = count_targets.to(device)
|
108 |
+
|
109 |
+
# Forward pass
|
110 |
+
outputs, count_vectors = model(images,
|
111 |
+
captions,
|
112 |
+
teacher_forcing_ratio=0.5)
|
113 |
+
|
114 |
+
# Calculate loss
|
115 |
+
loss, cls_loss, counting_loss = model.calculate_loss(
|
116 |
+
outputs=outputs,
|
117 |
+
targets=captions,
|
118 |
+
count_vectors=count_vectors,
|
119 |
+
count_targets=count_targets,
|
120 |
+
lambda_count=lambda_count)
|
121 |
+
|
122 |
+
# Backward pass
|
123 |
+
optimizer.zero_grad()
|
124 |
+
loss.backward()
|
125 |
+
|
126 |
+
# Clip gradients
|
127 |
+
if grad_clip:
|
128 |
+
nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
129 |
+
|
130 |
+
# Update weights
|
131 |
+
optimizer.step()
|
132 |
+
|
133 |
+
# Track losses
|
134 |
+
total_loss += loss.item()
|
135 |
+
total_cls_loss += cls_loss.item()
|
136 |
+
total_count_loss += counting_loss.item()
|
137 |
+
|
138 |
+
# Print progress
|
139 |
+
if i % print_freq == 0 and i > 0:
|
140 |
+
print(
|
141 |
+
f'Batch {i}/{len(train_loader)}, Loss: {loss.item():.4f}, '
|
142 |
+
f'Cls Loss: {cls_loss.item():.4f}, Count Loss: {counting_loss.item():.4f}'
|
143 |
+
)
|
144 |
+
|
145 |
+
return total_loss / batch_count, total_cls_loss / batch_count, total_count_loss / batch_count
|
146 |
+
|
147 |
+
|
148 |
+
def validate(model, val_loader, device, lambda_count=0.01):
|
149 |
+
"""
|
150 |
+
Validate the model
|
151 |
+
"""
|
152 |
+
model.eval()
|
153 |
+
total_loss = 0.0
|
154 |
+
total_cls_loss = 0.0
|
155 |
+
total_count_loss = 0.0
|
156 |
+
batch_count = 0
|
157 |
+
|
158 |
+
with torch.no_grad():
|
159 |
+
for i, (images, captions, caption_lengths,
|
160 |
+
count_targets) in tqdm(enumerate(val_loader),
|
161 |
+
total=len(val_loader)):
|
162 |
+
batch_count += 1
|
163 |
+
images = images.to(device)
|
164 |
+
captions = captions.to(device)
|
165 |
+
count_targets = count_targets.to(device)
|
166 |
+
|
167 |
+
# Forward pass
|
168 |
+
outputs, count_vectors = model(
|
169 |
+
images, captions,
|
170 |
+
teacher_forcing_ratio=0.0) # No teacher forcing in validation
|
171 |
+
|
172 |
+
# Calculate loss
|
173 |
+
loss, cls_loss, counting_loss = model.calculate_loss(
|
174 |
+
outputs=outputs,
|
175 |
+
targets=captions,
|
176 |
+
count_vectors=count_vectors,
|
177 |
+
count_targets=count_targets,
|
178 |
+
lambda_count=lambda_count)
|
179 |
+
|
180 |
+
# Track losses
|
181 |
+
total_loss += loss.item()
|
182 |
+
total_cls_loss += cls_loss.item()
|
183 |
+
total_count_loss += counting_loss.item()
|
184 |
+
|
185 |
+
return total_loss / batch_count, total_cls_loss / batch_count, total_count_loss / batch_count
|
186 |
+
|
187 |
+
|
188 |
+
def main():
|
189 |
+
# Configuration
|
190 |
+
dataset_dir = BASE_DIR
|
191 |
+
seed = SEED
|
192 |
+
checkpoints_dir = CHECKPOINT_DIR
|
193 |
+
checkpoint_name = CHECKPOINT_NAME
|
194 |
+
batch_size = BATCH_SIZE
|
195 |
+
|
196 |
+
# Model parameters
|
197 |
+
hidden_size = HIDDEN_SIZE
|
198 |
+
embedding_dim = EMBEDDING_DIM
|
199 |
+
use_coverage = USE_COVERAGE
|
200 |
+
lambda_count = LAMBDA_COUNT
|
201 |
+
|
202 |
+
# Training parameters
|
203 |
+
lr = LR
|
204 |
+
epochs = EPOCHS
|
205 |
+
grad_clip = GRAD_CLIP
|
206 |
+
print_freq = PRINT_FREQ
|
207 |
+
|
208 |
+
# Scheduler parameters
|
209 |
+
T_0 = T
|
210 |
+
T_mult = T_MULT
|
211 |
+
|
212 |
+
# Set random seeds
|
213 |
+
torch.manual_seed(seed)
|
214 |
+
np.random.seed(seed)
|
215 |
+
if torch.cuda.is_available():
|
216 |
+
torch.cuda.manual_seed(seed)
|
217 |
+
|
218 |
+
# Create checkpoint directory
|
219 |
+
os.makedirs(checkpoints_dir, exist_ok=True)
|
220 |
+
|
221 |
+
# Set device
|
222 |
+
device = DEVICE
|
223 |
+
print(f'Using device: {device}')
|
224 |
+
|
225 |
+
# Create dataloaders
|
226 |
+
train_loader, val_loader, test_loader, vocab = create_dataloaders_for_can(
|
227 |
+
base_dir=dataset_dir, batch_size=batch_size, num_workers=NUM_WORKERS)
|
228 |
+
|
229 |
+
print(f"Training samples: {len(train_loader.dataset)}")
|
230 |
+
print(f"Validation samples: {len(val_loader.dataset)}")
|
231 |
+
print(f"Test samples: {len(test_loader.dataset)}")
|
232 |
+
print(f"Vocabulary size: {len(vocab)}")
|
233 |
+
|
234 |
+
# Create model
|
235 |
+
model = create_can_model(num_classes=len(vocab),
|
236 |
+
hidden_size=hidden_size,
|
237 |
+
embedding_dim=embedding_dim,
|
238 |
+
use_coverage=use_coverage,
|
239 |
+
pretrained_backbone=PRETRAINED_BACKBONE,
|
240 |
+
backbone_type=BACKBONE_TYPE).to(device)
|
241 |
+
|
242 |
+
# Create optimizer
|
243 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
244 |
+
|
245 |
+
# Create learning rate scheduler
|
246 |
+
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
|
247 |
+
T_0=T_0,
|
248 |
+
T_mult=T_mult)
|
249 |
+
|
250 |
+
# Initialize wandb
|
251 |
+
run_name = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
252 |
+
wandb.init(project=PROJECT_NAME,
|
253 |
+
name=run_name,
|
254 |
+
config={
|
255 |
+
'seed': seed,
|
256 |
+
'batch_size': batch_size,
|
257 |
+
'hidden_size': hidden_size,
|
258 |
+
'embedding_dim': embedding_dim,
|
259 |
+
'use_coverage': use_coverage,
|
260 |
+
'lambda_count': lambda_count,
|
261 |
+
'lr': lr,
|
262 |
+
'epochs': epochs,
|
263 |
+
'grad_clip': grad_clip,
|
264 |
+
'T_0': T_0,
|
265 |
+
'T_mult': T_mult
|
266 |
+
})
|
267 |
+
|
268 |
+
# Training loop
|
269 |
+
best_val_loss = float('inf')
|
270 |
+
|
271 |
+
for epoch in tqdm(range(epochs)):
|
272 |
+
curr_lr = scheduler.get_last_lr()[0]
|
273 |
+
print(f'Epoch {epoch+1:03}/{epochs:03}')
|
274 |
+
t1 = time.time()
|
275 |
+
|
276 |
+
# Train
|
277 |
+
train_loss, train_cls_loss, train_count_loss = train_epoch(
|
278 |
+
model=model,
|
279 |
+
train_loader=train_loader,
|
280 |
+
optimizer=optimizer,
|
281 |
+
device=device,
|
282 |
+
grad_clip=grad_clip,
|
283 |
+
lambda_count=lambda_count,
|
284 |
+
print_freq=print_freq)
|
285 |
+
|
286 |
+
# Validate
|
287 |
+
val_loss, val_cls_loss, val_count_loss = validate(
|
288 |
+
model=model,
|
289 |
+
val_loader=val_loader,
|
290 |
+
device=device,
|
291 |
+
lambda_count=lambda_count)
|
292 |
+
|
293 |
+
# Update learning rate
|
294 |
+
scheduler.step()
|
295 |
+
t2 = time.time()
|
296 |
+
|
297 |
+
# Print stats
|
298 |
+
print(
|
299 |
+
f'Train - Total Loss: {train_loss:.4f}, Class Loss: {train_cls_loss:.4f}, Count Loss: {train_count_loss:.4f}'
|
300 |
+
)
|
301 |
+
print(
|
302 |
+
f'Val - Total Loss: {val_loss:.4f}, Class Loss: {val_cls_loss:.4f}, Count Loss: {val_count_loss:.4f}'
|
303 |
+
)
|
304 |
+
print(f'Time: {t2 - t1:.2f}s, Learning Rate: {curr_lr:.6f}')
|
305 |
+
|
306 |
+
# Log metrics to wandb
|
307 |
+
wandb.log({
|
308 |
+
'train_loss': train_loss,
|
309 |
+
'train_cls_loss': train_cls_loss,
|
310 |
+
'train_count_loss': train_count_loss,
|
311 |
+
'val_loss': val_loss,
|
312 |
+
'val_cls_loss': val_cls_loss,
|
313 |
+
'val_count_loss': val_count_loss,
|
314 |
+
'learning_rate': curr_lr,
|
315 |
+
'epoch': epoch
|
316 |
+
})
|
317 |
+
|
318 |
+
# Save checkpoint
|
319 |
+
if val_loss < best_val_loss:
|
320 |
+
best_val_loss = val_loss
|
321 |
+
checkpoint = {
|
322 |
+
'epoch': epoch,
|
323 |
+
'model': model.state_dict(),
|
324 |
+
'optimizer': optimizer.state_dict(),
|
325 |
+
'val_loss': val_loss,
|
326 |
+
'vocab': vocab
|
327 |
+
}
|
328 |
+
torch.save(checkpoint, os.path.join(checkpoints_dir,
|
329 |
+
checkpoint_name))
|
330 |
+
print('Model saved!')
|
331 |
+
|
332 |
+
print('Training completed!')
|
333 |
+
|
334 |
+
|
335 |
+
if __name__ == "__main__":
|
336 |
+
main()
|