Spaces:
Build error
Build error
Training Model Complete
Browse files- app.py +227 -219
- config.py +13 -77
- data_handler_ocr.py +165 -151
- model_ocr.py +285 -286
- utils_ocr.py +60 -161
app.py
CHANGED
@@ -1,219 +1,227 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# app.py
|
3 |
-
|
4 |
-
import os
|
5 |
-
#
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
import
|
10 |
-
import
|
11 |
-
|
12 |
-
|
13 |
-
import torch
|
14 |
-
import
|
15 |
-
import
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
from
|
27 |
-
from
|
28 |
-
|
29 |
-
|
30 |
-
# --- Global Variables ---
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
st.
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
#
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
st.sidebar.
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
#
|
80 |
-
ocr_model
|
81 |
-
#
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
st.sidebar.
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
st.
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# app.py
|
3 |
+
|
4 |
+
import os
|
5 |
+
# Disable Streamlit file watcher to prevent conflicts with PyTorch
|
6 |
+
os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false"
|
7 |
+
|
8 |
+
import streamlit as st
|
9 |
+
import pandas as pd
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import torchvision.transforms as transforms
|
15 |
+
import traceback
|
16 |
+
|
17 |
+
# Import all necessary configuration values from config.py
|
18 |
+
from config import (
|
19 |
+
IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN, VOCABULARY, BLANK_TOKEN_SYMBOL,
|
20 |
+
TRAIN_CSV_PATH, TEST_CSV_PATH, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
|
21 |
+
MODEL_SAVE_PATH, BATCH_SIZE, NUM_EPOCHS
|
22 |
+
)
|
23 |
+
|
24 |
+
# Import classes and functions from data_handler_ocr.py and model_ocr.py
|
25 |
+
from data_handler_ocr import CharIndexer, OCRDataset, ocr_collate_fn, load_ocr_dataframes, create_ocr_dataloaders
|
26 |
+
from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
|
27 |
+
from utils_ocr import preprocess_user_image_for_ocr, binarize_image, resize_image_for_ocr, normalize_image_for_model
|
28 |
+
|
29 |
+
|
30 |
+
# --- Global Variables ---
|
31 |
+
ocr_model = None
|
32 |
+
char_indexer = None
|
33 |
+
training_history = None
|
34 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
|
36 |
+
# --- Streamlit App Setup ---
|
37 |
+
st.set_page_config(layout="wide", page_title="Handwritten Name OCR App",)
|
38 |
+
|
39 |
+
|
40 |
+
st.title("π Handwritten Name Recognition (OCR) App")
|
41 |
+
st.markdown("""
|
42 |
+
This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
|
43 |
+
Optical Character Recognition (OCR) on handwritten names. You can upload an image
|
44 |
+
of a handwritten name for prediction or train a new model using the provided dataset.
|
45 |
+
|
46 |
+
**Note:** Training a robust OCR model can be time-consuming.
|
47 |
+
""")
|
48 |
+
|
49 |
+
# --- Initialize CharIndexer ---
|
50 |
+
# This initializes char_indexer once when the script starts
|
51 |
+
char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
|
52 |
+
|
53 |
+
# --- Model Loading / Initialization ---
|
54 |
+
@st.cache_resource # Cache the model to prevent reloading on every rerun
|
55 |
+
def get_and_load_ocr_model_cached(num_classes, model_path):
|
56 |
+
"""
|
57 |
+
Initializes the OCR model and attempts to load a pre-trained model.
|
58 |
+
If no pre-trained model exists, a new model instance is returned.
|
59 |
+
"""
|
60 |
+
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
|
61 |
+
|
62 |
+
if os.path.exists(model_path):
|
63 |
+
st.sidebar.info("Loading pre-trained OCR model...")
|
64 |
+
try:
|
65 |
+
# Load model to CPU first, then move to device
|
66 |
+
model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
67 |
+
st.sidebar.success("OCR model loaded successfully!")
|
68 |
+
except Exception as e:
|
69 |
+
st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
|
70 |
+
# If loading fails, re-initialize an untrained model
|
71 |
+
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
|
72 |
+
else:
|
73 |
+
st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
|
74 |
+
|
75 |
+
return model_instance
|
76 |
+
|
77 |
+
# Get the model instance and assign it to the global 'ocr_model'
|
78 |
+
ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
|
79 |
+
# Ensure the model is on the correct device for inference
|
80 |
+
ocr_model.to(device)
|
81 |
+
ocr_model.eval() # Set model to evaluation mode for inference by default
|
82 |
+
|
83 |
+
|
84 |
+
# --- Sidebar for Model Training ---
|
85 |
+
st.sidebar.header("Train OCR Model")
|
86 |
+
st.sidebar.write("Click the button below to start training the OCR model.")
|
87 |
+
|
88 |
+
# Progress bar and label for training in the sidebar
|
89 |
+
progress_bar_sidebar = st.sidebar.progress(0)
|
90 |
+
progress_label_sidebar = st.sidebar.empty()
|
91 |
+
|
92 |
+
def update_progress_callback_sidebar(value, text):
|
93 |
+
progress_bar_sidebar.progress(int(value * 100))
|
94 |
+
progress_label_sidebar.text(text)
|
95 |
+
|
96 |
+
if st.sidebar.button("π Start Training"):
|
97 |
+
progress_bar_sidebar.progress(0)
|
98 |
+
progress_label_sidebar.empty()
|
99 |
+
st.empty()
|
100 |
+
|
101 |
+
if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
|
102 |
+
st.sidebar.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found!")
|
103 |
+
elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
|
104 |
+
st.sidebar.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
|
105 |
+
"Evaluation might be affected or skipped. Please ensure all data paths are correct.")
|
106 |
+
else:
|
107 |
+
st.sidebar.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
|
108 |
+
|
109 |
+
try:
|
110 |
+
train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
|
111 |
+
st.sidebar.success("Training and Test DataFrames loaded successfully.")
|
112 |
+
|
113 |
+
st.sidebar.success(f"CharIndexer initialized with {char_indexer.num_classes} classes.")
|
114 |
+
|
115 |
+
train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer, BATCH_SIZE)
|
116 |
+
st.sidebar.success("DataLoaders created successfully.")
|
117 |
+
|
118 |
+
ocr_model.train()
|
119 |
+
|
120 |
+
st.sidebar.write("Training in progress... This may take a while.")
|
121 |
+
ocr_model, training_history = train_ocr_model(
|
122 |
+
model=ocr_model,
|
123 |
+
train_loader=train_loader,
|
124 |
+
test_loader=test_loader,
|
125 |
+
char_indexer=char_indexer,
|
126 |
+
epochs=NUM_EPOCHS,
|
127 |
+
device=device,
|
128 |
+
progress_callback=update_progress_callback_sidebar
|
129 |
+
)
|
130 |
+
st.sidebar.success("OCR model training finished!")
|
131 |
+
update_progress_callback_sidebar(1.0, "Training complete!")
|
132 |
+
|
133 |
+
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
|
134 |
+
save_ocr_model(ocr_model, MODEL_SAVE_PATH)
|
135 |
+
st.sidebar.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
|
136 |
+
|
137 |
+
except Exception as e:
|
138 |
+
st.sidebar.error(f"An error occurred during training: {e}")
|
139 |
+
st.exception(e)
|
140 |
+
update_progress_callback_sidebar(0.0, "Training failed!")
|
141 |
+
|
142 |
+
# --- Sidebar for Model Loading ---
|
143 |
+
st.sidebar.header("Load Pre-trained Model")
|
144 |
+
st.sidebar.write("If you have a saved model, you can load it here instead of training.")
|
145 |
+
|
146 |
+
if st.sidebar.button("πΎ Load Model"):
|
147 |
+
if os.path.exists(MODEL_SAVE_PATH):
|
148 |
+
try:
|
149 |
+
loaded_model = CRNN(num_classes=char_indexer.num_classes)
|
150 |
+
load_ocr_model(loaded_model, MODEL_SAVE_PATH)
|
151 |
+
loaded_model.to(device)
|
152 |
+
|
153 |
+
st.sidebar.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`")
|
154 |
+
except Exception as e:
|
155 |
+
st.sidebar.error(f"Error loading model: {e}")
|
156 |
+
st.exception(e)
|
157 |
+
else:
|
158 |
+
st.sidebar.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.")
|
159 |
+
|
160 |
+
# --- Main Content: Prediction Section and Training History ---
|
161 |
+
|
162 |
+
# Display training history chart
|
163 |
+
if training_history:
|
164 |
+
st.subheader("Training History Plots")
|
165 |
+
history_df = pd.DataFrame({
|
166 |
+
'Epoch': range(1, len(training_history['train_loss']) + 1),
|
167 |
+
'Train Loss': training_history['train_loss'],
|
168 |
+
'Test Loss': training_history['test_loss'],
|
169 |
+
'Test CER (%)': [cer * 100 for cer in training_history['test_cer']],
|
170 |
+
'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']]
|
171 |
+
})
|
172 |
+
|
173 |
+
st.markdown("**Loss over Epochs**")
|
174 |
+
st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
|
175 |
+
st.caption("Lower loss indicates better model performance.")
|
176 |
+
|
177 |
+
st.markdown("**Character Error Rate (CER) over Epochs**")
|
178 |
+
st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
|
179 |
+
st.caption("Lower CER indicates fewer character errors (0% is perfect).")
|
180 |
+
|
181 |
+
st.markdown("**Exact Match Accuracy over Epochs**")
|
182 |
+
st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
|
183 |
+
st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
|
184 |
+
|
185 |
+
st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
|
186 |
+
st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
|
187 |
+
st.caption("CER should decrease, Accuracy should increase.")
|
188 |
+
st.write("---") # Separator after charts
|
189 |
+
|
190 |
+
|
191 |
+
# Predict on a New Image
|
192 |
+
|
193 |
+
if ocr_model is None:
|
194 |
+
st.warning("Please train or load a model before attempting prediction.")
|
195 |
+
else:
|
196 |
+
uploaded_file = st.file_uploader("πΌοΈ Choose an image...", type=["png", "jpg", "jpeg", "jfif"])
|
197 |
+
|
198 |
+
if uploaded_file is not None:
|
199 |
+
try:
|
200 |
+
image_pil = Image.open(uploaded_file).convert('L')
|
201 |
+
st.image(image_pil, caption="Uploaded Image", use_container_width=True)
|
202 |
+
st.write("---")
|
203 |
+
st.write("Processing and Recognizing...")
|
204 |
+
|
205 |
+
processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
|
206 |
+
|
207 |
+
ocr_model.eval()
|
208 |
+
with torch.no_grad():
|
209 |
+
output = ocr_model(processed_image_tensor)
|
210 |
+
|
211 |
+
predicted_texts = ctc_greedy_decode(output, char_indexer)
|
212 |
+
predicted_text = predicted_texts[0]
|
213 |
+
|
214 |
+
st.success(f"Recognized Text: **{predicted_text}**")
|
215 |
+
|
216 |
+
except Exception as e:
|
217 |
+
st.error(f"Error processing image or recognizing text: {e}")
|
218 |
+
st.info("π‘ **Tips for best results:**\n"
|
219 |
+
"- Ensure the handwritten text is clear and on a clean background.\n"
|
220 |
+
"- Only include one name/word per image.\n"
|
221 |
+
"- The model is trained on specific characters. Unusual symbols might not be recognized.")
|
222 |
+
st.exception(e)
|
223 |
+
|
224 |
+
st.markdown("""
|
225 |
+
---
|
226 |
+
*Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
|
227 |
+
""")
|
config.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
<<<<<<< HEAD
|
2 |
# config.py
|
3 |
|
4 |
import os
|
@@ -8,8 +7,8 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
8 |
DATA_DIR = os.path.join(BASE_DIR, 'data')
|
9 |
MODELS_DIR = os.path.join(BASE_DIR, 'models')
|
10 |
|
11 |
-
TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'images'
|
12 |
-
TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'images'
|
13 |
|
14 |
TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
|
15 |
TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
|
@@ -17,26 +16,13 @@ TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
|
|
17 |
MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'handwritten_name_ocr_model.pth')
|
18 |
|
19 |
# --- Character Set and OCR Configuration ---
|
20 |
-
# This character set MUST cover all characters present in your dataset.
|
21 |
-
# Add any special characters if needed.
|
22 |
-
# The order here is crucial as it defines the indices for your characters.
|
23 |
CHARS = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
|
24 |
-
|
25 |
-
# Define the character for the blank token. It MUST NOT be in CHARS.
|
26 |
-
BLANK_TOKEN_SYMBOL = 'Γ'
|
27 |
-
|
28 |
-
# Construct the full vocabulary string. It's conventional to put the blank token last.
|
29 |
-
# This VOCABULARY string is what you pass to CharIndexer.
|
30 |
VOCABULARY = CHARS + BLANK_TOKEN_SYMBOL
|
31 |
-
|
32 |
-
# NUM_CLASSES is the total number of unique symbols in the vocabulary, including the blank.
|
33 |
NUM_CLASSES = len(VOCABULARY)
|
34 |
-
|
35 |
-
# BLANK_TOKEN is the actual index of the blank symbol within the VOCABULARY.
|
36 |
-
# Since we appended it last, its index will be len(CHARS).
|
37 |
BLANK_TOKEN = VOCABULARY.find(BLANK_TOKEN_SYMBOL)
|
38 |
|
39 |
-
# --- Sanity Checks
|
40 |
if BLANK_TOKEN == -1:
|
41 |
raise ValueError(f"Error: BLANK_TOKEN_SYMBOL '{BLANK_TOKEN_SYMBOL}' not found in VOCABULARY. Check config.py definitions.")
|
42 |
if BLANK_TOKEN >= NUM_CLASSES:
|
@@ -48,65 +34,15 @@ print(f"Blank Symbol: '{BLANK_TOKEN_SYMBOL}' at index {BLANK_TOKEN}")
|
|
48 |
|
49 |
|
50 |
# --- Image Preprocessing Parameters ---
|
51 |
-
IMG_HEIGHT = 32
|
|
|
52 |
|
53 |
# --- Training Parameters ---
|
54 |
-
BATCH_SIZE =
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
LEARNING_RATE = 0.001
|
56 |
-
=======
|
57 |
-
# config.py
|
58 |
-
|
59 |
-
import os
|
60 |
-
|
61 |
-
# --- Paths ---
|
62 |
-
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
63 |
-
DATA_DIR = os.path.join(BASE_DIR, 'data')
|
64 |
-
MODELS_DIR = os.path.join(BASE_DIR, 'models')
|
65 |
-
|
66 |
-
TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'images', 'train')
|
67 |
-
TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'images', 'test')
|
68 |
-
|
69 |
-
TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
|
70 |
-
TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
|
71 |
-
|
72 |
-
MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'handwritten_name_ocr_model.pth')
|
73 |
-
|
74 |
-
# --- Character Set and OCR Configuration ---
|
75 |
-
# This character set MUST cover all characters present in your dataset.
|
76 |
-
# Add any special characters if needed.
|
77 |
-
# The order here is crucial as it defines the indices for your characters.
|
78 |
-
CHARS = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
|
79 |
-
|
80 |
-
# Define the character for the blank token. It MUST NOT be in CHARS.
|
81 |
-
BLANK_TOKEN_SYMBOL = 'Γ'
|
82 |
-
|
83 |
-
# Construct the full vocabulary string. It's conventional to put the blank token last.
|
84 |
-
# This VOCABULARY string is what you pass to CharIndexer.
|
85 |
-
VOCABULARY = CHARS + BLANK_TOKEN_SYMBOL
|
86 |
-
|
87 |
-
# NUM_CLASSES is the total number of unique symbols in the vocabulary, including the blank.
|
88 |
-
NUM_CLASSES = len(VOCABULARY)
|
89 |
-
|
90 |
-
# BLANK_TOKEN is the actual index of the blank symbol within the VOCABULARY.
|
91 |
-
# Since we appended it last, its index will be len(CHARS).
|
92 |
-
BLANK_TOKEN = VOCABULARY.find(BLANK_TOKEN_SYMBOL)
|
93 |
-
|
94 |
-
# --- Sanity Checks (Highly Recommended) ---
|
95 |
-
if BLANK_TOKEN == -1:
|
96 |
-
raise ValueError(f"Error: BLANK_TOKEN_SYMBOL '{BLANK_TOKEN_SYMBOL}' not found in VOCABULARY. Check config.py definitions.")
|
97 |
-
if BLANK_TOKEN >= NUM_CLASSES:
|
98 |
-
raise ValueError(f"Error: BLANK_TOKEN index ({BLANK_TOKEN}) must be less than NUM_CLASSES ({NUM_CLASSES}).")
|
99 |
-
|
100 |
-
print(f"Config Loaded: NUM_CLASSES={NUM_CLASSES}, BLANK_TOKEN_INDEX={BLANK_TOKEN}")
|
101 |
-
print(f"Vocabulary Length: {len(VOCABULARY)}")
|
102 |
-
print(f"Blank Symbol: '{BLANK_TOKEN_SYMBOL}' at index {BLANK_TOKEN}")
|
103 |
-
|
104 |
-
|
105 |
-
# --- Image Preprocessing Parameters ---
|
106 |
-
IMG_HEIGHT = 32
|
107 |
-
|
108 |
-
# --- Training Parameters ---
|
109 |
-
BATCH_SIZE = 64
|
110 |
-
LEARNING_RATE = 0.001
|
111 |
-
>>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
|
112 |
-
NUM_EPOCHS = 3
|
|
|
|
|
1 |
# config.py
|
2 |
|
3 |
import os
|
|
|
7 |
DATA_DIR = os.path.join(BASE_DIR, 'data')
|
8 |
MODELS_DIR = os.path.join(BASE_DIR, 'models')
|
9 |
|
10 |
+
TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'images')
|
11 |
+
TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'images')
|
12 |
|
13 |
TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
|
14 |
TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
|
|
|
16 |
MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'handwritten_name_ocr_model.pth')
|
17 |
|
18 |
# --- Character Set and OCR Configuration ---
|
|
|
|
|
|
|
19 |
CHARS = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
|
20 |
+
BLANK_TOKEN_SYMBOL = 'Γ'
|
|
|
|
|
|
|
|
|
|
|
21 |
VOCABULARY = CHARS + BLANK_TOKEN_SYMBOL
|
|
|
|
|
22 |
NUM_CLASSES = len(VOCABULARY)
|
|
|
|
|
|
|
23 |
BLANK_TOKEN = VOCABULARY.find(BLANK_TOKEN_SYMBOL)
|
24 |
|
25 |
+
# --- Sanity Checks ---
|
26 |
if BLANK_TOKEN == -1:
|
27 |
raise ValueError(f"Error: BLANK_TOKEN_SYMBOL '{BLANK_TOKEN_SYMBOL}' not found in VOCABULARY. Check config.py definitions.")
|
28 |
if BLANK_TOKEN >= NUM_CLASSES:
|
|
|
34 |
|
35 |
|
36 |
# --- Image Preprocessing Parameters ---
|
37 |
+
IMG_HEIGHT = 32 # Target height for all input images to the model
|
38 |
+
MAX_IMG_WIDTH = 1024 # Adjust this value based on your typical image widths and available RAM
|
39 |
|
40 |
# --- Training Parameters ---
|
41 |
+
BATCH_SIZE = 10
|
42 |
+
|
43 |
+
# NEW: Dataset Limits
|
44 |
+
TRAIN_SAMPLES_LIMIT = 1000
|
45 |
+
TEST_SAMPLES_LIMIT = 1000
|
46 |
+
|
47 |
+
NUM_EPOCHS = 5
|
48 |
LEARNING_RATE = 0.001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_handler_ocr.py
CHANGED
@@ -1,151 +1,165 @@
|
|
1 |
-
#data_handler_ocr.py
|
2 |
-
|
3 |
-
import pandas as pd
|
4 |
-
import torch
|
5 |
-
from torch.utils.data import Dataset, DataLoader
|
6 |
-
from torchvision import transforms
|
7 |
-
import os
|
8 |
-
from PIL import Image
|
9 |
-
import numpy as np
|
10 |
-
import torch.nn.functional as F
|
11 |
-
|
12 |
-
# Import utility functions and config
|
13 |
-
from config import
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
"
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
if idx
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
ground_truth_text =
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#data_handler_ocr.py
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
from torchvision import transforms
|
7 |
+
import os
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
# Import utility functions and config
|
13 |
+
from config import (
|
14 |
+
VOCABULARY, BLANK_TOKEN, BLANK_TOKEN_SYMBOL, IMG_HEIGHT,
|
15 |
+
TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
|
16 |
+
TRAIN_SAMPLES_LIMIT, TEST_SAMPLES_LIMIT
|
17 |
+
)
|
18 |
+
from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
|
19 |
+
|
20 |
+
class CharIndexer:
|
21 |
+
"""Manages character-to-index and index-to-character mappings."""
|
22 |
+
def __init__(self, vocabulary_string: str, blank_token_symbol: str):
|
23 |
+
self.chars = sorted(list(set(vocabulary_string)))
|
24 |
+
self.char_to_idx = {char: i for i, char in enumerate(self.chars)}
|
25 |
+
self.idx_to_char = {i: char for i, char in enumerate(self.chars)}
|
26 |
+
|
27 |
+
if blank_token_symbol not in self.char_to_idx:
|
28 |
+
raise ValueError(f"Blank token symbol '{blank_token_symbol}' not found in provided vocabulary string: '{vocabulary_string}'")
|
29 |
+
|
30 |
+
self.blank_token_idx = self.char_to_idx[blank_token_symbol]
|
31 |
+
self.num_classes = len(self.chars)
|
32 |
+
|
33 |
+
if self.blank_token_idx >= self.num_classes:
|
34 |
+
raise ValueError(f"Blank token index ({self.blank_token_idx}) is out of range for num_classes ({self.num_classes}). This indicates a configuration mismatch.")
|
35 |
+
|
36 |
+
print(f"CharIndexer initialized: num_classes={self.num_classes}, blank_token_idx={self.blank_token_idx}")
|
37 |
+
print(f"Mapped blank symbol: '{self.idx_to_char[self.blank_token_idx]}'")
|
38 |
+
|
39 |
+
def encode(self, text: str) -> list[int]:
|
40 |
+
"""Converts a text string to a list of integer indices."""
|
41 |
+
encoded_list = []
|
42 |
+
for char in text:
|
43 |
+
if char in self.char_to_idx:
|
44 |
+
encoded_list.append(self.char_to_idx[char])
|
45 |
+
else:
|
46 |
+
print(f"Warning: Character '{char}' not found in CharIndexer vocabulary. Mapping to blank token.")
|
47 |
+
encoded_list.append(self.blank_token_idx)
|
48 |
+
return encoded_list
|
49 |
+
|
50 |
+
def decode(self, indices: list[int]) -> str:
|
51 |
+
"""Converts a list of integer indices back to a text string."""
|
52 |
+
decoded_text = []
|
53 |
+
for i, idx in enumerate(indices):
|
54 |
+
if idx == self.blank_token_idx:
|
55 |
+
continue # Skip blank tokens
|
56 |
+
|
57 |
+
if i > 0 and indices[i-1] == idx:
|
58 |
+
continue
|
59 |
+
|
60 |
+
if idx in self.idx_to_char:
|
61 |
+
decoded_text.append(self.idx_to_char[idx])
|
62 |
+
else:
|
63 |
+
print(f"Warning: Index {idx} not found in CharIndexer's idx_to_char mapping during decoding.")
|
64 |
+
|
65 |
+
return "".join(decoded_text)
|
66 |
+
|
67 |
+
class OCRDataset(Dataset):
|
68 |
+
"""
|
69 |
+
Custom PyTorch Dataset for the Handwritten Name Recognition task.
|
70 |
+
Loads images and their corresponding text labels.
|
71 |
+
"""
|
72 |
+
def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
|
73 |
+
self.data = dataframe
|
74 |
+
self.char_indexer = char_indexer
|
75 |
+
self.image_dir = image_dir
|
76 |
+
|
77 |
+
if transform is None:
|
78 |
+
self.transform = transforms.Compose([
|
79 |
+
transforms.Lambda(lambda img: binarize_image(img)),
|
80 |
+
transforms.Lambda(lambda img: resize_image_for_ocr(img, IMG_HEIGHT)), # Resize image to fixed height
|
81 |
+
transforms.ToTensor(), # Convert PIL Image to PyTorch Tensor (H, W) -> (1, H, W), scales to [0,1]
|
82 |
+
transforms.Lambda(normalize_image_for_model) # Normalize pixel values to [-1, 1]
|
83 |
+
])
|
84 |
+
else:
|
85 |
+
self.transform = transform
|
86 |
+
|
87 |
+
|
88 |
+
def __len__(self) -> int:
|
89 |
+
return len(self.data)
|
90 |
+
|
91 |
+
def __getitem__(self, idx):
|
92 |
+
raw_filename_entry = self.data.loc[idx, 'FILENAME']
|
93 |
+
ground_truth_text = self.data.loc[idx, 'IDENTITY']
|
94 |
+
|
95 |
+
filename = raw_filename_entry.split(',')[0].strip()
|
96 |
+
img_path = os.path.join(self.image_dir, filename)
|
97 |
+
ground_truth_text = str(ground_truth_text)
|
98 |
+
|
99 |
+
try:
|
100 |
+
image = load_image_as_grayscale(img_path) # Returns PIL Image 'L'
|
101 |
+
except FileNotFoundError:
|
102 |
+
print(f"Error: Image file not found at {img_path}. Skipping this item.")
|
103 |
+
raise
|
104 |
+
|
105 |
+
if self.transform:
|
106 |
+
image = self.transform(image)
|
107 |
+
|
108 |
+
image_width = image.shape[2] # Assuming image is (C, H, W) after transform
|
109 |
+
|
110 |
+
text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
|
111 |
+
text_length = len(text_encoded)
|
112 |
+
|
113 |
+
return image, text_encoded, image_width, text_length
|
114 |
+
|
115 |
+
def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
116 |
+
"""
|
117 |
+
Custom collate function for the DataLoader to handle variable-width images
|
118 |
+
and variable-length text sequences for CTC loss.
|
119 |
+
"""
|
120 |
+
images, texts, image_widths, text_lengths = zip(*batch)
|
121 |
+
|
122 |
+
max_batch_width = max(image_widths)
|
123 |
+
padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
|
124 |
+
images_batch = torch.stack(padded_images, 0)
|
125 |
+
|
126 |
+
texts_batch = torch.cat(texts, 0)
|
127 |
+
text_lengths_tensor = torch.tensor(list(text_lengths), dtype=torch.long)
|
128 |
+
image_widths_tensor = torch.tensor(image_widths, dtype=torch.long)
|
129 |
+
|
130 |
+
return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
|
131 |
+
|
132 |
+
|
133 |
+
def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
|
134 |
+
"""
|
135 |
+
Loads training and testing dataframes.
|
136 |
+
Assumes CSVs have 'FILENAME' and 'IDENTITY' columns.
|
137 |
+
Applies dataset limits from config.py.
|
138 |
+
"""
|
139 |
+
train_df = pd.read_csv(train_csv_path, encoding='ISO-8859-1')
|
140 |
+
test_df = pd.read_csv(test_csv_path, encoding='ISO-8859-1')
|
141 |
+
|
142 |
+
# Apply limits if they are set (not 0)
|
143 |
+
if TRAIN_SAMPLES_LIMIT > 0:
|
144 |
+
train_df = train_df.head(TRAIN_SAMPLES_LIMIT)
|
145 |
+
print(f"Limited training data to {TRAIN_SAMPLES_LIMIT} samples.")
|
146 |
+
if TEST_SAMPLES_LIMIT > 0:
|
147 |
+
test_df = test_df.head(TEST_SAMPLES_LIMIT)
|
148 |
+
print(f"Limited test data to {TEST_SAMPLES_LIMIT} samples.")
|
149 |
+
|
150 |
+
return train_df, test_df
|
151 |
+
|
152 |
+
def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
|
153 |
+
char_indexer: CharIndexer, batch_size: int) -> tuple[DataLoader, DataLoader]:
|
154 |
+
"""
|
155 |
+
Creates PyTorch DataLoader objects for OCR training and testing datasets,
|
156 |
+
using specific image directories for train/test.
|
157 |
+
"""
|
158 |
+
train_dataset = OCRDataset(train_df, char_indexer, TRAIN_IMAGES_DIR)
|
159 |
+
test_dataset = OCRDataset(test_df, char_indexer, TEST_IMAGES_DIR)
|
160 |
+
|
161 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
|
162 |
+
num_workers=0, collate_fn=ocr_collate_fn)
|
163 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
|
164 |
+
num_workers=0, collate_fn=ocr_collate_fn)
|
165 |
+
return train_loader, test_loader
|
model_ocr.py
CHANGED
@@ -1,286 +1,285 @@
|
|
1 |
-
# model_ocr.py
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn as
|
5 |
-
import torch.
|
6 |
-
|
7 |
-
from
|
8 |
-
from
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
from
|
14 |
-
from
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
nn.
|
28 |
-
nn.
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
nn.
|
33 |
-
nn.
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
nn.
|
38 |
-
nn.
|
39 |
-
nn.
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
nn.
|
46 |
-
|
47 |
-
#
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
#
|
56 |
-
|
57 |
-
|
58 |
-
#
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
#
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
#
|
106 |
-
#
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
model.
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
'
|
211 |
-
'
|
212 |
-
'
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
# `
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
loss
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
epoch_train_loss
|
249 |
-
|
250 |
-
|
251 |
-
#
|
252 |
-
|
253 |
-
model
|
254 |
-
test_loss
|
255 |
-
training_history['
|
256 |
-
training_history['
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
progress_val =
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
model.
|
285 |
-
|
286 |
-
print(f"OCR model loaded from {path}")
|
|
|
1 |
+
# model_ocr.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.optim as optim
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from tqdm import tqdm
|
8 |
+
from sklearn.metrics import accuracy_score
|
9 |
+
import editdistance
|
10 |
+
|
11 |
+
# Import config and char_indexer
|
12 |
+
from config import IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN
|
13 |
+
from data_handler_ocr import CharIndexer
|
14 |
+
from utils_ocr import binarize_image, resize_image_for_ocr, normalize_image_for_model
|
15 |
+
|
16 |
+
|
17 |
+
class CNN_Backbone(nn.Module):
|
18 |
+
"""
|
19 |
+
CNN feature extractor for OCR. Designed to produce features suitable for RNN.
|
20 |
+
Output feature map should have height 1 after the final pooling/reduction.
|
21 |
+
"""
|
22 |
+
def __init__(self, input_channels=1, output_channels=512):
|
23 |
+
super(CNN_Backbone, self).__init__()
|
24 |
+
self.cnn = nn.Sequential(
|
25 |
+
# First block
|
26 |
+
nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
|
27 |
+
nn.ReLU(True),
|
28 |
+
nn.MaxPool2d(kernel_size=2, stride=2), # H: 32 -> 16, W: W_in -> W_in/2
|
29 |
+
|
30 |
+
# Second block
|
31 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
32 |
+
nn.ReLU(True),
|
33 |
+
nn.MaxPool2d(kernel_size=2, stride=2), # H: 16 -> 8, W: W_in/2 -> W_in/4
|
34 |
+
|
35 |
+
# Third block (with two conv layers)
|
36 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
37 |
+
nn.ReLU(True),
|
38 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
39 |
+
nn.ReLU(True),
|
40 |
+
# This MaxPool2d effectively brings height from 8 to 4, with a small width adjustment due to padding
|
41 |
+
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # H: 8 -> 4, W: (W/4) -> (W/4 + 1) (approx)
|
42 |
+
|
43 |
+
# Fourth block
|
44 |
+
nn.Conv2d(256, output_channels, kernel_size=3, stride=1, padding=1),
|
45 |
+
nn.ReLU(True),
|
46 |
+
# This AdaptiveAvgPool2d makes sure the height dimension becomes 1
|
47 |
+
# while preserving the width. This is crucial for RNN input.
|
48 |
+
nn.AdaptiveAvgPool2d((1, None)) # Output height 1, preserve width
|
49 |
+
)
|
50 |
+
|
51 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
52 |
+
# x: (N, C, H, W) e.g., (B, 1, 32, W_img)
|
53 |
+
|
54 |
+
# Pass through the CNN layers
|
55 |
+
conv_features = self.cnn(x) # Output: (N, cnn_out_channels, 1, W_prime)
|
56 |
+
|
57 |
+
# Squeeze the height dimension (which is 1)
|
58 |
+
# This transforms (N, C_out, 1, W_prime) to (N, C_out, W_prime)
|
59 |
+
conv_features = conv_features.squeeze(2)
|
60 |
+
|
61 |
+
# Permute for RNN input: (sequence_length, batch_size, input_size)
|
62 |
+
# This transforms (N, C_out, W_prime) to (W_prime, N, C_out)
|
63 |
+
conv_features = conv_features.permute(2, 0, 1)
|
64 |
+
|
65 |
+
# Return the CNN features, ready for the RNN layer in CRNN
|
66 |
+
return conv_features
|
67 |
+
|
68 |
+
class BidirectionalLSTM(nn.Module):
|
69 |
+
"""Bidirectional LSTM layer for sequence modeling."""
|
70 |
+
def __init__(self, input_size: int, hidden_size: int, num_layers: int, dropout: float = 0.5):
|
71 |
+
super(BidirectionalLSTM, self).__init__()
|
72 |
+
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
|
73 |
+
bidirectional=True, dropout=dropout, batch_first=False)
|
74 |
+
# batch_first=False expects input as (sequence_length, batch_size, input_size)
|
75 |
+
|
76 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
77 |
+
output, _ = self.lstm(x) # [0] returns the output, [1] returns (h_n, c_n)
|
78 |
+
return output
|
79 |
+
|
80 |
+
class CRNN(nn.Module):
|
81 |
+
"""
|
82 |
+
Convolutional Recurrent Neural Network for OCR.
|
83 |
+
Combines CNN for feature extraction, LSTMs for sequence modeling,
|
84 |
+
and a final linear layer for character prediction.
|
85 |
+
"""
|
86 |
+
def __init__(self, num_classes: int, cnn_output_channels: int = 512,
|
87 |
+
rnn_hidden_size: int = 256, rnn_num_layers: int = 2): # Corrected parameter name
|
88 |
+
super(CRNN, self).__init__()
|
89 |
+
self.cnn = CNN_Backbone(output_channels=cnn_output_channels)
|
90 |
+
# Input to LSTM is the number of channels from the CNN output
|
91 |
+
self.rnn = BidirectionalLSTM(cnn_output_channels, rnn_hidden_size, rnn_num_layers) # Corrected usage
|
92 |
+
# Output of bidirectional LSTM is hidden_size * 2
|
93 |
+
self.fc = nn.Linear(rnn_hidden_size * 2, num_classes) # Final linear layer for classes
|
94 |
+
|
95 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
96 |
+
# x: (N, C, H, W) e.g., (B, 1, 32, W_img)
|
97 |
+
|
98 |
+
# 1. Pass through the CNN to extract features
|
99 |
+
conv_features = self.cnn(x) # Output: (W_prime, N, C_out) after permute in CNN_Backbone
|
100 |
+
|
101 |
+
# 2. Pass CNN features through the RNN (LSTM)
|
102 |
+
rnn_features = self.rnn(conv_features) # Output: (W_prime, N, rnn_hidden_size * 2)
|
103 |
+
|
104 |
+
# 3. Pass RNN features through the final fully connected layer
|
105 |
+
# Apply the linear layer to each time step independently
|
106 |
+
# output will be (W_prime, N, num_classes)
|
107 |
+
output = self.fc(rnn_features)
|
108 |
+
|
109 |
+
return output
|
110 |
+
|
111 |
+
|
112 |
+
# --- Decoding Function ---
|
113 |
+
def ctc_greedy_decode(output: torch.Tensor, char_indexer: CharIndexer) -> list[str]:
|
114 |
+
"""
|
115 |
+
Performs greedy decoding on the CTC output.
|
116 |
+
output: (sequence_length, batch_size, num_classes) - raw logits
|
117 |
+
"""
|
118 |
+
# Apply log_softmax to get probabilities for argmax
|
119 |
+
log_probs = F.log_softmax(output, dim=2)
|
120 |
+
|
121 |
+
# Permute to (batch_size, sequence_length, num_classes) for argmax along class dim
|
122 |
+
predicted_indices = torch.argmax(log_probs.permute(1, 0, 2), dim=2).cpu().numpy()
|
123 |
+
|
124 |
+
decoded_texts = []
|
125 |
+
for seq in predicted_indices:
|
126 |
+
# Use char_indexer's decode method, which handles blank removal and duplicate collapse
|
127 |
+
decoded_texts.append(char_indexer.decode(seq.tolist()))
|
128 |
+
return decoded_texts
|
129 |
+
|
130 |
+
# --- Evaluation Function ---
|
131 |
+
def evaluate_model(model: nn.Module, dataloader: DataLoader, char_indexer: CharIndexer, device: str):
|
132 |
+
model.eval()
|
133 |
+
criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
|
134 |
+
total_loss = 0
|
135 |
+
all_predictions = []
|
136 |
+
all_ground_truths = []
|
137 |
+
|
138 |
+
with torch.no_grad():
|
139 |
+
for inputs, targets_padded, _, target_lengths in tqdm(dataloader, desc="Evaluating"):
|
140 |
+
inputs = inputs.to(device)
|
141 |
+
targets_padded = targets_padded.to(device)
|
142 |
+
target_lengths_tensor = target_lengths.to(device)
|
143 |
+
|
144 |
+
output = model(inputs)
|
145 |
+
|
146 |
+
outputs_seq_len_for_ctc = torch.full(
|
147 |
+
size=(output.shape[1],),
|
148 |
+
fill_value=output.shape[0],
|
149 |
+
dtype=torch.long,
|
150 |
+
device=device
|
151 |
+
)
|
152 |
+
|
153 |
+
# CTC Loss calculation requires log_softmax on the output logits
|
154 |
+
log_probs_for_loss = F.log_softmax(output, dim=2)
|
155 |
+
|
156 |
+
# CTCLoss expects targets_padded as a 1D tensor and target_lengths_tensor as corresponding lengths
|
157 |
+
loss = criterion(log_probs_for_loss, targets_padded, outputs_seq_len_for_ctc, target_lengths_tensor)
|
158 |
+
total_loss += loss.item() * inputs.size(0)
|
159 |
+
|
160 |
+
decoded_preds = ctc_greedy_decode(output, char_indexer)
|
161 |
+
all_predictions.extend(decoded_preds)
|
162 |
+
|
163 |
+
ground_truths_batch = []
|
164 |
+
current_idx_in_concatenated_targets = 0
|
165 |
+
|
166 |
+
target_lengths_list = target_lengths.cpu().tolist()
|
167 |
+
|
168 |
+
for i in range(inputs.size(0)):
|
169 |
+
length = target_lengths_list[i]
|
170 |
+
|
171 |
+
current_target_segment = targets_padded[current_idx_in_concatenated_targets : current_idx_in_concatenated_targets + length].tolist()
|
172 |
+
ground_truths_batch.append(char_indexer.decode(current_target_segment))
|
173 |
+
current_idx_in_concatenated_targets += length
|
174 |
+
|
175 |
+
all_ground_truths.extend(ground_truths_batch)
|
176 |
+
|
177 |
+
avg_loss = total_loss / len(dataloader.dataset)
|
178 |
+
|
179 |
+
# Calculate Character Error Rate (CER)
|
180 |
+
cer_sum = 0
|
181 |
+
total_chars = 0
|
182 |
+
for pred, gt in zip(all_predictions, all_ground_truths):
|
183 |
+
cer_sum += editdistance.eval(pred, gt)
|
184 |
+
total_chars += len(gt)
|
185 |
+
char_error_rate = cer_sum / total_chars if total_chars > 0 else 0.0
|
186 |
+
|
187 |
+
# Calculate Exact Match Accuracy (Word-level Accuracy)
|
188 |
+
exact_match_accuracy = accuracy_score(all_ground_truths, all_predictions)
|
189 |
+
|
190 |
+
return avg_loss, char_error_rate, exact_match_accuracy
|
191 |
+
|
192 |
+
# --- Training Function ---
|
193 |
+
def train_ocr_model(model: nn.Module, train_loader: DataLoader,
|
194 |
+
test_loader: DataLoader, char_indexer: CharIndexer,
|
195 |
+
epochs: int, device: str, progress_callback=None) -> tuple[nn.Module, dict]:
|
196 |
+
"""
|
197 |
+
Trains the OCR model using CTC loss.
|
198 |
+
"""
|
199 |
+
# CTCLoss needs the blank token index
|
200 |
+
criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
|
201 |
+
optimizer = optim.Adam(model.parameters(), lr=0.001) # Using a fixed LR for now
|
202 |
+
# Using ReduceLROnPlateau to adjust LR based on test loss (monitor 'min' loss)
|
203 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5) # Removed verbose=True
|
204 |
+
|
205 |
+
model.to(device) # Ensure model is on the correct device
|
206 |
+
model.train() # Set model to training mode
|
207 |
+
|
208 |
+
training_history = {
|
209 |
+
'train_loss': [],
|
210 |
+
'test_loss': [],
|
211 |
+
'test_cer': [],
|
212 |
+
'test_exact_match_accuracy': []
|
213 |
+
}
|
214 |
+
|
215 |
+
for epoch in range(epochs):
|
216 |
+
running_loss = 0.0
|
217 |
+
pbar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)")
|
218 |
+
for images, texts_encoded, _, text_lengths in pbar_train:
|
219 |
+
images = images.to(device)
|
220 |
+
# Ensure target tensors are on the correct device for CTCLoss calculation
|
221 |
+
texts_encoded = texts_encoded.to(device)
|
222 |
+
text_lengths = text_lengths.to(device)
|
223 |
+
|
224 |
+
optimizer.zero_grad() # Clear gradients from previous step
|
225 |
+
outputs = model(images) # (sequence_length_from_cnn, batch_size, num_classes)
|
226 |
+
|
227 |
+
# `outputs.shape[0]` is the actual sequence length (T) produced by the model.
|
228 |
+
# CTC loss expects `input_lengths` to be a tensor of shape (batch_size,) with these values.
|
229 |
+
outputs_seq_len_for_ctc = torch.full(
|
230 |
+
size=(outputs.shape[1],), # batch_size
|
231 |
+
fill_value=outputs.shape[0], # actual sequence length (T) from model output
|
232 |
+
dtype=torch.long,
|
233 |
+
device=device
|
234 |
+
)
|
235 |
+
|
236 |
+
# CTC Loss calculation requires log_softmax on the output logits
|
237 |
+
log_probs_for_loss = F.log_softmax(outputs, dim=2) # (T, N, C)
|
238 |
+
|
239 |
+
# Use outputs_seq_len_for_ctc for the input_lengths argument
|
240 |
+
loss = criterion(log_probs_for_loss, texts_encoded, outputs_seq_len_for_ctc, text_lengths)
|
241 |
+
loss.backward() # Backpropagate
|
242 |
+
optimizer.step() # Update model weights
|
243 |
+
|
244 |
+
running_loss += loss.item() * images.size(0) # Multiply by batch size for correct average
|
245 |
+
pbar_train.set_postfix(loss=loss.item())
|
246 |
+
|
247 |
+
epoch_train_loss = running_loss / len(train_loader.dataset)
|
248 |
+
training_history['train_loss'].append(epoch_train_loss)
|
249 |
+
|
250 |
+
# Evaluate on test set using the dedicated function
|
251 |
+
# Ensure model is in eval mode before calling evaluate_model
|
252 |
+
model.eval()
|
253 |
+
test_loss, test_cer, test_exact_match_accuracy = evaluate_model(model, test_loader, char_indexer, device)
|
254 |
+
training_history['test_loss'].append(test_loss)
|
255 |
+
training_history['test_cer'].append(test_cer)
|
256 |
+
training_history['test_exact_match_accuracy'].append(test_exact_match_accuracy)
|
257 |
+
|
258 |
+
# Adjust learning rate based on test loss
|
259 |
+
scheduler.step(test_loss)
|
260 |
+
|
261 |
+
print(f"Epoch {epoch+1}/{epochs}: Train Loss={epoch_train_loss:.4f}, "
|
262 |
+
f"Test Loss={test_loss:.4f}, Test CER={test_cer:.4f}, Test Exact Match Acc={test_exact_match_accuracy:.4f}")
|
263 |
+
|
264 |
+
if progress_callback:
|
265 |
+
# Update progress bar with current epoch and key metrics
|
266 |
+
progress_val = (epoch + 1) / epochs
|
267 |
+
progress_callback(progress_val, text=f"Epoch {epoch+1}/{epochs} done. Test CER: {test_cer:.4f}, Test Exact Match Acc: {test_exact_match_accuracy:.4f}")
|
268 |
+
|
269 |
+
model.train() # Set model back to training mode after evaluation
|
270 |
+
|
271 |
+
return model, training_history
|
272 |
+
|
273 |
+
def save_ocr_model(model: nn.Module, path: str):
|
274 |
+
"""Saves the state dictionary of the trained OCR model."""
|
275 |
+
torch.save(model.state_dict(), path)
|
276 |
+
print(f"OCR model saved to {path}")
|
277 |
+
|
278 |
+
def load_ocr_model(model: nn.Module, path: str):
|
279 |
+
"""
|
280 |
+
Loads a trained OCR model's state dictionary.
|
281 |
+
Includes map_location to handle loading models trained on GPU to CPU, and vice versa.
|
282 |
+
"""
|
283 |
+
model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) # Always load to CPU first
|
284 |
+
model.eval() # Set to evaluation mode
|
285 |
+
print(f"OCR model loaded from {path}")
|
|
utils_ocr.py
CHANGED
@@ -1,184 +1,83 @@
|
|
1 |
-
<<<<<<< HEAD
|
2 |
#utils_ocr.py
|
3 |
|
4 |
import cv2
|
5 |
-
from matplotlib.pylab import f
|
6 |
import numpy as np
|
7 |
from PIL import Image
|
8 |
import torch
|
9 |
-
|
|
|
10 |
|
11 |
-
#
|
|
|
|
|
|
|
12 |
|
13 |
def load_image_as_grayscale(image_path: str) -> Image.Image:
|
14 |
"""Loads an image from path and converts it to grayscale PIL Image."""
|
15 |
-
|
16 |
-
|
17 |
-
return
|
18 |
-
|
19 |
-
def binarize_image(image_pil: Image.Image) -> Image.Image:
|
20 |
-
"""Binarizes a grayscale PIL Image (black and white)."""
|
21 |
-
# Convert PIL to OpenCV format (numpy array)
|
22 |
-
img_np = np.array(image_pil)
|
23 |
-
# Apply Otsu's thresholding for adaptive binarization
|
24 |
-
_, img_bin = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
25 |
-
# Invert colors: Handwritten text usually dark on light. OCR models often
|
26 |
-
# prefer light text on dark background. Check your training data's style.
|
27 |
-
# This example assumes dark text on light background and inverts to white text on black.
|
28 |
-
img_bin = 255 - img_bin
|
29 |
-
return Image.fromarray(img_bin)
|
30 |
|
31 |
-
def
|
32 |
"""
|
33 |
-
|
34 |
-
|
35 |
"""
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
41 |
|
42 |
-
def
|
43 |
"""
|
44 |
-
|
|
|
45 |
"""
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
def
|
55 |
"""
|
56 |
-
|
57 |
-
|
|
|
58 |
"""
|
59 |
-
#
|
60 |
-
|
61 |
-
|
62 |
-
# Binarize
|
63 |
-
img_bin = binarize_image(img_gray)
|
64 |
-
|
65 |
-
# Resize (maintain aspect ratio)
|
66 |
-
img_resized = resize_image_for_ocr(img_bin, target_height)
|
67 |
-
|
68 |
-
# Normalize and convert to tensor
|
69 |
-
img_tensor = normalize_image_for_model(img_resized)
|
70 |
-
|
71 |
-
# Add batch dimension: (C, H, W) -> (1, C, H, W)
|
72 |
-
img_tensor = img_tensor.unsqueeze(0)
|
73 |
-
|
74 |
return img_tensor
|
75 |
|
76 |
-
def
|
77 |
"""
|
78 |
-
|
79 |
-
|
80 |
-
Output tensor shape: (C, H, max_width)
|
81 |
"""
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
#
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
#
|
94 |
-
|
95 |
-
import cv2
|
96 |
-
from matplotlib.pylab import f
|
97 |
-
import numpy as np
|
98 |
-
from PIL import Image
|
99 |
-
import torch
|
100 |
-
from torchvision import transforms
|
101 |
-
|
102 |
-
# --- Image Preprocessing for OCR ---
|
103 |
-
|
104 |
-
def load_image_as_grayscale(image_path: str) -> Image.Image:
|
105 |
-
"""Loads an image from path and converts it to grayscale PIL Image."""
|
106 |
-
# Use PIL for robust image loading and conversion to grayscale 'L' mode
|
107 |
-
img = Image.open(image_path).convert('L')
|
108 |
-
return img
|
109 |
-
|
110 |
-
def binarize_image(image_pil: Image.Image) -> Image.Image:
|
111 |
-
"""Binarizes a grayscale PIL Image (black and white)."""
|
112 |
-
# Convert PIL to OpenCV format (numpy array)
|
113 |
-
img_np = np.array(image_pil)
|
114 |
-
# Apply Otsu's thresholding for adaptive binarization
|
115 |
-
_, img_bin = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
116 |
-
# Invert colors: Handwritten text usually dark on light. OCR models often
|
117 |
-
# prefer light text on dark background. Check your training data's style.
|
118 |
-
# This example assumes dark text on light background and inverts to white text on black.
|
119 |
-
img_bin = 255 - img_bin
|
120 |
-
return Image.fromarray(img_bin)
|
121 |
-
|
122 |
-
def resize_image_for_ocr(image_pil: Image.Image, target_height: int) -> Image.Image:
|
123 |
-
"""
|
124 |
-
Resizes a PIL Image to a target height while maintaining aspect ratio.
|
125 |
-
Pads width if necessary to avoid distortion.
|
126 |
-
"""
|
127 |
-
original_width, original_height = image_pil.size
|
128 |
-
# Calculate new width based on target height and original aspect ratio
|
129 |
-
new_width = int(original_width * (target_height / original_height))
|
130 |
-
resized_img = image_pil.resize((new_width, target_height), Image.LANCZOS)
|
131 |
-
return resized_img
|
132 |
-
|
133 |
-
def normalize_image_for_model(image_pil: Image.Image) -> torch.Tensor:
|
134 |
-
"""
|
135 |
-
Converts a PIL Image to a PyTorch Tensor and normalizes pixel values.
|
136 |
-
"""
|
137 |
-
# Convert to tensor (scales to 0-1 automatically)
|
138 |
-
tensor_transform = transforms.ToTensor()
|
139 |
-
img_tensor = tensor_transform(image_pil)
|
140 |
-
# For grayscale images, mean and std are single values.
|
141 |
-
# Adjust normalization values if your training data uses different ones.
|
142 |
-
img_tensor = transforms.Normalize((0.5,), (0.5,))(img_tensor) # Normalize to [-1, 1]
|
143 |
-
return img_tensor
|
144 |
-
|
145 |
-
def preprocess_user_image_for_ocr(uploaded_image_pil: Image.Image, target_height: int) -> torch.Tensor:
|
146 |
-
"""
|
147 |
-
Combines all preprocessing steps for a single user-uploaded image
|
148 |
-
to prepare it for the OCR model.
|
149 |
-
"""
|
150 |
-
# Ensure it's grayscale
|
151 |
-
img_gray = uploaded_image_pil.convert('L')
|
152 |
-
|
153 |
-
# Binarize
|
154 |
-
img_bin = binarize_image(img_gray)
|
155 |
-
|
156 |
-
# Resize (maintain aspect ratio)
|
157 |
-
img_resized = resize_image_for_ocr(img_bin, target_height)
|
158 |
-
|
159 |
-
# Normalize and convert to tensor
|
160 |
-
img_tensor = normalize_image_for_model(img_resized)
|
161 |
-
|
162 |
-
# Add batch dimension: (C, H, W) -> (1, C, H, W)
|
163 |
-
img_tensor = img_tensor.unsqueeze(0)
|
164 |
-
|
165 |
-
return img_tensor
|
166 |
-
|
167 |
-
def pad_image_tensor(image_tensor: torch.Tensor, max_width: int) -> torch.Tensor:
|
168 |
-
"""
|
169 |
-
Pads a single image tensor to a max_width with zeros.
|
170 |
-
Input tensor shape: (C, H, W)
|
171 |
-
Output tensor shape: (C, H, max_width)
|
172 |
-
"""
|
173 |
-
C, H, W = image_tensor.shape
|
174 |
-
if W > max_width:
|
175 |
-
# If image is wider than max_width, you might want to crop or resize it.
|
176 |
-
# For this example, we'll just return a warning or clip.
|
177 |
-
# A more robust solution might split text lines or use a different resizing strategy.
|
178 |
-
print(f"Warning: Image width {W} exceeds max_width {max_width}. Cropping.")
|
179 |
-
return image_tensor[:, :, :max_width] # Simple cropping
|
180 |
-
padding = max_width - W
|
181 |
-
# Pad on the right (P_left, P_right, P_top, P_bottom)
|
182 |
-
padded_tensor = f.pad(image_tensor, (0, padding), 'constant', 0)
|
183 |
-
>>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
|
184 |
-
return padded_tensor
|
|
|
|
|
1 |
#utils_ocr.py
|
2 |
|
3 |
import cv2
|
|
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
6 |
import torch
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
import os
|
9 |
|
10 |
+
# Import config for IMG_HEIGHT and MAX_IMG_WIDTH
|
11 |
+
from config import IMG_HEIGHT, MAX_IMG_WIDTH
|
12 |
+
|
13 |
+
# --- Image Preprocessing Functions ---
|
14 |
|
15 |
def load_image_as_grayscale(image_path: str) -> Image.Image:
|
16 |
"""Loads an image from path and converts it to grayscale PIL Image."""
|
17 |
+
if not os.path.exists(image_path):
|
18 |
+
raise FileNotFoundError(f"Image not found at: {image_path}")
|
19 |
+
return Image.open(image_path).convert('L') # 'L' for grayscale
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
def binarize_image(img: Image.Image) -> Image.Image:
|
22 |
"""
|
23 |
+
Binarizes a grayscale PIL Image using Otsu's method.
|
24 |
+
Returns a PIL Image.
|
25 |
"""
|
26 |
+
# Convert PIL Image to OpenCV format (numpy array)
|
27 |
+
img_np = np.array(img)
|
28 |
+
|
29 |
+
# Apply Otsu's binarization
|
30 |
+
_, binary_img = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
31 |
+
|
32 |
+
# Convert back to PIL Image
|
33 |
+
return Image.fromarray(binary_img)
|
34 |
|
35 |
+
def resize_image_for_ocr(img: Image.Image, img_height: int) -> Image.Image:
|
36 |
"""
|
37 |
+
Resizes a PIL Image to a fixed height while maintaining aspect ratio.
|
38 |
+
Also ensures the width does not exceed MAX_IMG_WIDTH.
|
39 |
"""
|
40 |
+
width, height = img.size
|
41 |
+
|
42 |
+
# Calculate new width based on target height, maintaining aspect ratio
|
43 |
+
new_width = int(width * (img_height / height))
|
44 |
+
|
45 |
+
if new_width > MAX_IMG_WIDTH:
|
46 |
+
new_width = MAX_IMG_WIDTH
|
47 |
+
resized_img = img.resize((new_width, img_height), Image.Resampling.LANCZOS)
|
48 |
+
if resized_img.width > MAX_IMG_WIDTH:
|
49 |
+
# Crop the image from the left to MAX_IMG_WIDTH
|
50 |
+
resized_img = resized_img.crop((0, 0, MAX_IMG_WIDTH, img_height))
|
51 |
+
return resized_img
|
52 |
+
|
53 |
+
return img.resize((new_width, img_height), Image.Resampling.LANCZOS) # Use LANCZOS for high-quality downsampling
|
54 |
|
55 |
+
def normalize_image_for_model(img_tensor: torch.Tensor) -> torch.Tensor:
|
56 |
"""
|
57 |
+
Normalizes a torch.Tensor image (grayscale) for input into the model.
|
58 |
+
Puts pixel values in range [-1, 1].
|
59 |
+
Assumes image is already a torch.Tensor with values in [0, 1] (e.g., after ToTensor).
|
60 |
"""
|
61 |
+
# Formula: (pixel_value - mean) / std_dev
|
62 |
+
# For [0, 1] to [-1, 1], mean = 0.5, std_dev = 0.5
|
63 |
+
img_tensor = (img_tensor - 0.5) / 0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
return img_tensor
|
65 |
|
66 |
+
def preprocess_user_image_for_ocr(image_pil: Image.Image, target_height: int) -> torch.Tensor:
|
67 |
"""
|
68 |
+
Applies all necessary preprocessing steps to a user-uploaded PIL Image
|
69 |
+
to prepare it for the OCR model.
|
|
|
70 |
"""
|
71 |
+
# Define a transformation pipeline similar to the dataset, but including ToTensor
|
72 |
+
transform_pipeline = transforms.Compose([
|
73 |
+
transforms.Lambda(lambda img: binarize_image(img)), # PIL Image -> PIL Image
|
74 |
+
# Use the updated resize function that also handles MAX_IMG_WIDTH
|
75 |
+
transforms.Lambda(lambda img: resize_image_for_ocr(img, target_height)), # PIL Image -> PIL Image
|
76 |
+
transforms.ToTensor(), # PIL Image -> Tensor [0, 1]
|
77 |
+
transforms.Lambda(normalize_image_for_model) # Tensor [0, 1] -> Tensor [-1, 1]
|
78 |
+
])
|
79 |
+
|
80 |
+
processed_image = transform_pipeline(image_pil)
|
81 |
+
|
82 |
+
# Add a batch dimension (C, H, W) -> (1, C, H, W) for single image inference
|
83 |
+
return processed_image.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|