marianeft's picture
Training Model Complete
15dba6b verified
raw
history blame
13.1 kB
# model_ocr.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import editdistance
# Import config and char_indexer
from config import IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN
from data_handler_ocr import CharIndexer
from utils_ocr import binarize_image, resize_image_for_ocr, normalize_image_for_model
class CNN_Backbone(nn.Module):
"""
CNN feature extractor for OCR. Designed to produce features suitable for RNN.
Output feature map should have height 1 after the final pooling/reduction.
"""
def __init__(self, input_channels=1, output_channels=512):
super(CNN_Backbone, self).__init__()
self.cnn = nn.Sequential(
# First block
nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2), # H: 32 -> 16, W: W_in -> W_in/2
# Second block
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2), # H: 16 -> 8, W: W_in/2 -> W_in/4
# Third block (with two conv layers)
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
# This MaxPool2d effectively brings height from 8 to 4, with a small width adjustment due to padding
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # H: 8 -> 4, W: (W/4) -> (W/4 + 1) (approx)
# Fourth block
nn.Conv2d(256, output_channels, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
# This AdaptiveAvgPool2d makes sure the height dimension becomes 1
# while preserving the width. This is crucial for RNN input.
nn.AdaptiveAvgPool2d((1, None)) # Output height 1, preserve width
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (N, C, H, W) e.g., (B, 1, 32, W_img)
# Pass through the CNN layers
conv_features = self.cnn(x) # Output: (N, cnn_out_channels, 1, W_prime)
# Squeeze the height dimension (which is 1)
# This transforms (N, C_out, 1, W_prime) to (N, C_out, W_prime)
conv_features = conv_features.squeeze(2)
# Permute for RNN input: (sequence_length, batch_size, input_size)
# This transforms (N, C_out, W_prime) to (W_prime, N, C_out)
conv_features = conv_features.permute(2, 0, 1)
# Return the CNN features, ready for the RNN layer in CRNN
return conv_features
class BidirectionalLSTM(nn.Module):
"""Bidirectional LSTM layer for sequence modeling."""
def __init__(self, input_size: int, hidden_size: int, num_layers: int, dropout: float = 0.5):
super(BidirectionalLSTM, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
bidirectional=True, dropout=dropout, batch_first=False)
# batch_first=False expects input as (sequence_length, batch_size, input_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
output, _ = self.lstm(x) # [0] returns the output, [1] returns (h_n, c_n)
return output
class CRNN(nn.Module):
"""
Convolutional Recurrent Neural Network for OCR.
Combines CNN for feature extraction, LSTMs for sequence modeling,
and a final linear layer for character prediction.
"""
def __init__(self, num_classes: int, cnn_output_channels: int = 512,
rnn_hidden_size: int = 256, rnn_num_layers: int = 2): # Corrected parameter name
super(CRNN, self).__init__()
self.cnn = CNN_Backbone(output_channels=cnn_output_channels)
# Input to LSTM is the number of channels from the CNN output
self.rnn = BidirectionalLSTM(cnn_output_channels, rnn_hidden_size, rnn_num_layers) # Corrected usage
# Output of bidirectional LSTM is hidden_size * 2
self.fc = nn.Linear(rnn_hidden_size * 2, num_classes) # Final linear layer for classes
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (N, C, H, W) e.g., (B, 1, 32, W_img)
# 1. Pass through the CNN to extract features
conv_features = self.cnn(x) # Output: (W_prime, N, C_out) after permute in CNN_Backbone
# 2. Pass CNN features through the RNN (LSTM)
rnn_features = self.rnn(conv_features) # Output: (W_prime, N, rnn_hidden_size * 2)
# 3. Pass RNN features through the final fully connected layer
# Apply the linear layer to each time step independently
# output will be (W_prime, N, num_classes)
output = self.fc(rnn_features)
return output
# --- Decoding Function ---
def ctc_greedy_decode(output: torch.Tensor, char_indexer: CharIndexer) -> list[str]:
"""
Performs greedy decoding on the CTC output.
output: (sequence_length, batch_size, num_classes) - raw logits
"""
# Apply log_softmax to get probabilities for argmax
log_probs = F.log_softmax(output, dim=2)
# Permute to (batch_size, sequence_length, num_classes) for argmax along class dim
predicted_indices = torch.argmax(log_probs.permute(1, 0, 2), dim=2).cpu().numpy()
decoded_texts = []
for seq in predicted_indices:
# Use char_indexer's decode method, which handles blank removal and duplicate collapse
decoded_texts.append(char_indexer.decode(seq.tolist()))
return decoded_texts
# --- Evaluation Function ---
def evaluate_model(model: nn.Module, dataloader: DataLoader, char_indexer: CharIndexer, device: str):
model.eval()
criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
total_loss = 0
all_predictions = []
all_ground_truths = []
with torch.no_grad():
for inputs, targets_padded, _, target_lengths in tqdm(dataloader, desc="Evaluating"):
inputs = inputs.to(device)
targets_padded = targets_padded.to(device)
target_lengths_tensor = target_lengths.to(device)
output = model(inputs)
outputs_seq_len_for_ctc = torch.full(
size=(output.shape[1],),
fill_value=output.shape[0],
dtype=torch.long,
device=device
)
# CTC Loss calculation requires log_softmax on the output logits
log_probs_for_loss = F.log_softmax(output, dim=2)
# CTCLoss expects targets_padded as a 1D tensor and target_lengths_tensor as corresponding lengths
loss = criterion(log_probs_for_loss, targets_padded, outputs_seq_len_for_ctc, target_lengths_tensor)
total_loss += loss.item() * inputs.size(0)
decoded_preds = ctc_greedy_decode(output, char_indexer)
all_predictions.extend(decoded_preds)
ground_truths_batch = []
current_idx_in_concatenated_targets = 0
target_lengths_list = target_lengths.cpu().tolist()
for i in range(inputs.size(0)):
length = target_lengths_list[i]
current_target_segment = targets_padded[current_idx_in_concatenated_targets : current_idx_in_concatenated_targets + length].tolist()
ground_truths_batch.append(char_indexer.decode(current_target_segment))
current_idx_in_concatenated_targets += length
all_ground_truths.extend(ground_truths_batch)
avg_loss = total_loss / len(dataloader.dataset)
# Calculate Character Error Rate (CER)
cer_sum = 0
total_chars = 0
for pred, gt in zip(all_predictions, all_ground_truths):
cer_sum += editdistance.eval(pred, gt)
total_chars += len(gt)
char_error_rate = cer_sum / total_chars if total_chars > 0 else 0.0
# Calculate Exact Match Accuracy (Word-level Accuracy)
exact_match_accuracy = accuracy_score(all_ground_truths, all_predictions)
return avg_loss, char_error_rate, exact_match_accuracy
# --- Training Function ---
def train_ocr_model(model: nn.Module, train_loader: DataLoader,
test_loader: DataLoader, char_indexer: CharIndexer,
epochs: int, device: str, progress_callback=None) -> tuple[nn.Module, dict]:
"""
Trains the OCR model using CTC loss.
"""
# CTCLoss needs the blank token index
criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=0.001) # Using a fixed LR for now
# Using ReduceLROnPlateau to adjust LR based on test loss (monitor 'min' loss)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5) # Removed verbose=True
model.to(device) # Ensure model is on the correct device
model.train() # Set model to training mode
training_history = {
'train_loss': [],
'test_loss': [],
'test_cer': [],
'test_exact_match_accuracy': []
}
for epoch in range(epochs):
running_loss = 0.0
pbar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)")
for images, texts_encoded, _, text_lengths in pbar_train:
images = images.to(device)
# Ensure target tensors are on the correct device for CTCLoss calculation
texts_encoded = texts_encoded.to(device)
text_lengths = text_lengths.to(device)
optimizer.zero_grad() # Clear gradients from previous step
outputs = model(images) # (sequence_length_from_cnn, batch_size, num_classes)
# `outputs.shape[0]` is the actual sequence length (T) produced by the model.
# CTC loss expects `input_lengths` to be a tensor of shape (batch_size,) with these values.
outputs_seq_len_for_ctc = torch.full(
size=(outputs.shape[1],), # batch_size
fill_value=outputs.shape[0], # actual sequence length (T) from model output
dtype=torch.long,
device=device
)
# CTC Loss calculation requires log_softmax on the output logits
log_probs_for_loss = F.log_softmax(outputs, dim=2) # (T, N, C)
# Use outputs_seq_len_for_ctc for the input_lengths argument
loss = criterion(log_probs_for_loss, texts_encoded, outputs_seq_len_for_ctc, text_lengths)
loss.backward() # Backpropagate
optimizer.step() # Update model weights
running_loss += loss.item() * images.size(0) # Multiply by batch size for correct average
pbar_train.set_postfix(loss=loss.item())
epoch_train_loss = running_loss / len(train_loader.dataset)
training_history['train_loss'].append(epoch_train_loss)
# Evaluate on test set using the dedicated function
# Ensure model is in eval mode before calling evaluate_model
model.eval()
test_loss, test_cer, test_exact_match_accuracy = evaluate_model(model, test_loader, char_indexer, device)
training_history['test_loss'].append(test_loss)
training_history['test_cer'].append(test_cer)
training_history['test_exact_match_accuracy'].append(test_exact_match_accuracy)
# Adjust learning rate based on test loss
scheduler.step(test_loss)
print(f"Epoch {epoch+1}/{epochs}: Train Loss={epoch_train_loss:.4f}, "
f"Test Loss={test_loss:.4f}, Test CER={test_cer:.4f}, Test Exact Match Acc={test_exact_match_accuracy:.4f}")
if progress_callback:
# Update progress bar with current epoch and key metrics
progress_val = (epoch + 1) / epochs
progress_callback(progress_val, text=f"Epoch {epoch+1}/{epochs} done. Test CER: {test_cer:.4f}, Test Exact Match Acc: {test_exact_match_accuracy:.4f}")
model.train() # Set model back to training mode after evaluation
return model, training_history
def save_ocr_model(model: nn.Module, path: str):
"""Saves the state dictionary of the trained OCR model."""
torch.save(model.state_dict(), path)
print(f"OCR model saved to {path}")
def load_ocr_model(model: nn.Module, path: str):
"""
Loads a trained OCR model's state dictionary.
Includes map_location to handle loading models trained on GPU to CPU, and vice versa.
"""
model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) # Always load to CPU first
model.eval() # Set to evaluation mode
print(f"OCR model loaded from {path}")