marianeft's picture
Update UI
af3f1e7
raw
history blame
16.4 kB
# -*- coding: utf-8 -*-
# app.py
import os
# CRITICAL FIX: Disable Streamlit's file watcher to prevent conflicts with PyTorch
# This MUST be the first thing, before any other imports or Streamlit calls
os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false"
import streamlit as st
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import traceback # Ensure this is imported
# Import all necessary configuration values from config.py
# Wrap this import in a try-except
try:
from config import (
IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN, VOCABULARY, BLANK_TOKEN_SYMBOL,
TRAIN_CSV_PATH, TEST_CSV_PATH, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
MODEL_SAVE_PATH, BATCH_SIZE, NUM_EPOCHS
)
except Exception as e:
st.error(f"FATAL ERROR: Could not load config.py. Please check your config.py file for errors. Details: {e}")
st.stop() # Stop the app if config fails to load
# Import classes and functions from data_handler_ocr.py and model_ocr.py
# Wrap these imports in a try-except
try:
from data_handler_ocr import CharIndexer, OCRDataset, ocr_collate_fn, load_ocr_dataframes, create_ocr_dataloaders
from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
from utils_ocr import preprocess_user_image_for_ocr, binarize_image, resize_image_for_ocr, normalize_image_for_model
except Exception as e:
st.error(f"FATAL ERROR: Could not load core modules (data_handler_ocr.py, model_ocr.py, utils_ocr.py). Please check these files for errors. Details: {e}")
st.stop() # Stop the app if core modules fail to load
# --- Global Variables ---
# Initialize training_history in Streamlit's session state to persist across reruns
if 'training_history' not in st.session_state:
st.session_state.training_history = None
ocr_model = None # Will be initialized by @st.cache_resource
char_indexer = None # Will be initialized below
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Streamlit App Setup ---
st.set_page_config(layout="wide", page_title="Handwritten Name OCR App")
col1, col2, col3 = st.columns([1, 3, 1])
with col2:
st.title("πŸ“ Handwritten Name Recognition (OCR) App")
# --- Initialize CharIndexer ---
# Wrap this in a try-except
try:
char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
except Exception as e:
st.error(f"FATAL ERROR: Could not initialize CharIndexer. Check config.py (VOCABULARY, BLANK_TOKEN_SYMBOL) and data_handler_ocr.py (CharIndexer class). Details: {e}")
st.stop()
# --- Define Tabs ---
col1, col2, col3 = st.columns([1, 3, 1])
with col2:
tab1, tab2, tab3 = st.tabs(["Project Description", "Predict Name", "Train & Evaluate"])
# --- Tab 1: Project Description ---
with tab1:
# Use columns for centering content within the tab
st.markdown("""
This application implements a Handwritten Name Recognition (OCR) system using a Convolutional Recurrent Neural Network (CRNN) built with PyTorch.
Its core aim is to accurately convert handwritten text from images into digital format, providing a user-friendly interface via Streamlit.
Here are some helpful resources related to this project:
""")
st.markdown("""
**[πŸ“ƒ Project Documentation ](https://drive.google.com/file/d/1HBrQT_UnzNLdEsouW9wMk4alAeCsQxZb/view?usp=sharing)**
**[🎞️ Demo Presentation ](https://drive.google.com/drive/folders/1rOmwyTJkDCsU-Wuh-_CzvQ9sdb_ci_kX?usp=sharing)**
**[πŸ“š Dataset (from Kaggle)](https://www.kaggle.com/datasets/landlord/handwriting-recognition)**
**[πŸ“‚ Github Repository ](https://github.com/marianeft/handwritten_name_ocr_app)**
""")
# --- Tab 2: Predict Name (Main Content: Prediction Section) ---
with tab2:
st.header("Predict on a New Image")
st.markdown("Upload a clear image of a single handwritten name or word for recognition.")
if ocr_model is None:
st.warning("Model not loaded. Please train or load a model in the 'Train & Evaluate' tab before attempting prediction.")
else:
uploaded_file = st.file_uploader("πŸ–ΌοΈ Choose an image...", type=["png", "jpg", "jpeg", "jfif"])
if uploaded_file is not None:
try:
image_pil = Image.open(uploaded_file).convert('L')
st.image(image_pil, caption="Uploaded Image", use_container_width=True)
st.write("---")
st.write("Processing and Recognizing...")
processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
ocr_model.eval() # Ensure model is in eval mode for prediction
with torch.no_grad():
output = ocr_model(processed_image_tensor)
predicted_texts = ctc_greedy_decode(output, char_indexer)
predicted_text = predicted_texts[0]
st.success(f"Recognized Text: **{predicted_text}**")
except Exception as e:
st.error(f"Error processing image or recognizing text: {e}")
st.info("πŸ’‘ **Tips for best results:**\n"
"- Ensure the handwritten text is clear and on a clean background.\n"
"- Only include one name/word per image.\n"
"- The model is trained on specific characters. Unusual symbols might not be recognized.")
st.exception(e)
# --- Tab 3: Train & Evaluate ---
with tab3:
st.header("Model Training and Evaluation")
st.markdown("Here you can train a new OCR model or load a pre-trained one.")
# --- Model Loading / Initialization (Cached) ---
@st.cache_resource # Cache the model to prevent reloading on every rerun
def get_and_load_ocr_model_cached(num_classes, model_path):
"""
Initializes the OCR model and attempts to load a pre-trained model.
If no pre-trained model exists, a new model instance is returned.
"""
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
if os.path.exists(model_path):
st.info("Loading pre-trained OCR model...")
try:
# Load model to CPU first, then move to device
model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
st.success("OCR model loaded successfully!")
except Exception as e:
st.error(f"Error loading model from '{model_path}': {e}. A new model will be initialized.")
# If loading fails, re-initialize an untrained model
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
else:
st.warning("No pre-trained OCR model found. Please train a model.")
return model_instance
# Wrap model loading in a try-except
try:
ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
ocr_model.to(device)
ocr_model.eval() # Set model to evaluation mode for inference by default
except Exception as e:
st.error(f"FATAL ERROR: Could not initialize or load OCR model. Check model_ocr.py (CRNN class) or your saved model file. Details: {e}")
st.stop()
# --- Model Training Section ---
st.subheader("1. Train OCR Model")
st.write("Click the button below to start training the OCR model.")
# Progress bar and label for training within this tab
progress_container = st.empty() # Container for dynamic messages and progress
progress_message_placeholder = st.empty()
progress_bar_placeholder = st.progress(0)
def update_progress_callback(value, text):
progress_bar_placeholder.progress(int(value * 100))
progress_message_placeholder.info(text) # Use info for dynamic messages
if st.button("πŸ“Š Start Training"):
progress_message_placeholder.empty() # Clear previous messages
progress_bar_placeholder.progress(0) # Reset progress bar
if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
st.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found! Please check file paths and ensure data is uploaded correctly.")
elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
st.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
"Evaluation might be affected or skipped. Please ensure all data paths are correct and data is uploaded.")
else:
progress_message_placeholder.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
try:
train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
progress_message_placeholder.success("Training and Test DataFrames loaded successfully.")
progress_message_placeholder.info(f"Train DataFrame size: {len(train_df)} samples")
progress_message_placeholder.info(f"Test DataFrame size: {len(test_df)} samples")
if len(test_df) == 0:
progress_message_placeholder.error("ERROR: Test DataFrame is empty! Evaluation cannot proceed. Check TEST_CSV_PATH and TEST_IMAGES_DIR.")
if len(train_df) == 0:
progress_message_placeholder.error("ERROR: Train DataFrame is empty! Training cannot proceed. Check TRAIN_CSV_PATH and TRAIN_IMAGES_DIR.")
if len(train_df) == 0 or len(test_df) == 0: # Stop if critical data is missing
st.stop() # Added st.stop for critical data missing scenario
char_indexer_for_training = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
progress_message_placeholder.success(f"CharIndexer initialized with {char_indexer_for_training.num_classes} classes.")
train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer_for_training, BATCH_SIZE)
progress_message_placeholder.success("DataLoaders created successfully.")
# Re-initialize the model to train from scratch if the button is pressed
# This ensures we don't continue training a potentially already trained model if it was loaded.
ocr_model_for_training = CRNN(num_classes=char_indexer_for_training.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
ocr_model_for_training.to(device)
ocr_model_for_training.train()
progress_message_placeholder.write("Training in progress... This may take a while.")
# Capture the model and history
ocr_model_for_training, history_result = train_ocr_model(
model=ocr_model_for_training,
train_loader=train_loader,
test_loader=test_loader,
char_indexer=char_indexer_for_training,
epochs=NUM_EPOCHS,
device=device,
progress_callback=update_progress_callback
)
st.session_state.training_history = history_result # Save history to session state
progress_message_placeholder.success("OCR model training finished!")
update_progress_callback(1.0, "Training complete!")
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
save_ocr_model(ocr_model_for_training, MODEL_SAVE_PATH)
progress_message_placeholder.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
ocr_model = ocr_model_for_training
ocr_model.eval() # Set to eval mode for subsequent predictions
except Exception as e:
progress_message_placeholder.error(f"An error occurred during training: {e}")
st.exception(e) # This will print a detailed traceback in the Streamlit UI
update_progress_callback(0.0, "Training failed!")
st.write("---")
# --- Model Loading Section ---
st.subheader("2. Load Pre-trained Model")
st.write("If you have a saved model, you can load it here instead of training.")
if st.button("πŸ’Ύ Load Model"):
if os.path.exists(MODEL_SAVE_PATH):
try:
loaded_model_instance = CRNN(num_classes=char_indexer.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
load_ocr_model(loaded_model_instance, MODEL_SAVE_PATH)
loaded_model_instance.to(device)
ocr_model = loaded_model_instance
ocr_model.eval()
st.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`")
# If a model is loaded, we can try to re-evaluate it to get history,
# but typically history is stored from a training run.
# For simplicity, we'll assume training history is only stored after a training run.
except Exception as e:
st.error(f"Error loading model: {e}")
st.exception(e)
else:
st.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.")
st.write("---")
# --- Training History Plots Section ---
st.subheader("3. Training History Plots")
if st.session_state.training_history: # Check if history exists in session state
history_df = pd.DataFrame({
'Epoch': range(1, len(st.session_state.training_history['train_loss']) + 1),
'Train Loss': st.session_state.training_history['train_loss'],
'Test Loss': st.session_state.training_history['test_loss'],
'Test CER (%)': [cer * 100 for cer in st.session_state.training_history['test_cer']],
'Test Exact Match Accuracy (%)': [acc * 100 for acc in st.session_state.training_history['test_exact_match_accuracy']]
})
st.markdown("**Loss over Epochs**")
st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
st.caption("Lower loss indicates better model performance.")
st.markdown("**Character Error Rate (CER) over Epochs**")
st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
st.caption("Lower CER indicates fewer character errors (0% is perfect).")
st.markdown("**Exact Match Accuracy over Epochs**")
st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
st.caption("CER should decrease, Accuracy should increase.")
else:
st.info("Train the model first to see training history plots here.")
# --- Final Footer ---
col1, col2, col3 = st.columns([1, 3, 1])
with col2:
st.markdown("""
---
*Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
""")