Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,11 @@
|
|
1 |
-
|
2 |
# app.py
|
3 |
|
|
|
|
|
|
|
|
|
|
|
4 |
import streamlit as st
|
5 |
import pandas as pd
|
6 |
import numpy as np
|
@@ -8,20 +13,32 @@ from PIL import Image
|
|
8 |
import torch
|
9 |
import torch.nn.functional as F # Added F for log_softmax in inference
|
10 |
import torchvision.transforms as transforms
|
11 |
-
import os
|
12 |
import traceback # For detailed error logging
|
13 |
|
14 |
-
# Import
|
15 |
-
from config import
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
18 |
from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
|
19 |
-
from utils_ocr import preprocess_user_image_for_ocr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# --- Streamlit App Setup ---
|
22 |
-
st.set_page_config(page_title="Handwritten Name
|
23 |
|
24 |
-
st.title("π Handwritten Name Recognition (OCR)")
|
25 |
st.markdown("""
|
26 |
This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
|
27 |
Optical Character Recognition (OCR) on handwritten names. You can upload an image
|
@@ -31,11 +48,9 @@ st.markdown("""
|
|
31 |
""")
|
32 |
|
33 |
# --- Initialize CharIndexer ---
|
34 |
-
#
|
35 |
-
#
|
36 |
-
char_indexer = CharIndexer(
|
37 |
-
# For robustness, it's best to always use char_indexer.num_classes
|
38 |
-
# If NUM_CLASSES from config is used to initialize CRNN, ensure it matches char_indexer.num_classes
|
39 |
|
40 |
# --- Model Loading / Initialization ---
|
41 |
@st.cache_resource # Cache the model to prevent reloading on every rerun
|
@@ -64,7 +79,6 @@ def get_and_load_ocr_model_cached(num_classes, model_path):
|
|
64 |
# Get the model instance
|
65 |
ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
|
66 |
# Determine the device (GPU if available, else CPU)
|
67 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
68 |
ocr_model.to(device)
|
69 |
ocr_model.eval() # Set model to evaluation mode for inference by default
|
70 |
|
@@ -73,102 +87,97 @@ st.sidebar.header("Model Training (Optional)")
|
|
73 |
st.sidebar.markdown("If you want to train a new model or no model is found:")
|
74 |
|
75 |
# Initialize Streamlit widgets outside the button block
|
76 |
-
training_progress_bar = st.sidebar.empty() # Placeholder for progress bar
|
77 |
-
status_text = st.sidebar.empty() # Placeholder for status messages
|
78 |
|
79 |
-
if st.sidebar.button("π Train New OCR Model"):
|
80 |
# Clear previous messages/widgets if button is clicked again
|
|
|
81 |
training_progress_bar.empty()
|
82 |
-
status_text.empty()
|
83 |
|
84 |
# Check for existence of CSVs and image directories
|
85 |
-
if not os.path.exists(TRAIN_CSV_PATH) or not os.path.
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
else:
|
91 |
-
status_text.
|
92 |
|
|
|
93 |
training_progress_bar_instance = training_progress_bar.progress(0.0, text="Training in progress. Please wait.")
|
94 |
|
|
|
|
|
|
|
|
|
|
|
95 |
try:
|
96 |
-
train_df =
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
# Create DataLoader instances
|
114 |
-
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) # num_workers=0 for Windows
|
115 |
-
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
|
116 |
-
|
117 |
-
# Train the model, passing the progress callback
|
118 |
trained_ocr_model, training_history = train_ocr_model(
|
119 |
-
|
120 |
-
train_loader,
|
121 |
-
test_loader,
|
122 |
-
char_indexer, # Pass char_indexer for CER calculation
|
123 |
epochs=NUM_EPOCHS,
|
124 |
device=device,
|
125 |
-
progress_callback=
|
126 |
)
|
|
|
|
|
127 |
|
128 |
-
# Ensure the directory for saving the model exists
|
129 |
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
|
130 |
save_ocr_model(trained_ocr_model, MODEL_SAVE_PATH)
|
131 |
-
status_text.success(f"
|
132 |
-
|
133 |
-
# Display training history chart
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
st.sidebar.markdown("**Exact Match Accuracy over Epochs**")
|
160 |
-
st.sidebar.line_chart(
|
161 |
-
history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']]
|
162 |
-
)
|
163 |
-
st.sidebar.caption("Higher exact match accuracy indicates more perfectly recognized names.")
|
164 |
-
|
165 |
-
# Update the global model instance to the newly trained one for immediate inference
|
166 |
-
ocr_model = trained_ocr_model
|
167 |
-
ocr_model.eval()
|
168 |
|
169 |
except Exception as e:
|
170 |
status_text.error(f"An error occurred during training: {e}")
|
171 |
-
|
|
|
172 |
|
173 |
# --- Main Content: Name Prediction ---
|
174 |
st.header("Predict Your Handwritten Name")
|
@@ -180,22 +189,19 @@ if uploaded_file is not None:
|
|
180 |
try:
|
181 |
# Open the uploaded image
|
182 |
image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale
|
183 |
-
|
|
|
184 |
st.write("---")
|
185 |
st.write("Processing and Recognizing...")
|
186 |
|
187 |
# Preprocess the image for the model using utils_ocr function
|
188 |
processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
|
189 |
-
|
190 |
-
#
|
191 |
-
ocr_model.eval() # Ensure model is in evaluation mode
|
192 |
with torch.no_grad(): # Disable gradient calculation for inference
|
193 |
-
output =
|
194 |
-
|
195 |
-
|
196 |
-
# It returns a list of strings, so get the first element for single image inference.
|
197 |
-
predicted_texts = ctc_greedy_decode(output, char_indexer)
|
198 |
-
predicted_text = predicted_texts[0] # Get the first (and only) prediction
|
199 |
|
200 |
st.success(f"Recognized Text: **{predicted_text}**")
|
201 |
|
@@ -205,222 +211,9 @@ if uploaded_file is not None:
|
|
205 |
"- Ensure the handwritten text is clear and on a clean background.\n"
|
206 |
"- Only include one name/word per image.\n"
|
207 |
"- The model is trained on specific characters. Unusual symbols might not be recognized.")
|
208 |
-
st.
|
209 |
|
210 |
st.markdown("""
|
211 |
---
|
212 |
*Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
|
213 |
-
=======
|
214 |
-
# app.py
|
215 |
-
|
216 |
-
import streamlit as st
|
217 |
-
import pandas as pd
|
218 |
-
import numpy as np
|
219 |
-
from PIL import Image
|
220 |
-
import torch
|
221 |
-
import torch.nn.functional as F # Added F for log_softmax in inference
|
222 |
-
import torchvision.transforms as transforms
|
223 |
-
import os
|
224 |
-
import traceback # For detailed error logging
|
225 |
-
|
226 |
-
# Import custom modules
|
227 |
-
from config import CHARS, BLANK_TOKEN, IMG_HEIGHT, TRAIN_CSV_PATH, TEST_CSV_PATH, \
|
228 |
-
TRAIN_IMAGES_DIR, TEST_IMAGES_DIR, MODEL_SAVE_PATH, NUM_CLASSES, NUM_EPOCHS, BATCH_SIZE
|
229 |
-
from data_handler_ocr import CharIndexer, OCRDataset
|
230 |
-
from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
|
231 |
-
from utils_ocr import preprocess_user_image_for_ocr
|
232 |
-
|
233 |
-
# --- Streamlit App Setup ---
|
234 |
-
st.set_page_config(page_title="Handwritten Name Recognizer", layout="centered")
|
235 |
-
|
236 |
-
st.title("π Handwritten Name Recognition (OCR)")
|
237 |
-
st.markdown("""
|
238 |
-
This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
|
239 |
-
Optical Character Recognition (OCR) on handwritten names. You can upload an image
|
240 |
-
of a handwritten name for prediction or train a new model using the provided dataset.
|
241 |
-
|
242 |
-
**Note:** Training a robust OCR model can be time-consuming.
|
243 |
-
""")
|
244 |
-
|
245 |
-
# --- Initialize CharIndexer ---
|
246 |
-
# The CHARS variable should contain all possible characters your model can recognize.
|
247 |
-
# Make sure it's comprehensive based on your dataset.
|
248 |
-
char_indexer = CharIndexer(CHARS, BLANK_TOKEN)
|
249 |
-
# For robustness, it's best to always use char_indexer.num_classes
|
250 |
-
# If NUM_CLASSES from config is used to initialize CRNN, ensure it matches char_indexer.num_classes
|
251 |
-
|
252 |
-
# --- Model Loading / Initialization ---
|
253 |
-
@st.cache_resource # Cache the model to prevent reloading on every rerun
|
254 |
-
def get_and_load_ocr_model_cached(num_classes, model_path):
|
255 |
-
"""
|
256 |
-
Initializes the OCR model and attempts to load a pre-trained model.
|
257 |
-
If no pre-trained model exists, a new model instance is returned.
|
258 |
-
"""
|
259 |
-
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
|
260 |
-
|
261 |
-
if os.path.exists(model_path):
|
262 |
-
st.sidebar.info("Loading pre-trained OCR model...")
|
263 |
-
try:
|
264 |
-
# Load model to CPU first, then move to device
|
265 |
-
model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
266 |
-
st.sidebar.success("OCR model loaded successfully!")
|
267 |
-
except Exception as e:
|
268 |
-
st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
|
269 |
-
# If loading fails, re-initialize an untrained model
|
270 |
-
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
|
271 |
-
else:
|
272 |
-
st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
|
273 |
-
|
274 |
-
return model_instance
|
275 |
-
|
276 |
-
# Get the model instance
|
277 |
-
ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
|
278 |
-
# Determine the device (GPU if available, else CPU)
|
279 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
280 |
-
ocr_model.to(device)
|
281 |
-
ocr_model.eval() # Set model to evaluation mode for inference by default
|
282 |
-
|
283 |
-
# --- Sidebar for Model Training ---
|
284 |
-
st.sidebar.header("Model Training (Optional)")
|
285 |
-
st.sidebar.markdown("If you want to train a new model or no model is found:")
|
286 |
-
|
287 |
-
# Initialize Streamlit widgets outside the button block
|
288 |
-
training_progress_bar = st.sidebar.empty() # Placeholder for progress bar
|
289 |
-
status_text = st.sidebar.empty() # Placeholder for status messages
|
290 |
-
|
291 |
-
if st.sidebar.button("π Train New OCR Model"):
|
292 |
-
# Clear previous messages/widgets if button is clicked again
|
293 |
-
training_progress_bar.empty()
|
294 |
-
status_text.empty()
|
295 |
-
|
296 |
-
# Check for existence of CSVs and image directories
|
297 |
-
if not os.path.exists(TRAIN_CSV_PATH) or not os.path.exists(TEST_CSV_PATH) or \
|
298 |
-
not os.path.isdir(TRAIN_IMAGES_DIR) or not os.path.isdir(TEST_IMAGES_DIR):
|
299 |
-
status_text.error(f"""Dataset files or image directories not found.
|
300 |
-
Please ensure '{TRAIN_CSV_PATH}', '{TEST_CSV_PATH}', and directories '{TRAIN_IMAGES_DIR}'
|
301 |
-
and '{TEST_IMAGES_DIR}' exist. Refer to your project structure.""")
|
302 |
-
else:
|
303 |
-
status_text.write(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
|
304 |
-
|
305 |
-
training_progress_bar_instance = training_progress_bar.progress(0.0, text="Training in progress. Please wait.")
|
306 |
-
|
307 |
-
try:
|
308 |
-
train_df = pd.read_csv(TRAIN_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
|
309 |
-
test_df = pd.read_csv(TEST_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
|
310 |
-
|
311 |
-
# Define standard image transforms for consistency
|
312 |
-
train_transform = transforms.Compose([
|
313 |
-
transforms.Resize((IMG_HEIGHT, 100)), # Resize to fixed height, width will be 100 (adjust as needed for variable width)
|
314 |
-
transforms.ToTensor(), # Converts PIL Image to PyTorch Tensor (H, W) -> (C, H, W), normalizes to [0,1]
|
315 |
-
])
|
316 |
-
test_transform = transforms.Compose([
|
317 |
-
transforms.Resize((IMG_HEIGHT, 100)), # Same transformation as train
|
318 |
-
transforms.ToTensor(),
|
319 |
-
])
|
320 |
-
|
321 |
-
# Create dataset instances
|
322 |
-
train_dataset = OCRDataset(dataframe=train_df, char_indexer=char_indexer, image_dir=TRAIN_IMAGES_DIR, transform=train_transform)
|
323 |
-
test_dataset = OCRDataset(dataframe=test_df, char_indexer=char_indexer, image_dir=TEST_IMAGES_DIR, transform=test_transform)
|
324 |
-
|
325 |
-
# Create DataLoader instances
|
326 |
-
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) # num_workers=0 for Windows
|
327 |
-
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
|
328 |
-
|
329 |
-
# Train the model, passing the progress callback
|
330 |
-
trained_ocr_model, training_history = train_ocr_model(
|
331 |
-
ocr_model, # Pass the initialized model instance
|
332 |
-
train_loader,
|
333 |
-
test_loader,
|
334 |
-
char_indexer, # Pass char_indexer for CER calculation
|
335 |
-
epochs=NUM_EPOCHS,
|
336 |
-
device=device,
|
337 |
-
progress_callback=training_progress_bar_instance.progress # Pass the instance's progress method
|
338 |
-
)
|
339 |
-
|
340 |
-
# Ensure the directory for saving the model exists
|
341 |
-
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
|
342 |
-
save_ocr_model(trained_ocr_model, MODEL_SAVE_PATH)
|
343 |
-
status_text.success(f"Model training complete and saved to `{MODEL_SAVE_PATH}`!")
|
344 |
-
|
345 |
-
# Display training history chart
|
346 |
-
st.sidebar.subheader("Training History Plots")
|
347 |
-
|
348 |
-
history_df = pd.DataFrame({
|
349 |
-
'Epoch': range(1, len(training_history['train_loss']) + 1),
|
350 |
-
'Train Loss': training_history['train_loss'],
|
351 |
-
'Test Loss': training_history['test_loss'],
|
352 |
-
'Test CER (%)': [cer * 100 for cer in training_history['test_cer']], # Convert CER to percentage for display
|
353 |
-
'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']] # Convert to percentage
|
354 |
-
})
|
355 |
-
|
356 |
-
# Plot 1: Training and Test Loss
|
357 |
-
st.sidebar.markdown("**Loss over Epochs**")
|
358 |
-
st.sidebar.line_chart(
|
359 |
-
history_df.set_index('Epoch')[['Train Loss', 'Test Loss']]
|
360 |
-
)
|
361 |
-
st.sidebar.caption("Lower loss indicates better model performance.")
|
362 |
-
|
363 |
-
# Plot 2: Character Error Rate (CER)
|
364 |
-
st.sidebar.markdown("**Character Error Rate (CER) over Epochs**")
|
365 |
-
st.sidebar.line_chart(
|
366 |
-
history_df.set_index('Epoch')[['Test CER (%)']]
|
367 |
-
)
|
368 |
-
st.sidebar.caption("Lower CER indicates fewer character errors (0% is perfect).")
|
369 |
-
|
370 |
-
# Plot 3: Exact Match Accuracy
|
371 |
-
st.sidebar.markdown("**Exact Match Accuracy over Epochs**")
|
372 |
-
st.sidebar.line_chart(
|
373 |
-
history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']]
|
374 |
-
)
|
375 |
-
st.sidebar.caption("Higher exact match accuracy indicates more perfectly recognized names.")
|
376 |
-
|
377 |
-
# Update the global model instance to the newly trained one for immediate inference
|
378 |
-
ocr_model = trained_ocr_model
|
379 |
-
ocr_model.eval()
|
380 |
-
|
381 |
-
except Exception as e:
|
382 |
-
status_text.error(f"An error occurred during training: {e}")
|
383 |
-
st.sidebar.text(traceback.format_exc()) # Show full traceback for debugging
|
384 |
-
|
385 |
-
# --- Main Content: Name Prediction ---
|
386 |
-
st.header("Predict Your Handwritten Name")
|
387 |
-
st.markdown("Upload a clear image of a single handwritten name or word.")
|
388 |
-
|
389 |
-
uploaded_file = st.file_uploader("πΌοΈ Choose an image...", type=["png", "jpg", "jpeg"])
|
390 |
-
|
391 |
-
if uploaded_file is not None:
|
392 |
-
try:
|
393 |
-
# Open the uploaded image
|
394 |
-
image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale
|
395 |
-
st.image(image_pil, caption="Uploaded Image", use_column_width=True)
|
396 |
-
st.write("---")
|
397 |
-
st.write("Processing and Recognizing...")
|
398 |
-
|
399 |
-
# Preprocess the image for the model using utils_ocr function
|
400 |
-
processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
|
401 |
-
|
402 |
-
# Make prediction
|
403 |
-
ocr_model.eval() # Ensure model is in evaluation mode
|
404 |
-
with torch.no_grad(): # Disable gradient calculation for inference
|
405 |
-
output = ocr_model(processed_image_tensor) # (sequence_length, batch_size, num_classes)
|
406 |
-
|
407 |
-
# ctc_greedy_decode expects (sequence_length, batch_size, num_classes)
|
408 |
-
# It returns a list of strings, so get the first element for single image inference.
|
409 |
-
predicted_texts = ctc_greedy_decode(output, char_indexer)
|
410 |
-
predicted_text = predicted_texts[0] # Get the first (and only) prediction
|
411 |
-
|
412 |
-
st.success(f"Recognized Text: **{predicted_text}**")
|
413 |
-
|
414 |
-
except Exception as e:
|
415 |
-
st.error(f"Error processing image or recognizing text: {e}")
|
416 |
-
st.info("π‘ **Tips for best results:**\n"
|
417 |
-
"- Ensure the handwritten text is clear and on a clean background.\n"
|
418 |
-
"- Only include one name/word per image.\n"
|
419 |
-
"- The model is trained on specific characters. Unusual symbols might not be recognized.")
|
420 |
-
st.text(traceback.format_exc())
|
421 |
-
|
422 |
-
st.markdown("""
|
423 |
-
---
|
424 |
-
*Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
|
425 |
-
>>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
|
426 |
""")
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
# app.py
|
3 |
|
4 |
+
import os
|
5 |
+
# CRITICAL FIX: Disable Streamlit's file watcher to prevent conflicts with PyTorch
|
6 |
+
# This MUST be the first thing, before any other imports or Streamlit calls
|
7 |
+
os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false"
|
8 |
+
|
9 |
import streamlit as st
|
10 |
import pandas as pd
|
11 |
import numpy as np
|
|
|
13 |
import torch
|
14 |
import torch.nn.functional as F # Added F for log_softmax in inference
|
15 |
import torchvision.transforms as transforms
|
|
|
16 |
import traceback # For detailed error logging
|
17 |
|
18 |
+
# Import all necessary configuration values from config.py
|
19 |
+
from config import (
|
20 |
+
IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN, VOCABULARY, BLANK_TOKEN_SYMBOL,
|
21 |
+
TRAIN_CSV_PATH, TEST_CSV_PATH, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
|
22 |
+
MODEL_SAVE_PATH, BATCH_SIZE, NUM_EPOCHS
|
23 |
+
)
|
24 |
+
|
25 |
+
# Import classes and functions from data_handler_ocr.py and model_ocr.py
|
26 |
+
from data_handler_ocr import CharIndexer, OCRDataset, ocr_collate_fn, load_ocr_dataframes, create_ocr_dataloaders
|
27 |
from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
|
28 |
+
from utils_ocr import preprocess_user_image_for_ocr, binarize_image, resize_image_for_ocr, normalize_image_for_model # Ensure these are imported if needed
|
29 |
+
|
30 |
+
# --- Global Variables ---
|
31 |
+
# These will hold the model and char_indexer instance after training or loading
|
32 |
+
trained_ocr_model = None
|
33 |
+
char_indexer = None
|
34 |
+
training_history = None
|
35 |
+
# Determine the device (GPU if available, else CPU)
|
36 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
|
38 |
# --- Streamlit App Setup ---
|
39 |
+
st.set_page_config(layout="wide", page_title="Handwritten Name OCR App") # Changed to wide layout for better display
|
40 |
|
41 |
+
st.title("π Handwritten Name Recognition (OCR) App") # Updated title for consistency
|
42 |
st.markdown("""
|
43 |
This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
|
44 |
Optical Character Recognition (OCR) on handwritten names. You can upload an image
|
|
|
48 |
""")
|
49 |
|
50 |
# --- Initialize CharIndexer ---
|
51 |
+
# CRITICAL FIX: Initialize CharIndexer with VOCABULARY and BLANK_TOKEN_SYMBOL
|
52 |
+
# This resolves the ValueError: "Blank token symbol '95' not found..."
|
53 |
+
char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
|
|
|
|
|
54 |
|
55 |
# --- Model Loading / Initialization ---
|
56 |
@st.cache_resource # Cache the model to prevent reloading on every rerun
|
|
|
79 |
# Get the model instance
|
80 |
ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
|
81 |
# Determine the device (GPU if available, else CPU)
|
|
|
82 |
ocr_model.to(device)
|
83 |
ocr_model.eval() # Set model to evaluation mode for inference by default
|
84 |
|
|
|
87 |
st.sidebar.markdown("If you want to train a new model or no model is found:")
|
88 |
|
89 |
# Initialize Streamlit widgets outside the button block
|
90 |
+
training_progress_bar = st.sidebar.empty() # Placeholder for progress bar in sidebar
|
91 |
+
status_text = st.sidebar.empty() # Placeholder for status messages in sidebar
|
92 |
|
93 |
+
if st.sidebar.button("π Train New OCR Model"): # Keep button in sidebar as per user's last provided code
|
94 |
# Clear previous messages/widgets if button is clicked again
|
95 |
+
training_progress_bar.progress(0) # Reset progress bar
|
96 |
training_progress_bar.empty()
|
97 |
+
status_text.empty() # Clear status text
|
98 |
|
99 |
# Check for existence of CSVs and image directories
|
100 |
+
if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
|
101 |
+
status_text.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found!")
|
102 |
+
elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
|
103 |
+
status_text.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
|
104 |
+
"Evaluation might be affected or skipped. Please ensure all data paths are correct.")
|
105 |
else:
|
106 |
+
status_text.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
|
107 |
|
108 |
+
# Define the progress bar instance here for the callback
|
109 |
training_progress_bar_instance = training_progress_bar.progress(0.0, text="Training in progress. Please wait.")
|
110 |
|
111 |
+
def update_progress_callback_sidebar(value, text):
|
112 |
+
"""Callback function to update Streamlit progress bar in sidebar."""
|
113 |
+
training_progress_bar_instance.progress(int(value * 100))
|
114 |
+
status_text.text(text) # Update status text in sidebar
|
115 |
+
|
116 |
try:
|
117 |
+
train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
|
118 |
+
status_text.success("Training and Test DataFrames loaded successfully.")
|
119 |
+
|
120 |
+
char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
|
121 |
+
status_text.success(f"CharIndexer initialized with {char_indexer.num_classes} classes.")
|
122 |
+
|
123 |
+
# Pass the limits to create_ocr_dataloaders
|
124 |
+
train_loader, test_loader = create_ocr_dataloaders(
|
125 |
+
train_df, test_df, char_indexer, BATCH_SIZE
|
126 |
+
)
|
127 |
+
status_text.success("DataLoaders created successfully.")
|
128 |
+
|
129 |
+
ocr_model_for_training = CRNN(num_classes=NUM_CLASSES) # Create a new instance for training
|
130 |
+
ocr_model_for_training.to(device)
|
131 |
+
status_text.info(f"CRNN model initialized and moved to {device}.")
|
132 |
+
|
133 |
+
status_text.write("Training in progress... This may take a while.")
|
|
|
|
|
|
|
|
|
|
|
134 |
trained_ocr_model, training_history = train_ocr_model(
|
135 |
+
model=ocr_model_for_training, # Pass the new instance
|
136 |
+
train_loader=train_loader,
|
137 |
+
test_loader=test_loader,
|
138 |
+
char_indexer=char_indexer, # Pass char_indexer for CER calculation
|
139 |
epochs=NUM_EPOCHS,
|
140 |
device=device,
|
141 |
+
progress_callback=update_progress_callback_sidebar # Pass the sidebar callback
|
142 |
)
|
143 |
+
status_text.success("OCR model training finished!")
|
144 |
+
update_progress_callback_sidebar(1.0, "Training complete!")
|
145 |
|
|
|
146 |
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
|
147 |
save_ocr_model(trained_ocr_model, MODEL_SAVE_PATH)
|
148 |
+
status_text.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
|
149 |
+
|
150 |
+
# Display training history chart in the main section, not sidebar
|
151 |
+
if training_history:
|
152 |
+
st.subheader("Training History Plots")
|
153 |
+
history_df = pd.DataFrame({
|
154 |
+
'Epoch': range(1, len(training_history['train_loss']) + 1),
|
155 |
+
'Train Loss': training_history['train_loss'],
|
156 |
+
'Test Loss': training_history['test_loss'],
|
157 |
+
'Test CER (%)': [cer * 100 for cer in training_history['test_cer']],
|
158 |
+
'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']]
|
159 |
+
})
|
160 |
+
|
161 |
+
st.markdown("**Loss over Epochs**")
|
162 |
+
st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
|
163 |
+
st.caption("Lower loss indicates better model performance.")
|
164 |
+
|
165 |
+
st.markdown("**Character Error Rate (CER) over Epochs**")
|
166 |
+
st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
|
167 |
+
st.caption("Lower CER indicates fewer character errors (0% is perfect).")
|
168 |
+
|
169 |
+
st.markdown("**Exact Match Accuracy over Epochs**")
|
170 |
+
st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
|
171 |
+
st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
|
172 |
+
|
173 |
+
st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
|
174 |
+
st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
|
175 |
+
st.caption("CER should decrease, Accuracy should increase.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
except Exception as e:
|
178 |
status_text.error(f"An error occurred during training: {e}")
|
179 |
+
status_text.exception(e) # Display full traceback in Streamlit
|
180 |
+
update_progress_callback_sidebar(0.0, "Training failed!")
|
181 |
|
182 |
# --- Main Content: Name Prediction ---
|
183 |
st.header("Predict Your Handwritten Name")
|
|
|
189 |
try:
|
190 |
# Open the uploaded image
|
191 |
image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale
|
192 |
+
# Use use_container_width for deprecation warning fix
|
193 |
+
st.image(image_pil, caption="Uploaded Image", use_container_width=True)
|
194 |
st.write("---")
|
195 |
st.write("Processing and Recognizing...")
|
196 |
|
197 |
# Preprocess the image for the model using utils_ocr function
|
198 |
processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
|
199 |
+
|
200 |
+
trained_ocr_model.eval() # Ensure model is in evaluation mode
|
|
|
201 |
with torch.no_grad(): # Disable gradient calculation for inference
|
202 |
+
output = trained_ocr_model(processed_image_tensor) # (sequence_length, batch_size, num_classes)
|
203 |
+
predicted_texts = ctc_greedy_decode(output, char_indexer)
|
204 |
+
predicted_text = predicted_texts[0] # Get the first (and only) prediction
|
|
|
|
|
|
|
205 |
|
206 |
st.success(f"Recognized Text: **{predicted_text}**")
|
207 |
|
|
|
211 |
"- Ensure the handwritten text is clear and on a clean background.\n"
|
212 |
"- Only include one name/word per image.\n"
|
213 |
"- The model is trained on specific characters. Unusual symbols might not be recognized.")
|
214 |
+
st.exception(e) # Display full traceback in Streamlit
|
215 |
|
216 |
st.markdown("""
|
217 |
---
|
218 |
*Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
""")
|