marianeft commited on
Commit
15dba6b
Β·
verified Β·
1 Parent(s): 0c48050

Training Model Complete

Browse files
Files changed (5) hide show
  1. app.py +227 -219
  2. config.py +13 -77
  3. data_handler_ocr.py +165 -151
  4. model_ocr.py +285 -286
  5. utils_ocr.py +60 -161
app.py CHANGED
@@ -1,219 +1,227 @@
1
- # -*- coding: utf-8 -*-
2
- # app.py
3
-
4
- import os
5
- # CRITICAL FIX: Disable Streamlit's file watcher to prevent conflicts with PyTorch
6
- # This MUST be the first thing, before any other imports or Streamlit calls
7
- os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false"
8
-
9
- import streamlit as st
10
- import pandas as pd
11
- import numpy as np
12
- from PIL import Image
13
- import torch
14
- import torch.nn.functional as F # Added F for log_softmax in inference
15
- import torchvision.transforms as transforms
16
- import traceback # For detailed error logging
17
-
18
- # Import all necessary configuration values from config.py
19
- from config import (
20
- IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN, VOCABULARY, BLANK_TOKEN_SYMBOL,
21
- TRAIN_CSV_PATH, TEST_CSV_PATH, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
22
- MODEL_SAVE_PATH, BATCH_SIZE, NUM_EPOCHS
23
- )
24
-
25
- # Import classes and functions from data_handler_ocr.py and model_ocr.py
26
- from data_handler_ocr import CharIndexer, OCRDataset, ocr_collate_fn, load_ocr_dataframes, create_ocr_dataloaders
27
- from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
28
- from utils_ocr import preprocess_user_image_for_ocr, binarize_image, resize_image_for_ocr, normalize_image_for_model # Ensure these are imported if needed
29
-
30
- # --- Global Variables ---
31
- # These will hold the model and char_indexer instance after training or loading
32
- trained_ocr_model = None
33
- char_indexer = None
34
- training_history = None
35
- # Determine the device (GPU if available, else CPU)
36
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
-
38
- # --- Streamlit App Setup ---
39
- st.set_page_config(layout="wide", page_title="Handwritten Name OCR App") # Changed to wide layout for better display
40
-
41
- st.title("πŸ“ Handwritten Name Recognition (OCR) App") # Updated title for consistency
42
- st.markdown("""
43
- This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
44
- Optical Character Recognition (OCR) on handwritten names. You can upload an image
45
- of a handwritten name for prediction or train a new model using the provided dataset.
46
-
47
- **Note:** Training a robust OCR model can be time-consuming.
48
- """)
49
-
50
- # --- Initialize CharIndexer ---
51
- # CRITICAL FIX: Initialize CharIndexer with VOCABULARY and BLANK_TOKEN_SYMBOL
52
- # This resolves the ValueError: "Blank token symbol '95' not found..."
53
- char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
54
-
55
- # --- Model Loading / Initialization ---
56
- @st.cache_resource # Cache the model to prevent reloading on every rerun
57
- def get_and_load_ocr_model_cached(num_classes, model_path):
58
- """
59
- Initializes the OCR model and attempts to load a pre-trained model.
60
- If no pre-trained model exists, a new model instance is returned.
61
- """
62
- model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
63
-
64
- if os.path.exists(model_path):
65
- st.sidebar.info("Loading pre-trained OCR model...")
66
- try:
67
- # Load model to CPU first, then move to device
68
- model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
69
- st.sidebar.success("OCR model loaded successfully!")
70
- except Exception as e:
71
- st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
72
- # If loading fails, re-initialize an untrained model
73
- model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
74
- else:
75
- st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
76
-
77
- return model_instance
78
-
79
- # Get the model instance
80
- ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
81
- # Determine the device (GPU if available, else CPU)
82
- ocr_model.to(device)
83
- ocr_model.eval() # Set model to evaluation mode for inference by default
84
-
85
- # --- Sidebar for Model Training ---
86
- st.sidebar.header("Model Training (Optional)")
87
- st.sidebar.markdown("If you want to train a new model or no model is found:")
88
-
89
- # Initialize Streamlit widgets outside the button block
90
- training_progress_bar = st.sidebar.empty() # Placeholder for progress bar in sidebar
91
- status_text = st.sidebar.empty() # Placeholder for status messages in sidebar
92
-
93
- if st.sidebar.button("πŸ“Š Train New OCR Model"): # Keep button in sidebar as per user's last provided code
94
- # Clear previous messages/widgets if button is clicked again
95
- training_progress_bar.progress(0) # Reset progress bar
96
- training_progress_bar.empty()
97
- status_text.empty() # Clear status text
98
-
99
- # Check for existence of CSVs and image directories
100
- if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
101
- status_text.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found!")
102
- elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
103
- status_text.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
104
- "Evaluation might be affected or skipped. Please ensure all data paths are correct.")
105
- else:
106
- status_text.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
107
-
108
- # Define the progress bar instance here for the callback
109
- training_progress_bar_instance = training_progress_bar.progress(0.0, text="Training in progress. Please wait.")
110
-
111
- def update_progress_callback_sidebar(value, text):
112
- """Callback function to update Streamlit progress bar in sidebar."""
113
- training_progress_bar_instance.progress(int(value * 100))
114
- status_text.text(text) # Update status text in sidebar
115
-
116
- try:
117
- train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
118
- status_text.success("Training and Test DataFrames loaded successfully.")
119
-
120
- char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
121
- status_text.success(f"CharIndexer initialized with {char_indexer.num_classes} classes.")
122
-
123
- # Pass the limits to create_ocr_dataloaders
124
- train_loader, test_loader = create_ocr_dataloaders(
125
- train_df, test_df, char_indexer, BATCH_SIZE
126
- )
127
- status_text.success("DataLoaders created successfully.")
128
-
129
- ocr_model_for_training = CRNN(num_classes=NUM_CLASSES) # Create a new instance for training
130
- ocr_model_for_training.to(device)
131
- status_text.info(f"CRNN model initialized and moved to {device}.")
132
-
133
- status_text.write("Training in progress... This may take a while.")
134
- trained_ocr_model, training_history = train_ocr_model(
135
- model=ocr_model_for_training, # Pass the new instance
136
- train_loader=train_loader,
137
- test_loader=test_loader,
138
- char_indexer=char_indexer, # Pass char_indexer for CER calculation
139
- epochs=NUM_EPOCHS,
140
- device=device,
141
- progress_callback=update_progress_callback_sidebar # Pass the sidebar callback
142
- )
143
- status_text.success("OCR model training finished!")
144
- update_progress_callback_sidebar(1.0, "Training complete!")
145
-
146
- os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
147
- save_ocr_model(trained_ocr_model, MODEL_SAVE_PATH)
148
- status_text.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
149
-
150
- # Display training history chart in the main section, not sidebar
151
- if training_history:
152
- st.subheader("Training History Plots")
153
- history_df = pd.DataFrame({
154
- 'Epoch': range(1, len(training_history['train_loss']) + 1),
155
- 'Train Loss': training_history['train_loss'],
156
- 'Test Loss': training_history['test_loss'],
157
- 'Test CER (%)': [cer * 100 for cer in training_history['test_cer']],
158
- 'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']]
159
- })
160
-
161
- st.markdown("**Loss over Epochs**")
162
- st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
163
- st.caption("Lower loss indicates better model performance.")
164
-
165
- st.markdown("**Character Error Rate (CER) over Epochs**")
166
- st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
167
- st.caption("Lower CER indicates fewer character errors (0% is perfect).")
168
-
169
- st.markdown("**Exact Match Accuracy over Epochs**")
170
- st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
171
- st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
172
-
173
- st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
174
- st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
175
- st.caption("CER should decrease, Accuracy should increase.")
176
-
177
- except Exception as e:
178
- status_text.error(f"An error occurred during training: {e}")
179
- status_text.exception(e) # Display full traceback in Streamlit
180
- update_progress_callback_sidebar(0.0, "Training failed!")
181
-
182
- # --- Main Content: Name Prediction ---
183
- st.header("Predict Your Handwritten Name")
184
- st.markdown("Upload a clear image of a single handwritten name or word.")
185
-
186
- uploaded_file = st.file_uploader("πŸ–ΌοΈ Choose an image...", type=["png", "jpg", "jpeg"])
187
-
188
- if uploaded_file is not None:
189
- try:
190
- # Open the uploaded image
191
- image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale
192
- # Use use_container_width for deprecation warning fix
193
- st.image(image_pil, caption="Uploaded Image", use_container_width=True)
194
- st.write("---")
195
- st.write("Processing and Recognizing...")
196
-
197
- # Preprocess the image for the model using utils_ocr function
198
- processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
199
-
200
- trained_ocr_model.eval() # Ensure model is in evaluation mode
201
- with torch.no_grad(): # Disable gradient calculation for inference
202
- output = trained_ocr_model(processed_image_tensor) # (sequence_length, batch_size, num_classes)
203
- predicted_texts = ctc_greedy_decode(output, char_indexer)
204
- predicted_text = predicted_texts[0] # Get the first (and only) prediction
205
-
206
- st.success(f"Recognized Text: **{predicted_text}**")
207
-
208
- except Exception as e:
209
- st.error(f"Error processing image or recognizing text: {e}")
210
- st.info("πŸ’‘ **Tips for best results:**\n"
211
- "- Ensure the handwritten text is clear and on a clean background.\n"
212
- "- Only include one name/word per image.\n"
213
- "- The model is trained on specific characters. Unusual symbols might not be recognized.")
214
- st.exception(e) # Display full traceback in Streamlit
215
-
216
- st.markdown("""
217
- ---
218
- *Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
219
- """)
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # app.py
3
+
4
+ import os
5
+ # Disable Streamlit file watcher to prevent conflicts with PyTorch
6
+ os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false"
7
+
8
+ import streamlit as st
9
+ import pandas as pd
10
+ import numpy as np
11
+ from PIL import Image
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import torchvision.transforms as transforms
15
+ import traceback
16
+
17
+ # Import all necessary configuration values from config.py
18
+ from config import (
19
+ IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN, VOCABULARY, BLANK_TOKEN_SYMBOL,
20
+ TRAIN_CSV_PATH, TEST_CSV_PATH, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
21
+ MODEL_SAVE_PATH, BATCH_SIZE, NUM_EPOCHS
22
+ )
23
+
24
+ # Import classes and functions from data_handler_ocr.py and model_ocr.py
25
+ from data_handler_ocr import CharIndexer, OCRDataset, ocr_collate_fn, load_ocr_dataframes, create_ocr_dataloaders
26
+ from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
27
+ from utils_ocr import preprocess_user_image_for_ocr, binarize_image, resize_image_for_ocr, normalize_image_for_model
28
+
29
+
30
+ # --- Global Variables ---
31
+ ocr_model = None
32
+ char_indexer = None
33
+ training_history = None
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ # --- Streamlit App Setup ---
37
+ st.set_page_config(layout="wide", page_title="Handwritten Name OCR App",)
38
+
39
+
40
+ st.title("πŸ“ Handwritten Name Recognition (OCR) App")
41
+ st.markdown("""
42
+ This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
43
+ Optical Character Recognition (OCR) on handwritten names. You can upload an image
44
+ of a handwritten name for prediction or train a new model using the provided dataset.
45
+
46
+ **Note:** Training a robust OCR model can be time-consuming.
47
+ """)
48
+
49
+ # --- Initialize CharIndexer ---
50
+ # This initializes char_indexer once when the script starts
51
+ char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
52
+
53
+ # --- Model Loading / Initialization ---
54
+ @st.cache_resource # Cache the model to prevent reloading on every rerun
55
+ def get_and_load_ocr_model_cached(num_classes, model_path):
56
+ """
57
+ Initializes the OCR model and attempts to load a pre-trained model.
58
+ If no pre-trained model exists, a new model instance is returned.
59
+ """
60
+ model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
61
+
62
+ if os.path.exists(model_path):
63
+ st.sidebar.info("Loading pre-trained OCR model...")
64
+ try:
65
+ # Load model to CPU first, then move to device
66
+ model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
67
+ st.sidebar.success("OCR model loaded successfully!")
68
+ except Exception as e:
69
+ st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
70
+ # If loading fails, re-initialize an untrained model
71
+ model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
72
+ else:
73
+ st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
74
+
75
+ return model_instance
76
+
77
+ # Get the model instance and assign it to the global 'ocr_model'
78
+ ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
79
+ # Ensure the model is on the correct device for inference
80
+ ocr_model.to(device)
81
+ ocr_model.eval() # Set model to evaluation mode for inference by default
82
+
83
+
84
+ # --- Sidebar for Model Training ---
85
+ st.sidebar.header("Train OCR Model")
86
+ st.sidebar.write("Click the button below to start training the OCR model.")
87
+
88
+ # Progress bar and label for training in the sidebar
89
+ progress_bar_sidebar = st.sidebar.progress(0)
90
+ progress_label_sidebar = st.sidebar.empty()
91
+
92
+ def update_progress_callback_sidebar(value, text):
93
+ progress_bar_sidebar.progress(int(value * 100))
94
+ progress_label_sidebar.text(text)
95
+
96
+ if st.sidebar.button("πŸ“Š Start Training"):
97
+ progress_bar_sidebar.progress(0)
98
+ progress_label_sidebar.empty()
99
+ st.empty()
100
+
101
+ if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
102
+ st.sidebar.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found!")
103
+ elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
104
+ st.sidebar.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
105
+ "Evaluation might be affected or skipped. Please ensure all data paths are correct.")
106
+ else:
107
+ st.sidebar.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
108
+
109
+ try:
110
+ train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
111
+ st.sidebar.success("Training and Test DataFrames loaded successfully.")
112
+
113
+ st.sidebar.success(f"CharIndexer initialized with {char_indexer.num_classes} classes.")
114
+
115
+ train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer, BATCH_SIZE)
116
+ st.sidebar.success("DataLoaders created successfully.")
117
+
118
+ ocr_model.train()
119
+
120
+ st.sidebar.write("Training in progress... This may take a while.")
121
+ ocr_model, training_history = train_ocr_model(
122
+ model=ocr_model,
123
+ train_loader=train_loader,
124
+ test_loader=test_loader,
125
+ char_indexer=char_indexer,
126
+ epochs=NUM_EPOCHS,
127
+ device=device,
128
+ progress_callback=update_progress_callback_sidebar
129
+ )
130
+ st.sidebar.success("OCR model training finished!")
131
+ update_progress_callback_sidebar(1.0, "Training complete!")
132
+
133
+ os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
134
+ save_ocr_model(ocr_model, MODEL_SAVE_PATH)
135
+ st.sidebar.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
136
+
137
+ except Exception as e:
138
+ st.sidebar.error(f"An error occurred during training: {e}")
139
+ st.exception(e)
140
+ update_progress_callback_sidebar(0.0, "Training failed!")
141
+
142
+ # --- Sidebar for Model Loading ---
143
+ st.sidebar.header("Load Pre-trained Model")
144
+ st.sidebar.write("If you have a saved model, you can load it here instead of training.")
145
+
146
+ if st.sidebar.button("πŸ’Ύ Load Model"):
147
+ if os.path.exists(MODEL_SAVE_PATH):
148
+ try:
149
+ loaded_model = CRNN(num_classes=char_indexer.num_classes)
150
+ load_ocr_model(loaded_model, MODEL_SAVE_PATH)
151
+ loaded_model.to(device)
152
+
153
+ st.sidebar.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`")
154
+ except Exception as e:
155
+ st.sidebar.error(f"Error loading model: {e}")
156
+ st.exception(e)
157
+ else:
158
+ st.sidebar.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.")
159
+
160
+ # --- Main Content: Prediction Section and Training History ---
161
+
162
+ # Display training history chart
163
+ if training_history:
164
+ st.subheader("Training History Plots")
165
+ history_df = pd.DataFrame({
166
+ 'Epoch': range(1, len(training_history['train_loss']) + 1),
167
+ 'Train Loss': training_history['train_loss'],
168
+ 'Test Loss': training_history['test_loss'],
169
+ 'Test CER (%)': [cer * 100 for cer in training_history['test_cer']],
170
+ 'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']]
171
+ })
172
+
173
+ st.markdown("**Loss over Epochs**")
174
+ st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
175
+ st.caption("Lower loss indicates better model performance.")
176
+
177
+ st.markdown("**Character Error Rate (CER) over Epochs**")
178
+ st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
179
+ st.caption("Lower CER indicates fewer character errors (0% is perfect).")
180
+
181
+ st.markdown("**Exact Match Accuracy over Epochs**")
182
+ st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
183
+ st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
184
+
185
+ st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
186
+ st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
187
+ st.caption("CER should decrease, Accuracy should increase.")
188
+ st.write("---") # Separator after charts
189
+
190
+
191
+ # Predict on a New Image
192
+
193
+ if ocr_model is None:
194
+ st.warning("Please train or load a model before attempting prediction.")
195
+ else:
196
+ uploaded_file = st.file_uploader("πŸ–ΌοΈ Choose an image...", type=["png", "jpg", "jpeg", "jfif"])
197
+
198
+ if uploaded_file is not None:
199
+ try:
200
+ image_pil = Image.open(uploaded_file).convert('L')
201
+ st.image(image_pil, caption="Uploaded Image", use_container_width=True)
202
+ st.write("---")
203
+ st.write("Processing and Recognizing...")
204
+
205
+ processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
206
+
207
+ ocr_model.eval()
208
+ with torch.no_grad():
209
+ output = ocr_model(processed_image_tensor)
210
+
211
+ predicted_texts = ctc_greedy_decode(output, char_indexer)
212
+ predicted_text = predicted_texts[0]
213
+
214
+ st.success(f"Recognized Text: **{predicted_text}**")
215
+
216
+ except Exception as e:
217
+ st.error(f"Error processing image or recognizing text: {e}")
218
+ st.info("πŸ’‘ **Tips for best results:**\n"
219
+ "- Ensure the handwritten text is clear and on a clean background.\n"
220
+ "- Only include one name/word per image.\n"
221
+ "- The model is trained on specific characters. Unusual symbols might not be recognized.")
222
+ st.exception(e)
223
+
224
+ st.markdown("""
225
+ ---
226
+ *Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
227
+ """)
config.py CHANGED
@@ -1,4 +1,3 @@
1
- <<<<<<< HEAD
2
  # config.py
3
 
4
  import os
@@ -8,8 +7,8 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
8
  DATA_DIR = os.path.join(BASE_DIR, 'data')
9
  MODELS_DIR = os.path.join(BASE_DIR, 'models')
10
 
11
- TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'images', 'train')
12
- TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'images', 'test')
13
 
14
  TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
15
  TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
@@ -17,26 +16,13 @@ TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
17
  MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'handwritten_name_ocr_model.pth')
18
 
19
  # --- Character Set and OCR Configuration ---
20
- # This character set MUST cover all characters present in your dataset.
21
- # Add any special characters if needed.
22
- # The order here is crucial as it defines the indices for your characters.
23
  CHARS = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
24
-
25
- # Define the character for the blank token. It MUST NOT be in CHARS.
26
- BLANK_TOKEN_SYMBOL = 'Þ'
27
-
28
- # Construct the full vocabulary string. It's conventional to put the blank token last.
29
- # This VOCABULARY string is what you pass to CharIndexer.
30
  VOCABULARY = CHARS + BLANK_TOKEN_SYMBOL
31
-
32
- # NUM_CLASSES is the total number of unique symbols in the vocabulary, including the blank.
33
  NUM_CLASSES = len(VOCABULARY)
34
-
35
- # BLANK_TOKEN is the actual index of the blank symbol within the VOCABULARY.
36
- # Since we appended it last, its index will be len(CHARS).
37
  BLANK_TOKEN = VOCABULARY.find(BLANK_TOKEN_SYMBOL)
38
 
39
- # --- Sanity Checks (Highly Recommended) ---
40
  if BLANK_TOKEN == -1:
41
  raise ValueError(f"Error: BLANK_TOKEN_SYMBOL '{BLANK_TOKEN_SYMBOL}' not found in VOCABULARY. Check config.py definitions.")
42
  if BLANK_TOKEN >= NUM_CLASSES:
@@ -48,65 +34,15 @@ print(f"Blank Symbol: '{BLANK_TOKEN_SYMBOL}' at index {BLANK_TOKEN}")
48
 
49
 
50
  # --- Image Preprocessing Parameters ---
51
- IMG_HEIGHT = 32
 
52
 
53
  # --- Training Parameters ---
54
- BATCH_SIZE = 64
 
 
 
 
 
 
55
  LEARNING_RATE = 0.001
56
- =======
57
- # config.py
58
-
59
- import os
60
-
61
- # --- Paths ---
62
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
63
- DATA_DIR = os.path.join(BASE_DIR, 'data')
64
- MODELS_DIR = os.path.join(BASE_DIR, 'models')
65
-
66
- TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'images', 'train')
67
- TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'images', 'test')
68
-
69
- TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
70
- TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
71
-
72
- MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'handwritten_name_ocr_model.pth')
73
-
74
- # --- Character Set and OCR Configuration ---
75
- # This character set MUST cover all characters present in your dataset.
76
- # Add any special characters if needed.
77
- # The order here is crucial as it defines the indices for your characters.
78
- CHARS = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
79
-
80
- # Define the character for the blank token. It MUST NOT be in CHARS.
81
- BLANK_TOKEN_SYMBOL = 'Þ'
82
-
83
- # Construct the full vocabulary string. It's conventional to put the blank token last.
84
- # This VOCABULARY string is what you pass to CharIndexer.
85
- VOCABULARY = CHARS + BLANK_TOKEN_SYMBOL
86
-
87
- # NUM_CLASSES is the total number of unique symbols in the vocabulary, including the blank.
88
- NUM_CLASSES = len(VOCABULARY)
89
-
90
- # BLANK_TOKEN is the actual index of the blank symbol within the VOCABULARY.
91
- # Since we appended it last, its index will be len(CHARS).
92
- BLANK_TOKEN = VOCABULARY.find(BLANK_TOKEN_SYMBOL)
93
-
94
- # --- Sanity Checks (Highly Recommended) ---
95
- if BLANK_TOKEN == -1:
96
- raise ValueError(f"Error: BLANK_TOKEN_SYMBOL '{BLANK_TOKEN_SYMBOL}' not found in VOCABULARY. Check config.py definitions.")
97
- if BLANK_TOKEN >= NUM_CLASSES:
98
- raise ValueError(f"Error: BLANK_TOKEN index ({BLANK_TOKEN}) must be less than NUM_CLASSES ({NUM_CLASSES}).")
99
-
100
- print(f"Config Loaded: NUM_CLASSES={NUM_CLASSES}, BLANK_TOKEN_INDEX={BLANK_TOKEN}")
101
- print(f"Vocabulary Length: {len(VOCABULARY)}")
102
- print(f"Blank Symbol: '{BLANK_TOKEN_SYMBOL}' at index {BLANK_TOKEN}")
103
-
104
-
105
- # --- Image Preprocessing Parameters ---
106
- IMG_HEIGHT = 32
107
-
108
- # --- Training Parameters ---
109
- BATCH_SIZE = 64
110
- LEARNING_RATE = 0.001
111
- >>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
112
- NUM_EPOCHS = 3
 
 
1
  # config.py
2
 
3
  import os
 
7
  DATA_DIR = os.path.join(BASE_DIR, 'data')
8
  MODELS_DIR = os.path.join(BASE_DIR, 'models')
9
 
10
+ TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'images')
11
+ TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'images')
12
 
13
  TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
14
  TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
 
16
  MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'handwritten_name_ocr_model.pth')
17
 
18
  # --- Character Set and OCR Configuration ---
 
 
 
19
  CHARS = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
20
+ BLANK_TOKEN_SYMBOL = 'Þ'
 
 
 
 
 
21
  VOCABULARY = CHARS + BLANK_TOKEN_SYMBOL
 
 
22
  NUM_CLASSES = len(VOCABULARY)
 
 
 
23
  BLANK_TOKEN = VOCABULARY.find(BLANK_TOKEN_SYMBOL)
24
 
25
+ # --- Sanity Checks ---
26
  if BLANK_TOKEN == -1:
27
  raise ValueError(f"Error: BLANK_TOKEN_SYMBOL '{BLANK_TOKEN_SYMBOL}' not found in VOCABULARY. Check config.py definitions.")
28
  if BLANK_TOKEN >= NUM_CLASSES:
 
34
 
35
 
36
  # --- Image Preprocessing Parameters ---
37
+ IMG_HEIGHT = 32 # Target height for all input images to the model
38
+ MAX_IMG_WIDTH = 1024 # Adjust this value based on your typical image widths and available RAM
39
 
40
  # --- Training Parameters ---
41
+ BATCH_SIZE = 10
42
+
43
+ # NEW: Dataset Limits
44
+ TRAIN_SAMPLES_LIMIT = 1000
45
+ TEST_SAMPLES_LIMIT = 1000
46
+
47
+ NUM_EPOCHS = 5
48
  LEARNING_RATE = 0.001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_handler_ocr.py CHANGED
@@ -1,151 +1,165 @@
1
- #data_handler_ocr.py
2
-
3
- import pandas as pd
4
- import torch
5
- from torch.utils.data import Dataset, DataLoader
6
- from torchvision import transforms
7
- import os
8
- from PIL import Image
9
- import numpy as np
10
- import torch.nn.functional as F
11
-
12
- # Import utility functions and config
13
- from config import VOCABULARY, BLANK_TOKEN, BLANK_TOKEN_SYMBOL, IMG_HEIGHT, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR
14
- from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
15
-
16
- class CharIndexer:
17
- """Manages character-to-index and index-to-character mappings."""
18
- def __init__(self, vocabulary_string: str, blank_token_symbol: str):
19
- self.chars = sorted(list(set(vocabulary_string)))
20
- self.char_to_idx = {char: i for i, char in enumerate(self.chars)}
21
- self.idx_to_char = {i: char for i, char in enumerate(self.chars)}
22
-
23
- if blank_token_symbol not in self.char_to_idx:
24
- raise ValueError(f"Blank token symbol '{blank_token_symbol}' not found in provided vocabulary string: '{vocabulary_string}'")
25
-
26
- self.blank_token_idx = self.char_to_idx[blank_token_symbol]
27
- self.num_classes = len(self.chars)
28
-
29
- if self.blank_token_idx >= self.num_classes:
30
- raise ValueError(f"Blank token index ({self.blank_token_idx}) is out of range for num_classes ({self.num_classes}). This indicates a configuration mismatch.")
31
-
32
- print(f"CharIndexer initialized: num_classes={self.num_classes}, blank_token_idx={self.blank_token_idx}")
33
- print(f"Mapped blank symbol: '{self.idx_to_char[self.blank_token_idx]}'")
34
-
35
- def encode(self, text: str) -> list[int]:
36
- """Converts a text string to a list of integer indices."""
37
- encoded_list = []
38
- for char in text:
39
- if char in self.char_to_idx:
40
- encoded_list.append(self.char_to_idx[char])
41
- else:
42
- print(f"Warning: Character '{char}' not found in CharIndexer vocabulary. Mapping to blank token.")
43
- encoded_list.append(self.blank_token_idx)
44
- return encoded_list
45
-
46
- def decode(self, indices: list[int]) -> str:
47
- """Converts a list of integer indices back to a text string."""
48
- decoded_text = []
49
- for i, idx in enumerate(indices):
50
- if idx == self.blank_token_idx:
51
- continue
52
- if i > 0 and indices[i-1] == idx:
53
- continue
54
- if idx in self.idx_to_char:
55
- decoded_text.append(self.idx_to_char[idx])
56
- else:
57
- print(f"Warning: Index {idx} not found in CharIndexer's idx_to_char mapping during decoding.")
58
-
59
- return "".join(decoded_text)
60
-
61
- class OCRDataset(Dataset):
62
- """
63
- Custom PyTorch Dataset for the Handwritten Name Recognition task.
64
- Loads images and their corresponding text labels.
65
- """
66
- def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
67
- self.data = dataframe
68
- self.char_indexer = char_indexer
69
- self.image_dir = image_dir
70
-
71
- if transform is None:
72
- self.transform = transforms.Compose([
73
-
74
- transforms.Lambda(lambda img: binarize_image(img)),
75
- transforms.Lambda(lambda img: resize_image_for_ocr(img, IMG_HEIGHT)),
76
- transforms.ToTensor(),
77
- transforms.Lambda(normalize_image_for_model)
78
- ])
79
- else:
80
- self.transform = transform
81
-
82
-
83
- def __len__(self) -> int:
84
- return len(self.data)
85
-
86
- def __getitem__(self, idx):
87
- raw_filename_entry = self.data.loc[idx, 'FILENAME']
88
- ground_truth_text = self.data.loc[idx, 'IDENTITY']
89
-
90
- filename_only = raw_filename_entry.split(',')[0].strip()
91
-
92
- img_path = os.path.join(self.image_dir, filename_only)
93
- ground_truth_text = str(ground_truth_text)
94
-
95
- try:
96
- image = load_image_as_grayscale(img_path)
97
- except FileNotFoundError:
98
- print(f"Error: Image file not found at {img_path}. Please check your dataset and config.py paths.")
99
- raise
100
-
101
- if self.transform:
102
- image = self.transform(image)
103
-
104
- image_width = image.shape[2]
105
-
106
- text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
107
- text_length = len(text_encoded)
108
-
109
- return image, text_encoded, image_width, text_length
110
-
111
- def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
112
- """
113
- Custom collate function for the DataLoader to handle variable-width images
114
- and variable-length text sequences for CTC loss.
115
- """
116
- images, texts, image_widths, text_lengths = zip(*batch)
117
-
118
- max_batch_width = max(image_widths)
119
- padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
120
- images_batch = torch.stack(padded_images, 0)
121
-
122
- texts_batch = torch.cat(texts, 0)
123
- text_lengths_tensor = torch.tensor(text_lengths, dtype=torch.long)
124
- image_widths_tensor = torch.tensor(image_widths, dtype=torch.long)
125
-
126
- return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
127
-
128
-
129
- def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
130
- """
131
- Loads training and testing dataframes.
132
- Assumes CSVs have 'FILENAME' and 'IDENTITY' columns and are comma-delimited with no header.
133
- """
134
- train_df = pd.read_csv(train_csv_path, delimiter=',', names=['FILENAME', 'IDENTITY'], header=None, encoding='utf-8')
135
- test_df = pd.read_csv(test_csv_path, delimiter=',', names=['FILENAME', 'IDENTITY'], header=None, encoding='utf-8')
136
- return train_df, test_df
137
-
138
- def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
139
- char_indexer: CharIndexer, batch_size: int) -> tuple[DataLoader, DataLoader]:
140
- """
141
- Creates PyTorch DataLoader objects for OCR training and testing datasets,
142
- using specific image directories for train/test.
143
- """
144
- train_dataset = OCRDataset(dataframe=train_df, char_indexer=char_indexer, image_dir=TRAIN_IMAGES_DIR)
145
- test_dataset = OCRDataset(dataframe=test_df, char_indexer=char_indexer, image_dir=TEST_IMAGES_DIR)
146
-
147
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
148
- num_workers=0, collate_fn=ocr_collate_fn)
149
- test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
150
- num_workers=0, collate_fn=ocr_collate_fn)
151
- return train_loader, test_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data_handler_ocr.py
2
+
3
+ import pandas as pd
4
+ import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from torchvision import transforms
7
+ import os
8
+ from PIL import Image
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+
12
+ # Import utility functions and config
13
+ from config import (
14
+ VOCABULARY, BLANK_TOKEN, BLANK_TOKEN_SYMBOL, IMG_HEIGHT,
15
+ TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
16
+ TRAIN_SAMPLES_LIMIT, TEST_SAMPLES_LIMIT
17
+ )
18
+ from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
19
+
20
+ class CharIndexer:
21
+ """Manages character-to-index and index-to-character mappings."""
22
+ def __init__(self, vocabulary_string: str, blank_token_symbol: str):
23
+ self.chars = sorted(list(set(vocabulary_string)))
24
+ self.char_to_idx = {char: i for i, char in enumerate(self.chars)}
25
+ self.idx_to_char = {i: char for i, char in enumerate(self.chars)}
26
+
27
+ if blank_token_symbol not in self.char_to_idx:
28
+ raise ValueError(f"Blank token symbol '{blank_token_symbol}' not found in provided vocabulary string: '{vocabulary_string}'")
29
+
30
+ self.blank_token_idx = self.char_to_idx[blank_token_symbol]
31
+ self.num_classes = len(self.chars)
32
+
33
+ if self.blank_token_idx >= self.num_classes:
34
+ raise ValueError(f"Blank token index ({self.blank_token_idx}) is out of range for num_classes ({self.num_classes}). This indicates a configuration mismatch.")
35
+
36
+ print(f"CharIndexer initialized: num_classes={self.num_classes}, blank_token_idx={self.blank_token_idx}")
37
+ print(f"Mapped blank symbol: '{self.idx_to_char[self.blank_token_idx]}'")
38
+
39
+ def encode(self, text: str) -> list[int]:
40
+ """Converts a text string to a list of integer indices."""
41
+ encoded_list = []
42
+ for char in text:
43
+ if char in self.char_to_idx:
44
+ encoded_list.append(self.char_to_idx[char])
45
+ else:
46
+ print(f"Warning: Character '{char}' not found in CharIndexer vocabulary. Mapping to blank token.")
47
+ encoded_list.append(self.blank_token_idx)
48
+ return encoded_list
49
+
50
+ def decode(self, indices: list[int]) -> str:
51
+ """Converts a list of integer indices back to a text string."""
52
+ decoded_text = []
53
+ for i, idx in enumerate(indices):
54
+ if idx == self.blank_token_idx:
55
+ continue # Skip blank tokens
56
+
57
+ if i > 0 and indices[i-1] == idx:
58
+ continue
59
+
60
+ if idx in self.idx_to_char:
61
+ decoded_text.append(self.idx_to_char[idx])
62
+ else:
63
+ print(f"Warning: Index {idx} not found in CharIndexer's idx_to_char mapping during decoding.")
64
+
65
+ return "".join(decoded_text)
66
+
67
+ class OCRDataset(Dataset):
68
+ """
69
+ Custom PyTorch Dataset for the Handwritten Name Recognition task.
70
+ Loads images and their corresponding text labels.
71
+ """
72
+ def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
73
+ self.data = dataframe
74
+ self.char_indexer = char_indexer
75
+ self.image_dir = image_dir
76
+
77
+ if transform is None:
78
+ self.transform = transforms.Compose([
79
+ transforms.Lambda(lambda img: binarize_image(img)),
80
+ transforms.Lambda(lambda img: resize_image_for_ocr(img, IMG_HEIGHT)), # Resize image to fixed height
81
+ transforms.ToTensor(), # Convert PIL Image to PyTorch Tensor (H, W) -> (1, H, W), scales to [0,1]
82
+ transforms.Lambda(normalize_image_for_model) # Normalize pixel values to [-1, 1]
83
+ ])
84
+ else:
85
+ self.transform = transform
86
+
87
+
88
+ def __len__(self) -> int:
89
+ return len(self.data)
90
+
91
+ def __getitem__(self, idx):
92
+ raw_filename_entry = self.data.loc[idx, 'FILENAME']
93
+ ground_truth_text = self.data.loc[idx, 'IDENTITY']
94
+
95
+ filename = raw_filename_entry.split(',')[0].strip()
96
+ img_path = os.path.join(self.image_dir, filename)
97
+ ground_truth_text = str(ground_truth_text)
98
+
99
+ try:
100
+ image = load_image_as_grayscale(img_path) # Returns PIL Image 'L'
101
+ except FileNotFoundError:
102
+ print(f"Error: Image file not found at {img_path}. Skipping this item.")
103
+ raise
104
+
105
+ if self.transform:
106
+ image = self.transform(image)
107
+
108
+ image_width = image.shape[2] # Assuming image is (C, H, W) after transform
109
+
110
+ text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
111
+ text_length = len(text_encoded)
112
+
113
+ return image, text_encoded, image_width, text_length
114
+
115
+ def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
116
+ """
117
+ Custom collate function for the DataLoader to handle variable-width images
118
+ and variable-length text sequences for CTC loss.
119
+ """
120
+ images, texts, image_widths, text_lengths = zip(*batch)
121
+
122
+ max_batch_width = max(image_widths)
123
+ padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
124
+ images_batch = torch.stack(padded_images, 0)
125
+
126
+ texts_batch = torch.cat(texts, 0)
127
+ text_lengths_tensor = torch.tensor(list(text_lengths), dtype=torch.long)
128
+ image_widths_tensor = torch.tensor(image_widths, dtype=torch.long)
129
+
130
+ return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
131
+
132
+
133
+ def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
134
+ """
135
+ Loads training and testing dataframes.
136
+ Assumes CSVs have 'FILENAME' and 'IDENTITY' columns.
137
+ Applies dataset limits from config.py.
138
+ """
139
+ train_df = pd.read_csv(train_csv_path, encoding='ISO-8859-1')
140
+ test_df = pd.read_csv(test_csv_path, encoding='ISO-8859-1')
141
+
142
+ # Apply limits if they are set (not 0)
143
+ if TRAIN_SAMPLES_LIMIT > 0:
144
+ train_df = train_df.head(TRAIN_SAMPLES_LIMIT)
145
+ print(f"Limited training data to {TRAIN_SAMPLES_LIMIT} samples.")
146
+ if TEST_SAMPLES_LIMIT > 0:
147
+ test_df = test_df.head(TEST_SAMPLES_LIMIT)
148
+ print(f"Limited test data to {TEST_SAMPLES_LIMIT} samples.")
149
+
150
+ return train_df, test_df
151
+
152
+ def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
153
+ char_indexer: CharIndexer, batch_size: int) -> tuple[DataLoader, DataLoader]:
154
+ """
155
+ Creates PyTorch DataLoader objects for OCR training and testing datasets,
156
+ using specific image directories for train/test.
157
+ """
158
+ train_dataset = OCRDataset(train_df, char_indexer, TRAIN_IMAGES_DIR)
159
+ test_dataset = OCRDataset(test_df, char_indexer, TEST_IMAGES_DIR)
160
+
161
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
162
+ num_workers=0, collate_fn=ocr_collate_fn)
163
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
164
+ num_workers=0, collate_fn=ocr_collate_fn)
165
+ return train_loader, test_loader
model_ocr.py CHANGED
@@ -1,286 +1,285 @@
1
- # model_ocr.py
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- import torch.optim as optim
7
- from torch.utils.data import DataLoader # Keep DataLoader for type hinting
8
- from tqdm import tqdm
9
- from sklearn.metrics import accuracy_score
10
- import editdistance
11
-
12
- # Import config and char_indexer
13
- from config import IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN
14
- from data_handler_ocr import CharIndexer
15
- from utils_ocr import binarize_image, resize_image_for_ocr, normalize_image_for_model
16
-
17
-
18
- class CNN_Backbone(nn.Module):
19
- """
20
- CNN feature extractor for OCR. Designed to produce features suitable for RNN.
21
- Output feature map should have height 1 after the final pooling/reduction.
22
- """
23
- def __init__(self, input_channels=1, output_channels=512):
24
- super(CNN_Backbone, self).__init__()
25
- self.cnn = nn.Sequential(
26
- # First block
27
- nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
28
- nn.ReLU(True),
29
- nn.MaxPool2d(kernel_size=2, stride=2), # H: 32 -> 16, W: W_in -> W_in/2
30
-
31
- # Second block
32
- nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
33
- nn.ReLU(True),
34
- nn.MaxPool2d(kernel_size=2, stride=2), # H: 16 -> 8, W: W_in/2 -> W_in/4
35
-
36
- # Third block (with two conv layers)
37
- nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
38
- nn.ReLU(True),
39
- nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
40
- nn.ReLU(True),
41
- # This MaxPool2d effectively brings height from 8 to 4, with a small width adjustment due to padding
42
- nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # H: 8 -> 4, W: (W/4) -> (W/4 + 1) (approx)
43
-
44
- # Fourth block
45
- nn.Conv2d(256, output_channels, kernel_size=3, stride=1, padding=1),
46
- nn.ReLU(True),
47
- # This AdaptiveAvgPool2d makes sure the height dimension becomes 1
48
- # while preserving the width. This is crucial for RNN input.
49
- nn.AdaptiveAvgPool2d((1, None)) # Output height 1, preserve width
50
- )
51
-
52
- def forward(self, x: torch.Tensor) -> torch.Tensor:
53
- # x: (N, C, H, W) e.g., (B, 1, 32, W_img)
54
-
55
- # Pass through the CNN layers
56
- conv_features = self.cnn(x) # Output: (N, cnn_out_channels, 1, W_prime)
57
-
58
- # Squeeze the height dimension (which is 1)
59
- # This transforms (N, C_out, 1, W_prime) to (N, C_out, W_prime)
60
- conv_features = conv_features.squeeze(2)
61
-
62
- # Permute for RNN input: (sequence_length, batch_size, input_size)
63
- # This transforms (N, C_out, W_prime) to (W_prime, N, C_out)
64
- conv_features = conv_features.permute(2, 0, 1)
65
-
66
- # Return the CNN features, ready for the RNN layer in CRNN
67
- return conv_features
68
-
69
- class BidirectionalLSTM(nn.Module):
70
- """Bidirectional LSTM layer for sequence modeling."""
71
- def __init__(self, input_size: int, hidden_size: int, num_layers: int, dropout: float = 0.5):
72
- super(BidirectionalLSTM, self).__init__()
73
- self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
74
- bidirectional=True, dropout=dropout, batch_first=False)
75
- # batch_first=False expects input as (sequence_length, batch_size, input_size)
76
-
77
- def forward(self, x: torch.Tensor) -> torch.Tensor:
78
- output, _ = self.lstm(x) # [0] returns the output, [1] returns (h_n, c_n)
79
- return output
80
-
81
- class CRNN(nn.Module):
82
- """
83
- Convolutional Recurrent Neural Network for OCR.
84
- Combines CNN for feature extraction, LSTMs for sequence modeling,
85
- and a final linear layer for character prediction.
86
- """
87
- def __init__(self, num_classes: int, cnn_output_channels: int = 512,
88
- rnn_hidden_size: int = 256, rnn_num_layers: int = 2):
89
- super(CRNN, self).__init__()
90
- self.cnn = CNN_Backbone(output_channels=cnn_output_channels)
91
- # Input to LSTM is the number of channels from the CNN output
92
- self.rnn = BidirectionalLSTM(cnn_output_channels, rnn_hidden_size, rnn_num_layers)
93
- # Output of bidirectional LSTM is hidden_size * 2
94
- self.fc = nn.Linear(rnn_hidden_size * 2, num_classes)
95
-
96
- def forward(self, x: torch.Tensor) -> torch.Tensor:
97
- # x: (N, C, H, W) e.g., (B, 1, 32, W_img)
98
-
99
- # 1. Pass through the CNN to extract features
100
- conv_features = self.cnn(x) # Output: (W_prime, N, C_out) after permute in CNN_Backbone
101
-
102
- # 2. Pass CNN features through the RNN (LSTM)
103
- rnn_features = self.rnn(conv_features) # Output: (W_prime, N, rnn_hidden_size * 2)
104
-
105
- # 3. Pass RNN features through the final fully connected layer
106
- # Apply the linear layer to each time step independently
107
- # output will be (W_prime, N, num_classes)
108
- output = self.fc(rnn_features)
109
-
110
- return output
111
-
112
-
113
- # --- Decoding Function ---
114
- def ctc_greedy_decode(output: torch.Tensor, char_indexer: CharIndexer) -> list[str]:
115
- """
116
- Performs greedy decoding on the CTC output.
117
- output: (sequence_length, batch_size, num_classes) - raw logits
118
- """
119
- # Apply log_softmax to get probabilities for argmax
120
- log_probs = F.log_softmax(output, dim=2)
121
-
122
- # Permute to (batch_size, sequence_length, num_classes) for argmax along class dim
123
- # This gives us the index of the most probable character at each time step for each sample in the batch.
124
- predicted_indices = torch.argmax(log_probs.permute(1, 0, 2), dim=2).cpu().numpy()
125
-
126
- decoded_texts = []
127
- for seq in predicted_indices:
128
- # Use char_indexer's decode method, which handles blank removal and duplicate collapse
129
- decoded_texts.append(char_indexer.decode(seq.tolist())) # Convert numpy array to list
130
- return decoded_texts
131
-
132
- # --- Evaluation Function ---
133
- def evaluate_model(model: nn.Module, dataloader: DataLoader, char_indexer: CharIndexer, device: str):
134
- model.eval() # Set model to evaluation mode
135
- # CTCLoss needs the blank token index, which is available from char_indexer
136
- criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
137
- total_loss = 0
138
- all_predictions = []
139
- all_ground_truths = []
140
-
141
- with torch.no_grad(): # Disable gradient calculation for evaluation
142
- for inputs, targets_padded, _, target_lengths in tqdm(dataloader, desc="Evaluating"):
143
- inputs = inputs.to(device)
144
- targets_padded = targets_padded.to(device)
145
- target_lengths = target_lengths.to(device)
146
-
147
- output = model(inputs) # (seq_len, batch_size, num_classes)
148
-
149
- # Calculate input_lengths for CTCLoss. This is the sequence length produced by the CNN/RNN.
150
- # It's the `output.shape[0]` (sequence_length) for each item in the batch.
151
- outputs_seq_len_for_ctc = torch.full(
152
- size=(output.shape[1],), # batch_size
153
- fill_value=output.shape[0], # actual sequence length (T) from model output
154
- dtype=torch.long,
155
- device=device
156
- )
157
-
158
- # CTC Loss calculation requires log_softmax on the output logits
159
- log_probs_for_loss = F.log_softmax(output, dim=2) # (T, N, C)
160
-
161
- loss = criterion(log_probs_for_loss, targets_padded, outputs_seq_len_for_ctc, target_lengths)
162
- total_loss += loss.item() * inputs.size(0) # Multiply by batch size for correct average
163
-
164
- # Decode predictions for metrics
165
- decoded_preds = ctc_greedy_decode(output, char_indexer)
166
-
167
- # Reconstruct ground truths from encoded tensors
168
- ground_truths = []
169
- # Loop through each sample in the batch
170
- for i in range(targets_padded.size(0)):
171
- # Extract the actual target sequence for the i-th sample using its length
172
- # Convert to list before passing to char_indexer.decode
173
- ground_truths.append(char_indexer.decode(targets_padded[i, :target_lengths[i]].tolist()))
174
-
175
- all_predictions.extend(decoded_preds)
176
- all_ground_truths.extend(ground_truths)
177
-
178
- avg_loss = total_loss / len(dataloader.dataset)
179
-
180
- # Calculate Character Error Rate (CER)
181
- cer_sum = 0
182
- total_chars = 0
183
- for pred, gt in zip(all_predictions, all_ground_truths):
184
- cer_sum += editdistance.eval(pred, gt)
185
- total_chars += len(gt)
186
- char_error_rate = cer_sum / total_chars if total_chars > 0 else 0.0
187
-
188
- # Calculate Exact Match Accuracy (Word-level Accuracy)
189
- exact_match_accuracy = accuracy_score(all_ground_truths, all_predictions)
190
-
191
- return avg_loss, char_error_rate, exact_match_accuracy
192
-
193
- # --- Training Function ---
194
- def train_ocr_model(model: nn.Module, train_loader: DataLoader,
195
- test_loader: DataLoader, char_indexer: CharIndexer,
196
- epochs: int, device: str, progress_callback=None) -> tuple[nn.Module, dict]:
197
- """
198
- Trains the OCR model using CTC loss.
199
- """
200
- # CTCLoss needs the blank token index
201
- criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
202
- optimizer = optim.Adam(model.parameters(), lr=0.001) # Using a fixed LR for now
203
- # Using ReduceLROnPlateau to adjust LR based on test loss (monitor 'min' loss)
204
- scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5) # Removed verbose=True
205
-
206
- model.to(device) # Ensure model is on the correct device
207
- model.train() # Set model to training mode
208
-
209
- training_history = {
210
- 'train_loss': [],
211
- 'test_loss': [],
212
- 'test_cer': [],
213
- 'test_exact_match_accuracy': []
214
- }
215
-
216
- for epoch in range(epochs):
217
- running_loss = 0.0
218
- pbar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)")
219
- for images, texts_encoded, _, text_lengths in pbar_train:
220
- images = images.to(device)
221
- # Ensure target tensors are on the correct device for CTCLoss calculation
222
- texts_encoded = texts_encoded.to(device)
223
- text_lengths = text_lengths.to(device)
224
-
225
- optimizer.zero_grad() # Clear gradients from previous step
226
- outputs = model(images) # (sequence_length_from_cnn, batch_size, num_classes)
227
-
228
- # `outputs.shape[0]` is the actual sequence length (T) produced by the model.
229
- # CTC loss expects `input_lengths` to be a tensor of shape (batch_size,) with these values.
230
- outputs_seq_len_for_ctc = torch.full(
231
- size=(outputs.shape[1],), # batch_size
232
- fill_value=outputs.shape[0], # actual sequence length (T) from model output
233
- dtype=torch.long,
234
- device=device
235
- )
236
-
237
- # CTC Loss calculation requires log_softmax on the output logits
238
- log_probs_for_loss = F.log_softmax(outputs, dim=2) # (T, N, C)
239
-
240
- # Use outputs_seq_len_for_ctc for the input_lengths argument
241
- loss = criterion(log_probs_for_loss, texts_encoded, outputs_seq_len_for_ctc, text_lengths)
242
- loss.backward() # Backpropagate
243
- optimizer.step() # Update model weights
244
-
245
- running_loss += loss.item() * images.size(0) # Multiply by batch size for correct average
246
- pbar_train.set_postfix(loss=loss.item())
247
-
248
- epoch_train_loss = running_loss / len(train_loader.dataset)
249
- training_history['train_loss'].append(epoch_train_loss)
250
-
251
- # Evaluate on test set using the dedicated function
252
- # Ensure model is in eval mode before calling evaluate_model
253
- model.eval()
254
- test_loss, test_cer, test_exact_match_accuracy = evaluate_model(model, test_loader, char_indexer, device)
255
- training_history['test_loss'].append(test_loss)
256
- training_history['test_cer'].append(test_cer)
257
- training_history['test_exact_match_accuracy'].append(test_exact_match_accuracy)
258
-
259
- # Adjust learning rate based on test loss (this is where scheduler.step() is called)
260
- scheduler.step(test_loss)
261
-
262
- print(f"Epoch {epoch+1}/{epochs}: Train Loss={epoch_train_loss:.4f}, "
263
- f"Test Loss={test_loss:.4f}, Test CER={test_cer:.4f}, Test Exact Match Acc={test_exact_match_accuracy:.4f}")
264
-
265
- if progress_callback:
266
- # Update progress bar with current epoch and key metrics
267
- progress_val = (epoch + 1) / epochs
268
- progress_callback(progress_val, text=f"Epoch {epoch+1}/{epochs} done. Test CER: {test_cer:.4f}, Test Exact Match Acc: {test_exact_match_accuracy:.4f}")
269
-
270
- model.train() # Set model back to training mode after evaluation
271
-
272
- return model, training_history
273
-
274
- def save_ocr_model(model: nn.Module, path: str):
275
- """Saves the state dictionary of the trained OCR model."""
276
- torch.save(model.state_dict(), path)
277
- print(f"OCR model saved to {path}")
278
-
279
- def load_ocr_model(model: nn.Module, path: str):
280
- """
281
- Loads a trained OCR model's state dictionary.
282
- Includes map_location to handle loading models trained on GPU to CPU, and vice versa.
283
- """
284
- model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
285
- model.eval() # Set to evaluation mode
286
- print(f"OCR model loaded from {path}")
 
1
+ # model_ocr.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.optim as optim
6
+ from torch.utils.data import DataLoader
7
+ from tqdm import tqdm
8
+ from sklearn.metrics import accuracy_score
9
+ import editdistance
10
+
11
+ # Import config and char_indexer
12
+ from config import IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN
13
+ from data_handler_ocr import CharIndexer
14
+ from utils_ocr import binarize_image, resize_image_for_ocr, normalize_image_for_model
15
+
16
+
17
+ class CNN_Backbone(nn.Module):
18
+ """
19
+ CNN feature extractor for OCR. Designed to produce features suitable for RNN.
20
+ Output feature map should have height 1 after the final pooling/reduction.
21
+ """
22
+ def __init__(self, input_channels=1, output_channels=512):
23
+ super(CNN_Backbone, self).__init__()
24
+ self.cnn = nn.Sequential(
25
+ # First block
26
+ nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
27
+ nn.ReLU(True),
28
+ nn.MaxPool2d(kernel_size=2, stride=2), # H: 32 -> 16, W: W_in -> W_in/2
29
+
30
+ # Second block
31
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
32
+ nn.ReLU(True),
33
+ nn.MaxPool2d(kernel_size=2, stride=2), # H: 16 -> 8, W: W_in/2 -> W_in/4
34
+
35
+ # Third block (with two conv layers)
36
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
37
+ nn.ReLU(True),
38
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
39
+ nn.ReLU(True),
40
+ # This MaxPool2d effectively brings height from 8 to 4, with a small width adjustment due to padding
41
+ nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # H: 8 -> 4, W: (W/4) -> (W/4 + 1) (approx)
42
+
43
+ # Fourth block
44
+ nn.Conv2d(256, output_channels, kernel_size=3, stride=1, padding=1),
45
+ nn.ReLU(True),
46
+ # This AdaptiveAvgPool2d makes sure the height dimension becomes 1
47
+ # while preserving the width. This is crucial for RNN input.
48
+ nn.AdaptiveAvgPool2d((1, None)) # Output height 1, preserve width
49
+ )
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ # x: (N, C, H, W) e.g., (B, 1, 32, W_img)
53
+
54
+ # Pass through the CNN layers
55
+ conv_features = self.cnn(x) # Output: (N, cnn_out_channels, 1, W_prime)
56
+
57
+ # Squeeze the height dimension (which is 1)
58
+ # This transforms (N, C_out, 1, W_prime) to (N, C_out, W_prime)
59
+ conv_features = conv_features.squeeze(2)
60
+
61
+ # Permute for RNN input: (sequence_length, batch_size, input_size)
62
+ # This transforms (N, C_out, W_prime) to (W_prime, N, C_out)
63
+ conv_features = conv_features.permute(2, 0, 1)
64
+
65
+ # Return the CNN features, ready for the RNN layer in CRNN
66
+ return conv_features
67
+
68
+ class BidirectionalLSTM(nn.Module):
69
+ """Bidirectional LSTM layer for sequence modeling."""
70
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int, dropout: float = 0.5):
71
+ super(BidirectionalLSTM, self).__init__()
72
+ self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
73
+ bidirectional=True, dropout=dropout, batch_first=False)
74
+ # batch_first=False expects input as (sequence_length, batch_size, input_size)
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ output, _ = self.lstm(x) # [0] returns the output, [1] returns (h_n, c_n)
78
+ return output
79
+
80
+ class CRNN(nn.Module):
81
+ """
82
+ Convolutional Recurrent Neural Network for OCR.
83
+ Combines CNN for feature extraction, LSTMs for sequence modeling,
84
+ and a final linear layer for character prediction.
85
+ """
86
+ def __init__(self, num_classes: int, cnn_output_channels: int = 512,
87
+ rnn_hidden_size: int = 256, rnn_num_layers: int = 2): # Corrected parameter name
88
+ super(CRNN, self).__init__()
89
+ self.cnn = CNN_Backbone(output_channels=cnn_output_channels)
90
+ # Input to LSTM is the number of channels from the CNN output
91
+ self.rnn = BidirectionalLSTM(cnn_output_channels, rnn_hidden_size, rnn_num_layers) # Corrected usage
92
+ # Output of bidirectional LSTM is hidden_size * 2
93
+ self.fc = nn.Linear(rnn_hidden_size * 2, num_classes) # Final linear layer for classes
94
+
95
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
96
+ # x: (N, C, H, W) e.g., (B, 1, 32, W_img)
97
+
98
+ # 1. Pass through the CNN to extract features
99
+ conv_features = self.cnn(x) # Output: (W_prime, N, C_out) after permute in CNN_Backbone
100
+
101
+ # 2. Pass CNN features through the RNN (LSTM)
102
+ rnn_features = self.rnn(conv_features) # Output: (W_prime, N, rnn_hidden_size * 2)
103
+
104
+ # 3. Pass RNN features through the final fully connected layer
105
+ # Apply the linear layer to each time step independently
106
+ # output will be (W_prime, N, num_classes)
107
+ output = self.fc(rnn_features)
108
+
109
+ return output
110
+
111
+
112
+ # --- Decoding Function ---
113
+ def ctc_greedy_decode(output: torch.Tensor, char_indexer: CharIndexer) -> list[str]:
114
+ """
115
+ Performs greedy decoding on the CTC output.
116
+ output: (sequence_length, batch_size, num_classes) - raw logits
117
+ """
118
+ # Apply log_softmax to get probabilities for argmax
119
+ log_probs = F.log_softmax(output, dim=2)
120
+
121
+ # Permute to (batch_size, sequence_length, num_classes) for argmax along class dim
122
+ predicted_indices = torch.argmax(log_probs.permute(1, 0, 2), dim=2).cpu().numpy()
123
+
124
+ decoded_texts = []
125
+ for seq in predicted_indices:
126
+ # Use char_indexer's decode method, which handles blank removal and duplicate collapse
127
+ decoded_texts.append(char_indexer.decode(seq.tolist()))
128
+ return decoded_texts
129
+
130
+ # --- Evaluation Function ---
131
+ def evaluate_model(model: nn.Module, dataloader: DataLoader, char_indexer: CharIndexer, device: str):
132
+ model.eval()
133
+ criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
134
+ total_loss = 0
135
+ all_predictions = []
136
+ all_ground_truths = []
137
+
138
+ with torch.no_grad():
139
+ for inputs, targets_padded, _, target_lengths in tqdm(dataloader, desc="Evaluating"):
140
+ inputs = inputs.to(device)
141
+ targets_padded = targets_padded.to(device)
142
+ target_lengths_tensor = target_lengths.to(device)
143
+
144
+ output = model(inputs)
145
+
146
+ outputs_seq_len_for_ctc = torch.full(
147
+ size=(output.shape[1],),
148
+ fill_value=output.shape[0],
149
+ dtype=torch.long,
150
+ device=device
151
+ )
152
+
153
+ # CTC Loss calculation requires log_softmax on the output logits
154
+ log_probs_for_loss = F.log_softmax(output, dim=2)
155
+
156
+ # CTCLoss expects targets_padded as a 1D tensor and target_lengths_tensor as corresponding lengths
157
+ loss = criterion(log_probs_for_loss, targets_padded, outputs_seq_len_for_ctc, target_lengths_tensor)
158
+ total_loss += loss.item() * inputs.size(0)
159
+
160
+ decoded_preds = ctc_greedy_decode(output, char_indexer)
161
+ all_predictions.extend(decoded_preds)
162
+
163
+ ground_truths_batch = []
164
+ current_idx_in_concatenated_targets = 0
165
+
166
+ target_lengths_list = target_lengths.cpu().tolist()
167
+
168
+ for i in range(inputs.size(0)):
169
+ length = target_lengths_list[i]
170
+
171
+ current_target_segment = targets_padded[current_idx_in_concatenated_targets : current_idx_in_concatenated_targets + length].tolist()
172
+ ground_truths_batch.append(char_indexer.decode(current_target_segment))
173
+ current_idx_in_concatenated_targets += length
174
+
175
+ all_ground_truths.extend(ground_truths_batch)
176
+
177
+ avg_loss = total_loss / len(dataloader.dataset)
178
+
179
+ # Calculate Character Error Rate (CER)
180
+ cer_sum = 0
181
+ total_chars = 0
182
+ for pred, gt in zip(all_predictions, all_ground_truths):
183
+ cer_sum += editdistance.eval(pred, gt)
184
+ total_chars += len(gt)
185
+ char_error_rate = cer_sum / total_chars if total_chars > 0 else 0.0
186
+
187
+ # Calculate Exact Match Accuracy (Word-level Accuracy)
188
+ exact_match_accuracy = accuracy_score(all_ground_truths, all_predictions)
189
+
190
+ return avg_loss, char_error_rate, exact_match_accuracy
191
+
192
+ # --- Training Function ---
193
+ def train_ocr_model(model: nn.Module, train_loader: DataLoader,
194
+ test_loader: DataLoader, char_indexer: CharIndexer,
195
+ epochs: int, device: str, progress_callback=None) -> tuple[nn.Module, dict]:
196
+ """
197
+ Trains the OCR model using CTC loss.
198
+ """
199
+ # CTCLoss needs the blank token index
200
+ criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
201
+ optimizer = optim.Adam(model.parameters(), lr=0.001) # Using a fixed LR for now
202
+ # Using ReduceLROnPlateau to adjust LR based on test loss (monitor 'min' loss)
203
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5) # Removed verbose=True
204
+
205
+ model.to(device) # Ensure model is on the correct device
206
+ model.train() # Set model to training mode
207
+
208
+ training_history = {
209
+ 'train_loss': [],
210
+ 'test_loss': [],
211
+ 'test_cer': [],
212
+ 'test_exact_match_accuracy': []
213
+ }
214
+
215
+ for epoch in range(epochs):
216
+ running_loss = 0.0
217
+ pbar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)")
218
+ for images, texts_encoded, _, text_lengths in pbar_train:
219
+ images = images.to(device)
220
+ # Ensure target tensors are on the correct device for CTCLoss calculation
221
+ texts_encoded = texts_encoded.to(device)
222
+ text_lengths = text_lengths.to(device)
223
+
224
+ optimizer.zero_grad() # Clear gradients from previous step
225
+ outputs = model(images) # (sequence_length_from_cnn, batch_size, num_classes)
226
+
227
+ # `outputs.shape[0]` is the actual sequence length (T) produced by the model.
228
+ # CTC loss expects `input_lengths` to be a tensor of shape (batch_size,) with these values.
229
+ outputs_seq_len_for_ctc = torch.full(
230
+ size=(outputs.shape[1],), # batch_size
231
+ fill_value=outputs.shape[0], # actual sequence length (T) from model output
232
+ dtype=torch.long,
233
+ device=device
234
+ )
235
+
236
+ # CTC Loss calculation requires log_softmax on the output logits
237
+ log_probs_for_loss = F.log_softmax(outputs, dim=2) # (T, N, C)
238
+
239
+ # Use outputs_seq_len_for_ctc for the input_lengths argument
240
+ loss = criterion(log_probs_for_loss, texts_encoded, outputs_seq_len_for_ctc, text_lengths)
241
+ loss.backward() # Backpropagate
242
+ optimizer.step() # Update model weights
243
+
244
+ running_loss += loss.item() * images.size(0) # Multiply by batch size for correct average
245
+ pbar_train.set_postfix(loss=loss.item())
246
+
247
+ epoch_train_loss = running_loss / len(train_loader.dataset)
248
+ training_history['train_loss'].append(epoch_train_loss)
249
+
250
+ # Evaluate on test set using the dedicated function
251
+ # Ensure model is in eval mode before calling evaluate_model
252
+ model.eval()
253
+ test_loss, test_cer, test_exact_match_accuracy = evaluate_model(model, test_loader, char_indexer, device)
254
+ training_history['test_loss'].append(test_loss)
255
+ training_history['test_cer'].append(test_cer)
256
+ training_history['test_exact_match_accuracy'].append(test_exact_match_accuracy)
257
+
258
+ # Adjust learning rate based on test loss
259
+ scheduler.step(test_loss)
260
+
261
+ print(f"Epoch {epoch+1}/{epochs}: Train Loss={epoch_train_loss:.4f}, "
262
+ f"Test Loss={test_loss:.4f}, Test CER={test_cer:.4f}, Test Exact Match Acc={test_exact_match_accuracy:.4f}")
263
+
264
+ if progress_callback:
265
+ # Update progress bar with current epoch and key metrics
266
+ progress_val = (epoch + 1) / epochs
267
+ progress_callback(progress_val, text=f"Epoch {epoch+1}/{epochs} done. Test CER: {test_cer:.4f}, Test Exact Match Acc: {test_exact_match_accuracy:.4f}")
268
+
269
+ model.train() # Set model back to training mode after evaluation
270
+
271
+ return model, training_history
272
+
273
+ def save_ocr_model(model: nn.Module, path: str):
274
+ """Saves the state dictionary of the trained OCR model."""
275
+ torch.save(model.state_dict(), path)
276
+ print(f"OCR model saved to {path}")
277
+
278
+ def load_ocr_model(model: nn.Module, path: str):
279
+ """
280
+ Loads a trained OCR model's state dictionary.
281
+ Includes map_location to handle loading models trained on GPU to CPU, and vice versa.
282
+ """
283
+ model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) # Always load to CPU first
284
+ model.eval() # Set to evaluation mode
285
+ print(f"OCR model loaded from {path}")
 
utils_ocr.py CHANGED
@@ -1,184 +1,83 @@
1
- <<<<<<< HEAD
2
  #utils_ocr.py
3
 
4
  import cv2
5
- from matplotlib.pylab import f
6
  import numpy as np
7
  from PIL import Image
8
  import torch
9
- from torchvision import transforms
 
10
 
11
- # --- Image Preprocessing for OCR ---
 
 
 
12
 
13
  def load_image_as_grayscale(image_path: str) -> Image.Image:
14
  """Loads an image from path and converts it to grayscale PIL Image."""
15
- # Use PIL for robust image loading and conversion to grayscale 'L' mode
16
- img = Image.open(image_path).convert('L')
17
- return img
18
-
19
- def binarize_image(image_pil: Image.Image) -> Image.Image:
20
- """Binarizes a grayscale PIL Image (black and white)."""
21
- # Convert PIL to OpenCV format (numpy array)
22
- img_np = np.array(image_pil)
23
- # Apply Otsu's thresholding for adaptive binarization
24
- _, img_bin = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
25
- # Invert colors: Handwritten text usually dark on light. OCR models often
26
- # prefer light text on dark background. Check your training data's style.
27
- # This example assumes dark text on light background and inverts to white text on black.
28
- img_bin = 255 - img_bin
29
- return Image.fromarray(img_bin)
30
 
31
- def resize_image_for_ocr(image_pil: Image.Image, target_height: int) -> Image.Image:
32
  """
33
- Resizes a PIL Image to a target height while maintaining aspect ratio.
34
- Pads width if necessary to avoid distortion.
35
  """
36
- original_width, original_height = image_pil.size
37
- # Calculate new width based on target height and original aspect ratio
38
- new_width = int(original_width * (target_height / original_height))
39
- resized_img = image_pil.resize((new_width, target_height), Image.LANCZOS)
40
- return resized_img
 
 
 
41
 
42
- def normalize_image_for_model(image_pil: Image.Image) -> torch.Tensor:
43
  """
44
- Converts a PIL Image to a PyTorch Tensor and normalizes pixel values.
 
45
  """
46
- # Convert to tensor (scales to 0-1 automatically)
47
- tensor_transform = transforms.ToTensor()
48
- img_tensor = tensor_transform(image_pil)
49
- # For grayscale images, mean and std are single values.
50
- # Adjust normalization values if your training data uses different ones.
51
- img_tensor = transforms.Normalize((0.5,), (0.5,))(img_tensor) # Normalize to [-1, 1]
52
- return img_tensor
 
 
 
 
 
 
 
53
 
54
- def preprocess_user_image_for_ocr(uploaded_image_pil: Image.Image, target_height: int) -> torch.Tensor:
55
  """
56
- Combines all preprocessing steps for a single user-uploaded image
57
- to prepare it for the OCR model.
 
58
  """
59
- # Ensure it's grayscale
60
- img_gray = uploaded_image_pil.convert('L')
61
-
62
- # Binarize
63
- img_bin = binarize_image(img_gray)
64
-
65
- # Resize (maintain aspect ratio)
66
- img_resized = resize_image_for_ocr(img_bin, target_height)
67
-
68
- # Normalize and convert to tensor
69
- img_tensor = normalize_image_for_model(img_resized)
70
-
71
- # Add batch dimension: (C, H, W) -> (1, C, H, W)
72
- img_tensor = img_tensor.unsqueeze(0)
73
-
74
  return img_tensor
75
 
76
- def pad_image_tensor(image_tensor: torch.Tensor, max_width: int) -> torch.Tensor:
77
  """
78
- Pads a single image tensor to a max_width with zeros.
79
- Input tensor shape: (C, H, W)
80
- Output tensor shape: (C, H, max_width)
81
  """
82
- C, H, W = image_tensor.shape
83
- if W > max_width:
84
- # If image is wider than max_width, you might want to crop or resize it.
85
- # For this example, we'll just return a warning or clip.
86
- # A more robust solution might split text lines or use a different resizing strategy.
87
- print(f"Warning: Image width {W} exceeds max_width {max_width}. Cropping.")
88
- return image_tensor[:, :, :max_width] # Simple cropping
89
- padding = max_width - W
90
- # Pad on the right (P_left, P_right, P_top, P_bottom)
91
- padded_tensor = f.pad(image_tensor, (0, padding), 'constant', 0)
92
- =======
93
- #utils_ocr.py
94
-
95
- import cv2
96
- from matplotlib.pylab import f
97
- import numpy as np
98
- from PIL import Image
99
- import torch
100
- from torchvision import transforms
101
-
102
- # --- Image Preprocessing for OCR ---
103
-
104
- def load_image_as_grayscale(image_path: str) -> Image.Image:
105
- """Loads an image from path and converts it to grayscale PIL Image."""
106
- # Use PIL for robust image loading and conversion to grayscale 'L' mode
107
- img = Image.open(image_path).convert('L')
108
- return img
109
-
110
- def binarize_image(image_pil: Image.Image) -> Image.Image:
111
- """Binarizes a grayscale PIL Image (black and white)."""
112
- # Convert PIL to OpenCV format (numpy array)
113
- img_np = np.array(image_pil)
114
- # Apply Otsu's thresholding for adaptive binarization
115
- _, img_bin = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
116
- # Invert colors: Handwritten text usually dark on light. OCR models often
117
- # prefer light text on dark background. Check your training data's style.
118
- # This example assumes dark text on light background and inverts to white text on black.
119
- img_bin = 255 - img_bin
120
- return Image.fromarray(img_bin)
121
-
122
- def resize_image_for_ocr(image_pil: Image.Image, target_height: int) -> Image.Image:
123
- """
124
- Resizes a PIL Image to a target height while maintaining aspect ratio.
125
- Pads width if necessary to avoid distortion.
126
- """
127
- original_width, original_height = image_pil.size
128
- # Calculate new width based on target height and original aspect ratio
129
- new_width = int(original_width * (target_height / original_height))
130
- resized_img = image_pil.resize((new_width, target_height), Image.LANCZOS)
131
- return resized_img
132
-
133
- def normalize_image_for_model(image_pil: Image.Image) -> torch.Tensor:
134
- """
135
- Converts a PIL Image to a PyTorch Tensor and normalizes pixel values.
136
- """
137
- # Convert to tensor (scales to 0-1 automatically)
138
- tensor_transform = transforms.ToTensor()
139
- img_tensor = tensor_transform(image_pil)
140
- # For grayscale images, mean and std are single values.
141
- # Adjust normalization values if your training data uses different ones.
142
- img_tensor = transforms.Normalize((0.5,), (0.5,))(img_tensor) # Normalize to [-1, 1]
143
- return img_tensor
144
-
145
- def preprocess_user_image_for_ocr(uploaded_image_pil: Image.Image, target_height: int) -> torch.Tensor:
146
- """
147
- Combines all preprocessing steps for a single user-uploaded image
148
- to prepare it for the OCR model.
149
- """
150
- # Ensure it's grayscale
151
- img_gray = uploaded_image_pil.convert('L')
152
-
153
- # Binarize
154
- img_bin = binarize_image(img_gray)
155
-
156
- # Resize (maintain aspect ratio)
157
- img_resized = resize_image_for_ocr(img_bin, target_height)
158
-
159
- # Normalize and convert to tensor
160
- img_tensor = normalize_image_for_model(img_resized)
161
-
162
- # Add batch dimension: (C, H, W) -> (1, C, H, W)
163
- img_tensor = img_tensor.unsqueeze(0)
164
-
165
- return img_tensor
166
-
167
- def pad_image_tensor(image_tensor: torch.Tensor, max_width: int) -> torch.Tensor:
168
- """
169
- Pads a single image tensor to a max_width with zeros.
170
- Input tensor shape: (C, H, W)
171
- Output tensor shape: (C, H, max_width)
172
- """
173
- C, H, W = image_tensor.shape
174
- if W > max_width:
175
- # If image is wider than max_width, you might want to crop or resize it.
176
- # For this example, we'll just return a warning or clip.
177
- # A more robust solution might split text lines or use a different resizing strategy.
178
- print(f"Warning: Image width {W} exceeds max_width {max_width}. Cropping.")
179
- return image_tensor[:, :, :max_width] # Simple cropping
180
- padding = max_width - W
181
- # Pad on the right (P_left, P_right, P_top, P_bottom)
182
- padded_tensor = f.pad(image_tensor, (0, padding), 'constant', 0)
183
- >>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
184
- return padded_tensor
 
 
1
  #utils_ocr.py
2
 
3
  import cv2
 
4
  import numpy as np
5
  from PIL import Image
6
  import torch
7
+ import torchvision.transforms as transforms
8
+ import os
9
 
10
+ # Import config for IMG_HEIGHT and MAX_IMG_WIDTH
11
+ from config import IMG_HEIGHT, MAX_IMG_WIDTH
12
+
13
+ # --- Image Preprocessing Functions ---
14
 
15
  def load_image_as_grayscale(image_path: str) -> Image.Image:
16
  """Loads an image from path and converts it to grayscale PIL Image."""
17
+ if not os.path.exists(image_path):
18
+ raise FileNotFoundError(f"Image not found at: {image_path}")
19
+ return Image.open(image_path).convert('L') # 'L' for grayscale
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ def binarize_image(img: Image.Image) -> Image.Image:
22
  """
23
+ Binarizes a grayscale PIL Image using Otsu's method.
24
+ Returns a PIL Image.
25
  """
26
+ # Convert PIL Image to OpenCV format (numpy array)
27
+ img_np = np.array(img)
28
+
29
+ # Apply Otsu's binarization
30
+ _, binary_img = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
31
+
32
+ # Convert back to PIL Image
33
+ return Image.fromarray(binary_img)
34
 
35
+ def resize_image_for_ocr(img: Image.Image, img_height: int) -> Image.Image:
36
  """
37
+ Resizes a PIL Image to a fixed height while maintaining aspect ratio.
38
+ Also ensures the width does not exceed MAX_IMG_WIDTH.
39
  """
40
+ width, height = img.size
41
+
42
+ # Calculate new width based on target height, maintaining aspect ratio
43
+ new_width = int(width * (img_height / height))
44
+
45
+ if new_width > MAX_IMG_WIDTH:
46
+ new_width = MAX_IMG_WIDTH
47
+ resized_img = img.resize((new_width, img_height), Image.Resampling.LANCZOS)
48
+ if resized_img.width > MAX_IMG_WIDTH:
49
+ # Crop the image from the left to MAX_IMG_WIDTH
50
+ resized_img = resized_img.crop((0, 0, MAX_IMG_WIDTH, img_height))
51
+ return resized_img
52
+
53
+ return img.resize((new_width, img_height), Image.Resampling.LANCZOS) # Use LANCZOS for high-quality downsampling
54
 
55
+ def normalize_image_for_model(img_tensor: torch.Tensor) -> torch.Tensor:
56
  """
57
+ Normalizes a torch.Tensor image (grayscale) for input into the model.
58
+ Puts pixel values in range [-1, 1].
59
+ Assumes image is already a torch.Tensor with values in [0, 1] (e.g., after ToTensor).
60
  """
61
+ # Formula: (pixel_value - mean) / std_dev
62
+ # For [0, 1] to [-1, 1], mean = 0.5, std_dev = 0.5
63
+ img_tensor = (img_tensor - 0.5) / 0.5
 
 
 
 
 
 
 
 
 
 
 
 
64
  return img_tensor
65
 
66
+ def preprocess_user_image_for_ocr(image_pil: Image.Image, target_height: int) -> torch.Tensor:
67
  """
68
+ Applies all necessary preprocessing steps to a user-uploaded PIL Image
69
+ to prepare it for the OCR model.
 
70
  """
71
+ # Define a transformation pipeline similar to the dataset, but including ToTensor
72
+ transform_pipeline = transforms.Compose([
73
+ transforms.Lambda(lambda img: binarize_image(img)), # PIL Image -> PIL Image
74
+ # Use the updated resize function that also handles MAX_IMG_WIDTH
75
+ transforms.Lambda(lambda img: resize_image_for_ocr(img, target_height)), # PIL Image -> PIL Image
76
+ transforms.ToTensor(), # PIL Image -> Tensor [0, 1]
77
+ transforms.Lambda(normalize_image_for_model) # Tensor [0, 1] -> Tensor [-1, 1]
78
+ ])
79
+
80
+ processed_image = transform_pipeline(image_pil)
81
+
82
+ # Add a batch dimension (C, H, W) -> (1, C, H, W) for single image inference
83
+ return processed_image.unsqueeze(0)