Spaces:
Runtime error
Runtime error
File size: 16,579 Bytes
442eace d7fcda8 442eace af3f1e7 442eace af3f1e7 442eace af3f1e7 442eace af3f1e7 6388999 442eace af3f1e7 442eace 6388999 af3f1e7 442eace af3f1e7 6388999 04f0235 6388999 04f0235 6388999 442eace 6388999 af3f1e7 6388999 af3f1e7 6388999 af3f1e7 0385397 6388999 af3f1e7 6388999 af3f1e7 6388999 af3f1e7 6388999 af3f1e7 6388999 af3f1e7 6388999 af3f1e7 6388999 0385397 6388999 0385397 6388999 af3f1e7 6388999 af3f1e7 6388999 af3f1e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 |
# -*- coding: utf-8 -*-
# app.py
import os
import streamlit as st
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import traceback # Ensure this is imported
# Import all necessary configuration values from config.py
# Wrap this import in a try-except
try:
from config import (
IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN, VOCABULARY, BLANK_TOKEN_SYMBOL,
TRAIN_CSV_PATH, TEST_CSV_PATH, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
MODEL_SAVE_PATH, BATCH_SIZE, NUM_EPOCHS
)
except Exception as e:
st.error(f"FATAL ERROR: Could not load config.py. Please check your config.py file for errors. Details: {e}")
st.stop() # Stop the app if config fails to load
# Import classes and functions from data_handler_ocr.py and model_ocr.py
# Wrap these imports in a try-except
try:
from data_handler_ocr import CharIndexer, OCRDataset, ocr_collate_fn, load_ocr_dataframes, create_ocr_dataloaders
from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
from utils_ocr import preprocess_user_image_for_ocr, binarize_image, resize_image_for_ocr, normalize_image_for_model
except Exception as e:
st.error(f"FATAL ERROR: Could not load core modules (data_handler_ocr.py, model_ocr.py, utils_ocr.py). Please check these files for errors. Details: {e}")
st.stop() # Stop the app if core modules fail to load
# --- Global Variables ---
# Initialize training_history in Streamlit's session state to persist across reruns
if 'training_history' not in st.session_state:
st.session_state.training_history = None
# Initialize ocr_model and char_indexer as None; they will be populated below
ocr_model = None
char_indexer = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Streamlit App Setup ---
st.set_page_config(layout="wide", page_title="Handwritten Name OCR App")
# Main Title and Description (Centered)
main_title_col1, main_title_col2, main_title_col3 = st.columns([1, 3, 1])
with main_title_col2:
st.title("π Handwritten Name Recognition (OCR) App")
# --- Initialize CharIndexer ---
try:
char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
except Exception as e:
st.error(f"FATAL ERROR: Could not initialize CharIndexer. Check config.py (VOCABULARY, BLANK_TOKEN_SYMBOL) and data_handler_ocr.py (CharIndexer class). Details: {e}")
st.stop()
# --- Model Loading / Initialization (Cached and Global) ---
@st.cache_resource
def get_and_load_ocr_model_cached_internal(num_classes, model_path):
"""
Initializes the OCR model and attempts to load a pre-trained model.
Returns (model_instance, message_type, message_text)
"""
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
message_type = "warning"
message_text = "No pre-trained OCR model found. Please train a model using the 'Train & Evaluate' tab."
if os.path.exists(model_path):
try:
model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
message_type = "success"
message_text = "OCR model loaded successfully!"
except Exception as e:
message_type = "error"
message_text = f"Error loading model from '{model_path}' during app startup: {e}. A new model will be initialized."
# If loading fails, re-initialize to a fresh model to avoid issues.
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
return model_instance, message_type, message_text
# Display messages OUTSIDE the cached function
try:
loaded_model_instance, load_msg_type, load_msg_text = get_and_load_ocr_model_cached_internal(char_indexer.num_classes, MODEL_SAVE_PATH)
# Assign to global ocr_model
ocr_model = loaded_model_instance
# Display status messages as toasts
if load_msg_type == "success":
st.toast(load_msg_text, icon="β
")
elif load_msg_type == "warning":
st.toast(load_msg_text, icon="β οΈ")
elif load_msg_type == "error":
st.toast(load_msg_text, icon="π¨")
if ocr_model is not None:
ocr_model.to(device)
ocr_model.eval() # Set model to evaluation mode for inference by default
else:
st.error("Model instance is None after cached load. Prediction will not be available.")
except Exception as e:
st.error(f"FATAL ERROR: Could not initialize or load OCR model during app startup (outer block). Check model_ocr.py (CRNN class) or your saved model file. Details: {e}")
st.stop()
# --- Define Tabs ---
tabs_col1, tabs_col2, tabs_col3 = st.columns([1, 3, 1])
with tabs_col2:
tab1, tab2, tab3 = st.tabs([" π¨οΈ Project Description", " π Predict Name", " π Train & Evaluate"])
# --- Tab 1: Project Description ---
with tab1:
st.markdown("""
This application implements a Handwritten Name Recognition (OCR) system using a Convolutional Recurrent Neural Network (CRNN) built with PyTorch.
Its core aim is to accurately convert handwritten text from images into digital format, providing a user-friendly interface via Streamlit.
Here are some helpful resources related to this project:
""")
st.markdown("""
**[π Project Documentation ](https://drive.google.com/file/d/1HBrQT_UnzNLdEsouW9wMk4alAeCsQxZb/view?usp=sharing)**
**[ποΈ Demo Presentation ](https://drive.google.com/file/d/1j_S8cijxy6zxIn3cWg6tuLPNWB_7nwdI/view?usp=sharing)**
**[π Dataset (from Kaggle)](https://www.kaggle.com/datasets/landlord/handwriting-recognition)**
**[π Github Repository ](https://github.com/marianeft/handwritten_name_ocr_app)**
""")
# --- Tab 2: Predict Name (Main Content: Prediction Section) ---
with tab2:
st.markdown("Upload a clear image of a single handwritten name or word for recognition.")
# Check the global ocr_model for prediction availability
if ocr_model is None:
st.warning("Model not loaded. Please train or load a model in the 'Train & Evaluate' tab before attempting prediction.")
else:
uploaded_file = st.file_uploader("πΌοΈ Choose an image...", type=["png", "jpg", "jpeg", "jfif"])
if uploaded_file is not None:
try:
image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale
st.image(image_pil, caption="Uploaded Image", use_container_width=True)
st.write("---")
st.write("Processing and Recognizing...")
processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
ocr_model.eval() # Ensure model is in eval mode for prediction
with torch.no_grad():
output = ocr_model(processed_image_tensor)
predicted_texts = ctc_greedy_decode(output, char_indexer)
predicted_text = predicted_texts[0]
st.success(f"Recognized Text: **{predicted_text}**")
except Exception as e:
st.error(f"Error processing image or recognizing text: {e}")
st.info("π‘ **Tips for best results:**\n"
"- Ensure the handwritten text is clear and on a clean background.\n"
"- Only include one name/word per image.\n"
"- The model is trained on specific characters. Unusual symbols might not be recognized.")
st.exception(e) # Display full traceback for debugging
# --- Tab 3: Train & Evaluate ---
with tab3:
# --- Model Training Section ---
st.subheader("Train OCR Model")
st.write("Click the button below to start training the OCR model.")
# Progress bar and label for training within this tab
progress_message_placeholder = st.empty()
progress_bar_placeholder = st.progress(0)
def update_progress_callback(value, text):
progress_bar_placeholder.progress(int(value * 100))
progress_message_placeholder.info(text) # Use info for dynamic messages
if st.button("π Start Training"):
progress_message_placeholder.empty() # Clear previous messages
progress_bar_placeholder.progress(0) # Reset progress bar
if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
st.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found! Please check file paths and ensure data is uploaded correctly.")
elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
st.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
"Evaluation might be affected or skipped. Please ensure all data paths are correct and data is uploaded.")
else:
progress_message_placeholder.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
try:
train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
progress_message_placeholder.success("Training and Test DataFrames loaded successfully.")
progress_message_placeholder.info(f"Train DataFrame size: {len(train_df)} samples")
progress_message_placeholder.info(f"Test DataFrame size: {len(test_df)} samples")
if len(test_df) == 0:
progress_message_placeholder.error("ERROR: Test DataFrame is empty! Evaluation cannot proceed. Check TEST_CSV_PATH and TEST_IMAGES_DIR.")
if len(train_df) == 0:
progress_message_placeholder.error("ERROR: Train DataFrame is empty! Training cannot proceed. Check TRAIN_CSV_PATH and TRAIN_IMAGES_DIR.")
if len(train_df) == 0 or len(test_df) == 0: # Stop if critical data is missing
st.stop() # Added st.stop for critical data missing scenario
char_indexer_for_training = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
progress_message_placeholder.success(f"CharIndexer initialized with {char_indexer_for_training.num_classes} classes.")
train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer_for_training, BATCH_SIZE)
progress_message_placeholder.success("DataLoaders created successfully.")
ocr_model_for_training = CRNN(num_classes=char_indexer_for_training.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
ocr_model_for_training.to(device)
ocr_model_for_training.train() # Set to train mode before passing
progress_message_placeholder.write("Training in progress... This may take a while.")
ocr_model_for_training, history_result = train_ocr_model(
model=ocr_model_for_training, # Pass the local ocr_model_for_training instance
train_loader=train_loader,
test_loader=test_loader,
char_indexer=char_indexer_for_training,
epochs=NUM_EPOCHS,
device=device,
progress_callback=update_progress_callback
)
st.session_state.training_history = history_result # Save history to session state
progress_message_placeholder.success("OCR model training finished!")
update_progress_callback(1.0, "Training complete!")
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
save_ocr_model(ocr_model_for_training, MODEL_SAVE_PATH) # Save the now trained ocr_model_for_training
progress_message_placeholder.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
# Crucial: Update the global ocr_model with the newly trained one
ocr_model = ocr_model_for_training
ocr_model.eval() # Set to eval mode for subsequent predictions
except Exception as e:
progress_message_placeholder.error(f"An error occurred during training: {e}")
st.exception(e) # This will print a detailed traceback in the Streamlit UI
update_progress_callback(0.0, "Training failed!")
st.write("---")
# --- Model Loading Section ---
st.subheader("Load Pre-trained Model")
st.write("If you have a saved model, you can load it here instead of training.")
if st.button("πΎ Load Model"):
if os.path.exists(MODEL_SAVE_PATH):
try:
loaded_model_instance = CRNN(num_classes=char_indexer.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
load_ocr_model(loaded_model_instance, MODEL_SAVE_PATH)
loaded_model_instance.to(device)
ocr_model = loaded_model_instance # Update global model reference
ocr_model.eval() # Set to eval mode after loading
st.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`")
# For simplicity, training history is only populated after a training run.
# If you need to load history with the model, it would need to be saved separately.
except Exception as e:
st.error(f"Error loading model: {e}")
st.exception(e)
else:
st.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.")
st.write("---")
# --- Training History Plots Section ---
st.subheader("Training History Plots")
if st.session_state.training_history: # Check if history exists in session state
history_df = pd.DataFrame({
'Epoch': range(1, len(st.session_state.training_history['train_loss']) + 1),
'Train Loss': st.session_state.training_history['train_loss'],
'Test Loss': st.session_state.training_history['test_loss'],
'Test CER (%)': [cer * 100 for cer in st.session_state.training_history['test_cer']],
'Test Exact Match Accuracy (%)': [acc * 100 for acc in st.session_state.training_history['test_exact_match_accuracy']]
})
st.markdown("**Loss over Epochs**")
st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
st.caption("Lower loss indicates better model performance.")
st.markdown("**Character Error Rate (CER) over Epochs**")
st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
st.caption("Lower CER indicates fewer character errors (0% is perfect).")
st.markdown("**Exact Match Accuracy over Epochs**")
st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
st.caption("CER should decrease, Accuracy should increase.")
else:
st.info("Train the model first to see training history plots here.")
# --- Final Footer (Centered) ---
footer_col1, footer_col2, footer_col3 = st.columns([1, 3, 1])
with footer_col2:
st.markdown("""
---
*Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
""")
|