marianeft commited on
Commit
38b6f7f
Β·
verified Β·
1 Parent(s): 7c7405f

Delete src

Browse files
Files changed (5) hide show
  1. src/app.py +0 -227
  2. src/config.py +0 -48
  3. src/data_handler_ocr.py +0 -165
  4. src/model_ocr.py +0 -286
  5. src/utils_ocr.py +0 -83
src/app.py DELETED
@@ -1,227 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # app.py
3
-
4
- import os
5
- # Disable Streamlit file watcher to prevent conflicts with PyTorch
6
- os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false"
7
-
8
- import streamlit as st
9
- import pandas as pd
10
- import numpy as np
11
- from PIL import Image
12
- import torch
13
- import torch.nn.functional as F
14
- import torchvision.transforms as transforms
15
- import traceback
16
-
17
- # Import all necessary configuration values from config.py
18
- from config import (
19
- IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN, VOCABULARY, BLANK_TOKEN_SYMBOL,
20
- TRAIN_CSV_PATH, TEST_CSV_PATH, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
21
- MODEL_SAVE_PATH, BATCH_SIZE, NUM_EPOCHS
22
- )
23
-
24
- # Import classes and functions from data_handler_ocr.py and model_ocr.py
25
- from data_handler_ocr import CharIndexer, OCRDataset, ocr_collate_fn, load_ocr_dataframes, create_ocr_dataloaders
26
- from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
27
- from utils_ocr import preprocess_user_image_for_ocr, binarize_image, resize_image_for_ocr, normalize_image_for_model
28
-
29
-
30
- # --- Global Variables ---
31
- ocr_model = None
32
- char_indexer = None
33
- training_history = None
34
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
-
36
- # --- Streamlit App Setup ---
37
- st.set_page_config(layout="wide", page_title="Handwritten Name OCR App",)
38
-
39
-
40
- st.title("πŸ“ Handwritten Name Recognition (OCR) App")
41
- st.markdown("""
42
- This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
43
- Optical Character Recognition (OCR) on handwritten names. You can upload an image
44
- of a handwritten name for prediction or train a new model using the provided dataset.
45
-
46
- **Note:** Training a robust OCR model can be time-consuming.
47
- """)
48
-
49
- # --- Initialize CharIndexer ---
50
- # This initializes char_indexer once when the script starts
51
- char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
52
-
53
- # --- Model Loading / Initialization ---
54
- @st.cache_resource # Cache the model to prevent reloading on every rerun
55
- def get_and_load_ocr_model_cached(num_classes, model_path):
56
- """
57
- Initializes the OCR model and attempts to load a pre-trained model.
58
- If no pre-trained model exists, a new model instance is returned.
59
- """
60
- model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
61
-
62
- if os.path.exists(model_path):
63
- st.sidebar.info("Loading pre-trained OCR model...")
64
- try:
65
- # Load model to CPU first, then move to device
66
- model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
67
- st.sidebar.success("OCR model loaded successfully!")
68
- except Exception as e:
69
- st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
70
- # If loading fails, re-initialize an untrained model
71
- model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
72
- else:
73
- st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
74
-
75
- return model_instance
76
-
77
- # Get the model instance and assign it to the global 'ocr_model'
78
- ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
79
- # Ensure the model is on the correct device for inference
80
- ocr_model.to(device)
81
- ocr_model.eval() # Set model to evaluation mode for inference by default
82
-
83
-
84
- # --- Sidebar for Model Training ---
85
- st.sidebar.header("Train OCR Model")
86
- st.sidebar.write("Click the button below to start training the OCR model.")
87
-
88
- # Progress bar and label for training in the sidebar
89
- progress_bar_sidebar = st.sidebar.progress(0)
90
- progress_label_sidebar = st.sidebar.empty()
91
-
92
- def update_progress_callback_sidebar(value, text):
93
- progress_bar_sidebar.progress(int(value * 100))
94
- progress_label_sidebar.text(text)
95
-
96
- if st.sidebar.button("πŸ“Š Start Training"):
97
- progress_bar_sidebar.progress(0)
98
- progress_label_sidebar.empty()
99
- st.empty()
100
-
101
- if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
102
- st.sidebar.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found!")
103
- elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
104
- st.sidebar.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
105
- "Evaluation might be affected or skipped. Please ensure all data paths are correct.")
106
- else:
107
- st.sidebar.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
108
-
109
- try:
110
- train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
111
- st.sidebar.success("Training and Test DataFrames loaded successfully.")
112
-
113
- st.sidebar.success(f"CharIndexer initialized with {char_indexer.num_classes} classes.")
114
-
115
- train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer, BATCH_SIZE)
116
- st.sidebar.success("DataLoaders created successfully.")
117
-
118
- ocr_model.train()
119
-
120
- st.sidebar.write("Training in progress... This may take a while.")
121
- ocr_model, training_history = train_ocr_model(
122
- model=ocr_model,
123
- train_loader=train_loader,
124
- test_loader=test_loader,
125
- char_indexer=char_indexer,
126
- epochs=NUM_EPOCHS,
127
- device=device,
128
- progress_callback=update_progress_callback_sidebar
129
- )
130
- st.sidebar.success("OCR model training finished!")
131
- update_progress_callback_sidebar(1.0, "Training complete!")
132
-
133
- os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
134
- save_ocr_model(ocr_model, MODEL_SAVE_PATH)
135
- st.sidebar.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
136
-
137
- except Exception as e:
138
- st.sidebar.error(f"An error occurred during training: {e}")
139
- st.exception(e)
140
- update_progress_callback_sidebar(0.0, "Training failed!")
141
-
142
- # --- Sidebar for Model Loading ---
143
- st.sidebar.header("Load Pre-trained Model")
144
- st.sidebar.write("If you have a saved model, you can load it here instead of training.")
145
-
146
- if st.sidebar.button("πŸ’Ύ Load Model"):
147
- if os.path.exists(MODEL_SAVE_PATH):
148
- try:
149
- loaded_model = CRNN(num_classes=char_indexer.num_classes)
150
- load_ocr_model(loaded_model, MODEL_SAVE_PATH)
151
- loaded_model.to(device)
152
-
153
- st.sidebar.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`")
154
- except Exception as e:
155
- st.sidebar.error(f"Error loading model: {e}")
156
- st.exception(e)
157
- else:
158
- st.sidebar.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.")
159
-
160
- # --- Main Content: Prediction Section and Training History ---
161
-
162
- # Display training history chart
163
- if training_history:
164
- st.subheader("Training History Plots")
165
- history_df = pd.DataFrame({
166
- 'Epoch': range(1, len(training_history['train_loss']) + 1),
167
- 'Train Loss': training_history['train_loss'],
168
- 'Test Loss': training_history['test_loss'],
169
- 'Test CER (%)': [cer * 100 for cer in training_history['test_cer']],
170
- 'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']]
171
- })
172
-
173
- st.markdown("**Loss over Epochs**")
174
- st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
175
- st.caption("Lower loss indicates better model performance.")
176
-
177
- st.markdown("**Character Error Rate (CER) over Epochs**")
178
- st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
179
- st.caption("Lower CER indicates fewer character errors (0% is perfect).")
180
-
181
- st.markdown("**Exact Match Accuracy over Epochs**")
182
- st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
183
- st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
184
-
185
- st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
186
- st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
187
- st.caption("CER should decrease, Accuracy should increase.")
188
- st.write("---") # Separator after charts
189
-
190
-
191
- # Predict on a New Image
192
-
193
- if ocr_model is None:
194
- st.warning("Please train or load a model before attempting prediction.")
195
- else:
196
- uploaded_file = st.file_uploader("πŸ–ΌοΈ Choose an image...", type=["png", "jpg", "jpeg", "jfif"])
197
-
198
- if uploaded_file is not None:
199
- try:
200
- image_pil = Image.open(uploaded_file).convert('L')
201
- st.image(image_pil, caption="Uploaded Image", use_container_width=True)
202
- st.write("---")
203
- st.write("Processing and Recognizing...")
204
-
205
- processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
206
-
207
- ocr_model.eval()
208
- with torch.no_grad():
209
- output = ocr_model(processed_image_tensor)
210
-
211
- predicted_texts = ctc_greedy_decode(output, char_indexer)
212
- predicted_text = predicted_texts[0]
213
-
214
- st.success(f"Recognized Text: **{predicted_text}**")
215
-
216
- except Exception as e:
217
- st.error(f"Error processing image or recognizing text: {e}")
218
- st.info("πŸ’‘ **Tips for best results:**\n"
219
- "- Ensure the handwritten text is clear and on a clean background.\n"
220
- "- Only include one name/word per image.\n"
221
- "- The model is trained on specific characters. Unusual symbols might not be recognized.")
222
- st.exception(e)
223
-
224
- st.markdown("""
225
- ---
226
- *Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
227
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config.py DELETED
@@ -1,48 +0,0 @@
1
- # config.py
2
-
3
- import os
4
-
5
- # --- Paths ---
6
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
7
- DATA_DIR = os.path.join(BASE_DIR, 'data')
8
- MODELS_DIR = os.path.join(BASE_DIR, 'models')
9
-
10
- TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'images')
11
- TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'images')
12
-
13
- TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
14
- TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
15
-
16
- MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'handwritten_name_ocr_model.pth')
17
-
18
- # --- Character Set and OCR Configuration ---
19
- CHARS = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
20
- BLANK_TOKEN_SYMBOL = 'Þ'
21
- VOCABULARY = CHARS + BLANK_TOKEN_SYMBOL
22
- NUM_CLASSES = len(VOCABULARY)
23
- BLANK_TOKEN = VOCABULARY.find(BLANK_TOKEN_SYMBOL)
24
-
25
- # --- Sanity Checks ---
26
- if BLANK_TOKEN == -1:
27
- raise ValueError(f"Error: BLANK_TOKEN_SYMBOL '{BLANK_TOKEN_SYMBOL}' not found in VOCABULARY. Check config.py definitions.")
28
- if BLANK_TOKEN >= NUM_CLASSES:
29
- raise ValueError(f"Error: BLANK_TOKEN index ({BLANK_TOKEN}) must be less than NUM_CLASSES ({NUM_CLASSES}).")
30
-
31
- print(f"Config Loaded: NUM_CLASSES={NUM_CLASSES}, BLANK_TOKEN_INDEX={BLANK_TOKEN}")
32
- print(f"Vocabulary Length: {len(VOCABULARY)}")
33
- print(f"Blank Symbol: '{BLANK_TOKEN_SYMBOL}' at index {BLANK_TOKEN}")
34
-
35
-
36
- # --- Image Preprocessing Parameters ---
37
- IMG_HEIGHT = 32 # Target height for all input images to the model
38
- MAX_IMG_WIDTH = 1024 # Adjust this value based on your typical image widths and available RAM
39
-
40
- # --- Training Parameters ---
41
- BATCH_SIZE = 10
42
-
43
- # Dataset Limits
44
- TRAIN_SAMPLES_LIMIT = 1000
45
- TEST_SAMPLES_LIMIT = 1000
46
-
47
- NUM_EPOCHS = 5
48
- LEARNING_RATE = 0.001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data_handler_ocr.py DELETED
@@ -1,165 +0,0 @@
1
- # data_handler_ocr.py
2
-
3
- import pandas as pd
4
- import torch
5
- from torch.utils.data import Dataset, DataLoader
6
- from torchvision import transforms
7
- import os
8
- from PIL import Image
9
- import numpy as np
10
- import torch.nn.functional as F
11
-
12
- # Import utility functions and config
13
- from config import (
14
- VOCABULARY, BLANK_TOKEN, BLANK_TOKEN_SYMBOL, IMG_HEIGHT,
15
- TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
16
- TRAIN_SAMPLES_LIMIT, TEST_SAMPLES_LIMIT
17
- )
18
- from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
19
-
20
- class CharIndexer:
21
- """Manages character-to-index and index-to-character mappings."""
22
- def __init__(self, vocabulary_string: str, blank_token_symbol: str):
23
- self.chars = sorted(list(set(vocabulary_string)))
24
- self.char_to_idx = {char: i for i, char in enumerate(self.chars)}
25
- self.idx_to_char = {i: char for i, char in enumerate(self.chars)}
26
-
27
- if blank_token_symbol not in self.char_to_idx:
28
- raise ValueError(f"Blank token symbol '{blank_token_symbol}' not found in provided vocabulary string: '{vocabulary_string}'")
29
-
30
- self.blank_token_idx = self.char_to_idx[blank_token_symbol]
31
- self.num_classes = len(self.chars)
32
-
33
- if self.blank_token_idx >= self.num_classes:
34
- raise ValueError(f"Blank token index ({self.blank_token_idx}) is out of range for num_classes ({self.num_classes}). This indicates a configuration mismatch.")
35
-
36
- print(f"CharIndexer initialized: num_classes={self.num_classes}, blank_token_idx={self.blank_token_idx}")
37
- print(f"Mapped blank symbol: '{self.idx_to_char[self.blank_token_idx]}'")
38
-
39
- def encode(self, text: str) -> list[int]:
40
- """Converts a text string to a list of integer indices."""
41
- encoded_list = []
42
- for char in text:
43
- if char in self.char_to_idx:
44
- encoded_list.append(self.char_to_idx[char])
45
- else:
46
- print(f"Warning: Character '{char}' not found in CharIndexer vocabulary. Mapping to blank token.")
47
- encoded_list.append(self.blank_token_idx)
48
- return encoded_list
49
-
50
- def decode(self, indices: list[int]) -> str:
51
- """Converts a list of integer indices back to a text string."""
52
- decoded_text = []
53
- for i, idx in enumerate(indices):
54
- if idx == self.blank_token_idx:
55
- continue # Skip blank tokens
56
-
57
- if i > 0 and indices[i-1] == idx:
58
- continue
59
-
60
- if idx in self.idx_to_char:
61
- decoded_text.append(self.idx_to_char[idx])
62
- else:
63
- print(f"Warning: Index {idx} not found in CharIndexer's idx_to_char mapping during decoding.")
64
-
65
- return "".join(decoded_text)
66
-
67
- class OCRDataset(Dataset):
68
- """
69
- Custom PyTorch Dataset for the Handwritten Name Recognition task.
70
- Loads images and their corresponding text labels.
71
- """
72
- def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
73
- self.data = dataframe
74
- self.char_indexer = char_indexer
75
- self.image_dir = image_dir
76
-
77
- if transform is None:
78
- self.transform = transforms.Compose([
79
- transforms.Lambda(lambda img: binarize_image(img)),
80
- transforms.Lambda(lambda img: resize_image_for_ocr(img, IMG_HEIGHT)), # Resize image to fixed height
81
- transforms.ToTensor(), # Convert PIL Image to PyTorch Tensor (H, W) -> (1, H, W), scales to [0,1]
82
- transforms.Lambda(normalize_image_for_model) # Normalize pixel values to [-1, 1]
83
- ])
84
- else:
85
- self.transform = transform
86
-
87
-
88
- def __len__(self) -> int:
89
- return len(self.data)
90
-
91
- def __getitem__(self, idx):
92
- raw_filename_entry = self.data.loc[idx, 'FILENAME']
93
- ground_truth_text = self.data.loc[idx, 'IDENTITY']
94
-
95
- filename = raw_filename_entry.split(',')[0].strip()
96
- img_path = os.path.join(self.image_dir, filename)
97
- ground_truth_text = str(ground_truth_text)
98
-
99
- try:
100
- image = load_image_as_grayscale(img_path) # Returns PIL Image 'L'
101
- except FileNotFoundError:
102
- print(f"Error: Image file not found at {img_path}. Skipping this item.")
103
- raise
104
-
105
- if self.transform:
106
- image = self.transform(image)
107
-
108
- image_width = image.shape[2] # Assuming image is (C, H, W) after transform
109
-
110
- text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
111
- text_length = len(text_encoded)
112
-
113
- return image, text_encoded, image_width, text_length
114
-
115
- def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
116
- """
117
- Custom collate function for the DataLoader to handle variable-width images
118
- and variable-length text sequences for CTC loss.
119
- """
120
- images, texts, image_widths, text_lengths = zip(*batch)
121
-
122
- max_batch_width = max(image_widths)
123
- padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
124
- images_batch = torch.stack(padded_images, 0)
125
-
126
- texts_batch = torch.cat(texts, 0)
127
- text_lengths_tensor = torch.tensor(list(text_lengths), dtype=torch.long)
128
- image_widths_tensor = torch.tensor(image_widths, dtype=torch.long)
129
-
130
- return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
131
-
132
-
133
- def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
134
- """
135
- Loads training and testing dataframes.
136
- Assumes CSVs have 'FILENAME' and 'IDENTITY' columns.
137
- Applies dataset limits from config.py.
138
- """
139
- train_df = pd.read_csv(train_csv_path, encoding='ISO-8859-1')
140
- test_df = pd.read_csv(test_csv_path, encoding='ISO-8859-1')
141
-
142
- # Apply limits if they are set (not 0)
143
- if TRAIN_SAMPLES_LIMIT > 0:
144
- train_df = train_df.head(TRAIN_SAMPLES_LIMIT)
145
- print(f"Limited training data to {TRAIN_SAMPLES_LIMIT} samples.")
146
- if TEST_SAMPLES_LIMIT > 0:
147
- test_df = test_df.head(TEST_SAMPLES_LIMIT)
148
- print(f"Limited test data to {TEST_SAMPLES_LIMIT} samples.")
149
-
150
- return train_df, test_df
151
-
152
- def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
153
- char_indexer: CharIndexer, batch_size: int) -> tuple[DataLoader, DataLoader]:
154
- """
155
- Creates PyTorch DataLoader objects for OCR training and testing datasets,
156
- using specific image directories for train/test.
157
- """
158
- train_dataset = OCRDataset(train_df, char_indexer, TRAIN_IMAGES_DIR)
159
- test_dataset = OCRDataset(test_df, char_indexer, TEST_IMAGES_DIR)
160
-
161
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
162
- num_workers=0, collate_fn=ocr_collate_fn)
163
- test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
164
- num_workers=0, collate_fn=ocr_collate_fn)
165
- return train_loader, test_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/model_ocr.py DELETED
@@ -1,286 +0,0 @@
1
- # model_ocr.py
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- import torch.optim as optim
7
- from torch.utils.data import DataLoader
8
- from tqdm import tqdm
9
- from sklearn.metrics import accuracy_score
10
- import editdistance
11
-
12
- # Import config and char_indexer
13
- from config import IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN
14
- from data_handler_ocr import CharIndexer
15
- from utils_ocr import binarize_image, resize_image_for_ocr, normalize_image_for_model
16
-
17
-
18
- class CNN_Backbone(nn.Module):
19
- """
20
- CNN feature extractor for OCR. Designed to produce features suitable for RNN.
21
- Output feature map should have height 1 after the final pooling/reduction.
22
- """
23
- def __init__(self, input_channels=1, output_channels=512):
24
- super(CNN_Backbone, self).__init__()
25
- self.cnn = nn.Sequential(
26
- # First block
27
- nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
28
- nn.ReLU(True),
29
- nn.MaxPool2d(kernel_size=2, stride=2), # H: 32 -> 16, W: W_in -> W_in/2
30
-
31
- # Second block
32
- nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
33
- nn.ReLU(True),
34
- nn.MaxPool2d(kernel_size=2, stride=2), # H: 16 -> 8, W: W_in/2 -> W_in/4
35
-
36
- # Third block (with two conv layers)
37
- nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
38
- nn.ReLU(True),
39
- nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
40
- nn.ReLU(True),
41
- # This MaxPool2d effectively brings height from 8 to 4, with a small width adjustment due to padding
42
- nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # H: 8 -> 4, W: (W/4) -> (W/4 + 1) (approx)
43
-
44
- # Fourth block
45
- nn.Conv2d(256, output_channels, kernel_size=3, stride=1, padding=1),
46
- nn.ReLU(True),
47
- # This AdaptiveAvgPool2d makes sure the height dimension becomes 1
48
- # while preserving the width. This is crucial for RNN input.
49
- nn.AdaptiveAvgPool2d((1, None)) # Output height 1, preserve width
50
- )
51
-
52
- def forward(self, x: torch.Tensor) -> torch.Tensor:
53
- # x: (N, C, H, W) e.g., (B, 1, 32, W_img)
54
-
55
- # Pass through the CNN layers
56
- conv_features = self.cnn(x) # Output: (N, cnn_out_channels, 1, W_prime)
57
-
58
- # Squeeze the height dimension (which is 1)
59
- # This transforms (N, C_out, 1, W_prime) to (N, C_out, W_prime)
60
- conv_features = conv_features.squeeze(2)
61
-
62
- # Permute for RNN input: (sequence_length, batch_size, input_size)
63
- # This transforms (N, C_out, W_prime) to (W_prime, N, C_out)
64
- conv_features = conv_features.permute(2, 0, 1)
65
-
66
- # Return the CNN features, ready for the RNN layer in CRNN
67
- return conv_features
68
-
69
- class BidirectionalLSTM(nn.Module):
70
- """Bidirectional LSTM layer for sequence modeling."""
71
- def __init__(self, input_size: int, hidden_size: int, num_layers: int, dropout: float = 0.5):
72
- super(BidirectionalLSTM, self).__init__()
73
- self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
74
- bidirectional=True, dropout=dropout, batch_first=False)
75
- # batch_first=False expects input as (sequence_length, batch_size, input_size)
76
-
77
- def forward(self, x: torch.Tensor) -> torch.Tensor:
78
- output, _ = self.lstm(x) # [0] returns the output, [1] returns (h_n, c_n)
79
- return output
80
-
81
- class CRNN(nn.Module):
82
- """
83
- Convolutional Recurrent Neural Network for OCR.
84
- Combines CNN for feature extraction, LSTMs for sequence modeling,
85
- and a final linear layer for character prediction.
86
- """
87
- def __init__(self, num_classes: int, cnn_output_channels: int = 512,
88
- rnn_hidden_size: int = 256, rnn_num_layers: int = 2): # Corrected parameter name
89
- super(CRNN, self).__init__()
90
- self.cnn = CNN_Backbone(output_channels=cnn_output_channels)
91
- # Input to LSTM is the number of channels from the CNN output
92
- self.rnn = BidirectionalLSTM(cnn_output_channels, rnn_hidden_size, rnn_num_layers) # Corrected usage
93
- # Output of bidirectional LSTM is hidden_size * 2
94
- self.fc = nn.Linear(rnn_hidden_size * 2, num_classes) # Final linear layer for classes
95
-
96
- def forward(self, x: torch.Tensor) -> torch.Tensor:
97
- # x: (N, C, H, W) e.g., (B, 1, 32, W_img)
98
-
99
- # 1. Pass through the CNN to extract features
100
- conv_features = self.cnn(x) # Output: (W_prime, N, C_out) after permute in CNN_Backbone
101
-
102
- # 2. Pass CNN features through the RNN (LSTM)
103
- rnn_features = self.rnn(conv_features) # Output: (W_prime, N, rnn_hidden_size * 2)
104
-
105
- # 3. Pass RNN features through the final fully connected layer
106
- # Apply the linear layer to each time step independently
107
- # output will be (W_prime, N, num_classes)
108
- output = self.fc(rnn_features)
109
-
110
- return output
111
-
112
-
113
- # --- Decoding Function ---
114
- def ctc_greedy_decode(output: torch.Tensor, char_indexer: CharIndexer) -> list[str]:
115
- """
116
- Performs greedy decoding on the CTC output.
117
- output: (sequence_length, batch_size, num_classes) - raw logits
118
- """
119
- # Apply log_softmax to get probabilities for argmax
120
- log_probs = F.log_softmax(output, dim=2)
121
-
122
- # Permute to (batch_size, sequence_length, num_classes) for argmax along class dim
123
- predicted_indices = torch.argmax(log_probs.permute(1, 0, 2), dim=2).cpu().numpy()
124
-
125
- decoded_texts = []
126
- for seq in predicted_indices:
127
- # Use char_indexer's decode method, which handles blank removal and duplicate collapse
128
- decoded_texts.append(char_indexer.decode(seq.tolist()))
129
- return decoded_texts
130
-
131
- # --- Evaluation Function ---
132
- def evaluate_model(model: nn.Module, dataloader: DataLoader, char_indexer: CharIndexer, device: str):
133
- model.eval()
134
- criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
135
- total_loss = 0
136
- all_predictions = []
137
- all_ground_truths = []
138
-
139
- with torch.no_grad():
140
- for inputs, targets_padded, _, target_lengths in tqdm(dataloader, desc="Evaluating"):
141
- inputs = inputs.to(device)
142
- targets_padded = targets_padded.to(device)
143
- target_lengths_tensor = target_lengths.to(device)
144
-
145
- output = model(inputs)
146
-
147
- outputs_seq_len_for_ctc = torch.full(
148
- size=(output.shape[1],),
149
- fill_value=output.shape[0],
150
- dtype=torch.long,
151
- device=device
152
- )
153
-
154
- # CTC Loss calculation requires log_softmax on the output logits
155
- log_probs_for_loss = F.log_softmax(output, dim=2)
156
-
157
- # CTCLoss expects targets_padded as a 1D tensor and target_lengths_tensor as corresponding lengths
158
- loss = criterion(log_probs_for_loss, targets_padded, outputs_seq_len_for_ctc, target_lengths_tensor)
159
- total_loss += loss.item() * inputs.size(0)
160
-
161
- decoded_preds = ctc_greedy_decode(output, char_indexer)
162
- all_predictions.extend(decoded_preds)
163
-
164
- ground_truths_batch = []
165
- current_idx_in_concatenated_targets = 0
166
-
167
- target_lengths_list = target_lengths.cpu().tolist()
168
-
169
- for i in range(inputs.size(0)):
170
- length = target_lengths_list[i]
171
-
172
- current_target_segment = targets_padded[current_idx_in_concatenated_targets : current_idx_in_concatenated_targets + length].tolist()
173
- ground_truths_batch.append(char_indexer.decode(current_target_segment))
174
- current_idx_in_concatenated_targets += length
175
-
176
- all_ground_truths.extend(ground_truths_batch)
177
-
178
- avg_loss = total_loss / len(dataloader.dataset)
179
-
180
- # Calculate Character Error Rate (CER)
181
- cer_sum = 0
182
- total_chars = 0
183
- for pred, gt in zip(all_predictions, all_ground_truths):
184
- cer_sum += editdistance.eval(pred, gt)
185
- total_chars += len(gt)
186
- char_error_rate = cer_sum / total_chars if total_chars > 0 else 0.0
187
-
188
- # Calculate Exact Match Accuracy (Word-level Accuracy)
189
- exact_match_accuracy = accuracy_score(all_ground_truths, all_predictions)
190
-
191
- return avg_loss, char_error_rate, exact_match_accuracy
192
-
193
- # --- Training Function ---
194
- def train_ocr_model(model: nn.Module, train_loader: DataLoader,
195
- test_loader: DataLoader, char_indexer: CharIndexer,
196
- epochs: int, device: str, progress_callback=None) -> tuple[nn.Module, dict]:
197
- """
198
- Trains the OCR model using CTC loss.
199
- """
200
- # CTCLoss needs the blank token index
201
- criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
202
- optimizer = optim.Adam(model.parameters(), lr=0.001) # Using a fixed LR for now
203
- # Using ReduceLROnPlateau to adjust LR based on test loss (monitor 'min' loss)
204
- scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5) # Removed verbose=True
205
-
206
- model.to(device) # Ensure model is on the correct device
207
- model.train() # Set model to training mode
208
-
209
- training_history = {
210
- 'train_loss': [],
211
- 'test_loss': [],
212
- 'test_cer': [],
213
- 'test_exact_match_accuracy': []
214
- }
215
-
216
- for epoch in range(epochs):
217
- running_loss = 0.0
218
- pbar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)")
219
- for images, texts_encoded, _, text_lengths in pbar_train:
220
- images = images.to(device)
221
- # Ensure target tensors are on the correct device for CTCLoss calculation
222
- texts_encoded = texts_encoded.to(device)
223
- text_lengths = text_lengths.to(device)
224
-
225
- optimizer.zero_grad() # Clear gradients from previous step
226
- outputs = model(images) # (sequence_length_from_cnn, batch_size, num_classes)
227
-
228
- # `outputs.shape[0]` is the actual sequence length (T) produced by the model.
229
- # CTC loss expects `input_lengths` to be a tensor of shape (batch_size,) with these values.
230
- outputs_seq_len_for_ctc = torch.full(
231
- size=(outputs.shape[1],), # batch_size
232
- fill_value=outputs.shape[0], # actual sequence length (T) from model output
233
- dtype=torch.long,
234
- device=device
235
- )
236
-
237
- # CTC Loss calculation requires log_softmax on the output logits
238
- log_probs_for_loss = F.log_softmax(outputs, dim=2) # (T, N, C)
239
-
240
- # Use outputs_seq_len_for_ctc for the input_lengths argument
241
- loss = criterion(log_probs_for_loss, texts_encoded, outputs_seq_len_for_ctc, text_lengths)
242
- loss.backward() # Backpropagate
243
- optimizer.step() # Update model weights
244
-
245
- running_loss += loss.item() * images.size(0) # Multiply by batch size for correct average
246
- pbar_train.set_postfix(loss=loss.item())
247
-
248
- epoch_train_loss = running_loss / len(train_loader.dataset)
249
- training_history['train_loss'].append(epoch_train_loss)
250
-
251
- # Evaluate on test set using the dedicated function
252
- # Ensure model is in eval mode before calling evaluate_model
253
- model.eval()
254
- test_loss, test_cer, test_exact_match_accuracy = evaluate_model(model, test_loader, char_indexer, device)
255
- training_history['test_loss'].append(test_loss)
256
- training_history['test_cer'].append(test_cer)
257
- training_history['test_exact_match_accuracy'].append(test_exact_match_accuracy)
258
-
259
- # Adjust learning rate based on test loss
260
- scheduler.step(test_loss)
261
-
262
- print(f"Epoch {epoch+1}/{epochs}: Train Loss={epoch_train_loss:.4f}, "
263
- f"Test Loss={test_loss:.4f}, Test CER={test_cer:.4f}, Test Exact Match Acc={test_exact_match_accuracy:.4f}")
264
-
265
- if progress_callback:
266
- # Update progress bar with current epoch and key metrics
267
- progress_val = (epoch + 1) / epochs
268
- progress_callback(progress_val, text=f"Epoch {epoch+1}/{epochs} done. Test CER: {test_cer:.4f}, Test Exact Match Acc: {test_exact_match_accuracy:.4f}")
269
-
270
- model.train() # Set model back to training mode after evaluation
271
-
272
- return model, training_history
273
-
274
- def save_ocr_model(model: nn.Module, path: str):
275
- """Saves the state dictionary of the trained OCR model."""
276
- torch.save(model.state_dict(), path)
277
- print(f"OCR model saved to {path}")
278
-
279
- def load_ocr_model(model: nn.Module, path: str):
280
- """
281
- Loads a trained OCR model's state dictionary.
282
- Includes map_location to handle loading models trained on GPU to CPU, and vice versa.
283
- """
284
- model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) # Always load to CPU first
285
- model.eval() # Set to evaluation mode
286
- print(f"OCR model loaded from {path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils_ocr.py DELETED
@@ -1,83 +0,0 @@
1
- #utils_ocr.py
2
-
3
- import cv2
4
- import numpy as np
5
- from PIL import Image
6
- import torch
7
- import torchvision.transforms as transforms
8
- import os
9
-
10
- # Import config for IMG_HEIGHT and MAX_IMG_WIDTH
11
- from config import IMG_HEIGHT, MAX_IMG_WIDTH
12
-
13
- # --- Image Preprocessing Functions ---
14
-
15
- def load_image_as_grayscale(image_path: str) -> Image.Image:
16
- """Loads an image from path and converts it to grayscale PIL Image."""
17
- if not os.path.exists(image_path):
18
- raise FileNotFoundError(f"Image not found at: {image_path}")
19
- return Image.open(image_path).convert('L') # 'L' for grayscale
20
-
21
- def binarize_image(img: Image.Image) -> Image.Image:
22
- """
23
- Binarizes a grayscale PIL Image using Otsu's method.
24
- Returns a PIL Image.
25
- """
26
- # Convert PIL Image to OpenCV format (numpy array)
27
- img_np = np.array(img)
28
-
29
- # Apply Otsu's binarization
30
- _, binary_img = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
31
-
32
- # Convert back to PIL Image
33
- return Image.fromarray(binary_img)
34
-
35
- def resize_image_for_ocr(img: Image.Image, img_height: int) -> Image.Image:
36
- """
37
- Resizes a PIL Image to a fixed height while maintaining aspect ratio.
38
- Also ensures the width does not exceed MAX_IMG_WIDTH.
39
- """
40
- width, height = img.size
41
-
42
- # Calculate new width based on target height, maintaining aspect ratio
43
- new_width = int(width * (img_height / height))
44
-
45
- if new_width > MAX_IMG_WIDTH:
46
- new_width = MAX_IMG_WIDTH
47
- resized_img = img.resize((new_width, img_height), Image.Resampling.LANCZOS)
48
- if resized_img.width > MAX_IMG_WIDTH:
49
- # Crop the image from the left to MAX_IMG_WIDTH
50
- resized_img = resized_img.crop((0, 0, MAX_IMG_WIDTH, img_height))
51
- return resized_img
52
-
53
- return img.resize((new_width, img_height), Image.Resampling.LANCZOS) # Use LANCZOS for high-quality downsampling
54
-
55
- def normalize_image_for_model(img_tensor: torch.Tensor) -> torch.Tensor:
56
- """
57
- Normalizes a torch.Tensor image (grayscale) for input into the model.
58
- Puts pixel values in range [-1, 1].
59
- Assumes image is already a torch.Tensor with values in [0, 1] (e.g., after ToTensor).
60
- """
61
- # Formula: (pixel_value - mean) / std_dev
62
- # For [0, 1] to [-1, 1], mean = 0.5, std_dev = 0.5
63
- img_tensor = (img_tensor - 0.5) / 0.5
64
- return img_tensor
65
-
66
- def preprocess_user_image_for_ocr(image_pil: Image.Image, target_height: int) -> torch.Tensor:
67
- """
68
- Applies all necessary preprocessing steps to a user-uploaded PIL Image
69
- to prepare it for the OCR model.
70
- """
71
- # Define a transformation pipeline similar to the dataset, but including ToTensor
72
- transform_pipeline = transforms.Compose([
73
- transforms.Lambda(lambda img: binarize_image(img)), # PIL Image -> PIL Image
74
- # Use the updated resize function that also handles MAX_IMG_WIDTH
75
- transforms.Lambda(lambda img: resize_image_for_ocr(img, target_height)), # PIL Image -> PIL Image
76
- transforms.ToTensor(), # PIL Image -> Tensor [0, 1]
77
- transforms.Lambda(normalize_image_for_model) # Tensor [0, 1] -> Tensor [-1, 1]
78
- ])
79
-
80
- processed_image = transform_pipeline(image_pil)
81
-
82
- # Add a batch dimension (C, H, W) -> (1, C, H, W) for single image inference
83
- return processed_image.unsqueeze(0)