# model_ocr.py import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from tqdm import tqdm from sklearn.metrics import accuracy_score import editdistance # Import config and char_indexer from config import IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN from data_handler_ocr import CharIndexer from utils_ocr import binarize_image, resize_image_for_ocr, normalize_image_for_model class CNN_Backbone(nn.Module): """ CNN feature extractor for OCR. Designed to produce features suitable for RNN. Output feature map should have height 1 after the final pooling/reduction. """ def __init__(self, input_channels=1, output_channels=512): super(CNN_Backbone, self).__init__() self.cnn = nn.Sequential( # First block nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(True), nn.MaxPool2d(kernel_size=2, stride=2), # H: 32 -> 16, W: W_in -> W_in/2 # Second block nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(True), nn.MaxPool2d(kernel_size=2, stride=2), # H: 16 -> 8, W: W_in/2 -> W_in/4 # Third block (with two conv layers) nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(True), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(True), # This MaxPool2d effectively brings height from 8 to 4, with a small width adjustment due to padding nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # H: 8 -> 4, W: (W/4) -> (W/4 + 1) (approx) # Fourth block nn.Conv2d(256, output_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(True), # This AdaptiveAvgPool2d makes sure the height dimension becomes 1 # while preserving the width. This is crucial for RNN input. nn.AdaptiveAvgPool2d((1, None)) # Output height 1, preserve width ) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (N, C, H, W) e.g., (B, 1, 32, W_img) # Pass through the CNN layers conv_features = self.cnn(x) # Output: (N, cnn_out_channels, 1, W_prime) # Squeeze the height dimension (which is 1) # This transforms (N, C_out, 1, W_prime) to (N, C_out, W_prime) conv_features = conv_features.squeeze(2) # Permute for RNN input: (sequence_length, batch_size, input_size) # This transforms (N, C_out, W_prime) to (W_prime, N, C_out) conv_features = conv_features.permute(2, 0, 1) # Return the CNN features, ready for the RNN layer in CRNN return conv_features class BidirectionalLSTM(nn.Module): """Bidirectional LSTM layer for sequence modeling.""" def __init__(self, input_size: int, hidden_size: int, num_layers: int, dropout: float = 0.5): super(BidirectionalLSTM, self).__init__() self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True, dropout=dropout, batch_first=False) # batch_first=False expects input as (sequence_length, batch_size, input_size) def forward(self, x: torch.Tensor) -> torch.Tensor: output, _ = self.lstm(x) # [0] returns the output, [1] returns (h_n, c_n) return output class CRNN(nn.Module): """ Convolutional Recurrent Neural Network for OCR. Combines CNN for feature extraction, LSTMs for sequence modeling, and a final linear layer for character prediction. """ def __init__(self, num_classes: int, cnn_output_channels: int = 512, rnn_hidden_size: int = 256, rnn_num_layers: int = 2): # Corrected parameter name super(CRNN, self).__init__() self.cnn = CNN_Backbone(output_channels=cnn_output_channels) # Input to LSTM is the number of channels from the CNN output self.rnn = BidirectionalLSTM(cnn_output_channels, rnn_hidden_size, rnn_num_layers) # Corrected usage # Output of bidirectional LSTM is hidden_size * 2 self.fc = nn.Linear(rnn_hidden_size * 2, num_classes) # Final linear layer for classes def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (N, C, H, W) e.g., (B, 1, 32, W_img) # 1. Pass through the CNN to extract features conv_features = self.cnn(x) # Output: (W_prime, N, C_out) after permute in CNN_Backbone # 2. Pass CNN features through the RNN (LSTM) rnn_features = self.rnn(conv_features) # Output: (W_prime, N, rnn_hidden_size * 2) # 3. Pass RNN features through the final fully connected layer # Apply the linear layer to each time step independently # output will be (W_prime, N, num_classes) output = self.fc(rnn_features) return output # --- Decoding Function --- def ctc_greedy_decode(output: torch.Tensor, char_indexer: CharIndexer) -> list[str]: """ Performs greedy decoding on the CTC output. output: (sequence_length, batch_size, num_classes) - raw logits """ # Apply log_softmax to get probabilities for argmax log_probs = F.log_softmax(output, dim=2) # Permute to (batch_size, sequence_length, num_classes) for argmax along class dim predicted_indices = torch.argmax(log_probs.permute(1, 0, 2), dim=2).cpu().numpy() decoded_texts = [] for seq in predicted_indices: # Use char_indexer's decode method, which handles blank removal and duplicate collapse decoded_texts.append(char_indexer.decode(seq.tolist())) return decoded_texts # --- Evaluation Function --- def evaluate_model(model: nn.Module, dataloader: DataLoader, char_indexer: CharIndexer, device: str): model.eval() criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True) total_loss = 0 all_predictions = [] all_ground_truths = [] with torch.no_grad(): for inputs, targets_padded, _, target_lengths in tqdm(dataloader, desc="Evaluating"): inputs = inputs.to(device) targets_padded = targets_padded.to(device) target_lengths_tensor = target_lengths.to(device) output = model(inputs) outputs_seq_len_for_ctc = torch.full( size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, device=device ) # CTC Loss calculation requires log_softmax on the output logits log_probs_for_loss = F.log_softmax(output, dim=2) # CTCLoss expects targets_padded as a 1D tensor and target_lengths_tensor as corresponding lengths loss = criterion(log_probs_for_loss, targets_padded, outputs_seq_len_for_ctc, target_lengths_tensor) total_loss += loss.item() * inputs.size(0) decoded_preds = ctc_greedy_decode(output, char_indexer) all_predictions.extend(decoded_preds) ground_truths_batch = [] current_idx_in_concatenated_targets = 0 target_lengths_list = target_lengths.cpu().tolist() for i in range(inputs.size(0)): length = target_lengths_list[i] current_target_segment = targets_padded[current_idx_in_concatenated_targets : current_idx_in_concatenated_targets + length].tolist() ground_truths_batch.append(char_indexer.decode(current_target_segment)) current_idx_in_concatenated_targets += length all_ground_truths.extend(ground_truths_batch) avg_loss = total_loss / len(dataloader.dataset) # Calculate Character Error Rate (CER) cer_sum = 0 total_chars = 0 for pred, gt in zip(all_predictions, all_ground_truths): cer_sum += editdistance.eval(pred, gt) total_chars += len(gt) char_error_rate = cer_sum / total_chars if total_chars > 0 else 0.0 # Calculate Exact Match Accuracy (Word-level Accuracy) exact_match_accuracy = accuracy_score(all_ground_truths, all_predictions) return avg_loss, char_error_rate, exact_match_accuracy # --- Training Function --- def train_ocr_model(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader, char_indexer: CharIndexer, epochs: int, device: str, progress_callback=None) -> tuple[nn.Module, dict]: """ Trains the OCR model using CTC loss. """ # CTCLoss needs the blank token index criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True) optimizer = optim.Adam(model.parameters(), lr=0.001) # Using a fixed LR for now # Using ReduceLROnPlateau to adjust LR based on test loss (monitor 'min' loss) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5) # Removed verbose=True model.to(device) # Ensure model is on the correct device model.train() # Set model to training mode training_history = { 'train_loss': [], 'test_loss': [], 'test_cer': [], 'test_exact_match_accuracy': [] } for epoch in range(epochs): running_loss = 0.0 pbar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)") for images, texts_encoded, _, text_lengths in pbar_train: images = images.to(device) # Ensure target tensors are on the correct device for CTCLoss calculation texts_encoded = texts_encoded.to(device) text_lengths = text_lengths.to(device) optimizer.zero_grad() # Clear gradients from previous step outputs = model(images) # (sequence_length_from_cnn, batch_size, num_classes) # `outputs.shape[0]` is the actual sequence length (T) produced by the model. # CTC loss expects `input_lengths` to be a tensor of shape (batch_size,) with these values. outputs_seq_len_for_ctc = torch.full( size=(outputs.shape[1],), # batch_size fill_value=outputs.shape[0], # actual sequence length (T) from model output dtype=torch.long, device=device ) # CTC Loss calculation requires log_softmax on the output logits log_probs_for_loss = F.log_softmax(outputs, dim=2) # (T, N, C) # Use outputs_seq_len_for_ctc for the input_lengths argument loss = criterion(log_probs_for_loss, texts_encoded, outputs_seq_len_for_ctc, text_lengths) loss.backward() # Backpropagate optimizer.step() # Update model weights running_loss += loss.item() * images.size(0) # Multiply by batch size for correct average pbar_train.set_postfix(loss=loss.item()) epoch_train_loss = running_loss / len(train_loader.dataset) training_history['train_loss'].append(epoch_train_loss) # Evaluate on test set using the dedicated function # Ensure model is in eval mode before calling evaluate_model model.eval() test_loss, test_cer, test_exact_match_accuracy = evaluate_model(model, test_loader, char_indexer, device) training_history['test_loss'].append(test_loss) training_history['test_cer'].append(test_cer) training_history['test_exact_match_accuracy'].append(test_exact_match_accuracy) # Adjust learning rate based on test loss scheduler.step(test_loss) print(f"Epoch {epoch+1}/{epochs}: Train Loss={epoch_train_loss:.4f}, " f"Test Loss={test_loss:.4f}, Test CER={test_cer:.4f}, Test Exact Match Acc={test_exact_match_accuracy:.4f}") if progress_callback: # Update progress bar with current epoch and key metrics progress_val = (epoch + 1) / epochs progress_callback(progress_val, text=f"Epoch {epoch+1}/{epochs} done. Test CER: {test_cer:.4f}, Test Exact Match Acc: {test_exact_match_accuracy:.4f}") model.train() # Set model back to training mode after evaluation return model, training_history def save_ocr_model(model: nn.Module, path: str): """Saves the state dictionary of the trained OCR model.""" torch.save(model.state_dict(), path) print(f"OCR model saved to {path}") def load_ocr_model(model: nn.Module, path: str): """ Loads a trained OCR model's state dictionary. Includes map_location to handle loading models trained on GPU to CPU, and vice versa. """ model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) # Always load to CPU first model.eval() # Set to evaluation mode print(f"OCR model loaded from {path}")