File size: 16,579 Bytes
442eace
 
 
 
d7fcda8
442eace
 
 
 
 
 
af3f1e7
442eace
 
af3f1e7
 
 
 
 
 
 
 
 
 
442eace
 
af3f1e7
 
 
 
 
 
 
 
442eace
 
 
af3f1e7
 
 
 
6388999
 
 
442eace
 
 
af3f1e7
442eace
6388999
 
 
af3f1e7
442eace
 
af3f1e7
 
 
 
 
 
 
6388999
 
 
 
 
 
 
 
 
 
 
 
04f0235
6388999
 
 
04f0235
6388999
 
 
 
 
 
442eace
6388999
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af3f1e7
 
6388999
 
af3f1e7
6388999
 
 
af3f1e7
0385397
6388999
 
 
 
af3f1e7
6388999
 
 
 
 
 
 
 
af3f1e7
6388999
 
 
 
 
 
 
 
 
af3f1e7
6388999
 
 
af3f1e7
6388999
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af3f1e7
6388999
 
 
 
 
 
af3f1e7
6388999
 
 
0385397
6388999
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0385397
6388999
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af3f1e7
6388999
af3f1e7
6388999
 
 
af3f1e7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
# -*- coding: utf-8 -*-
# app.py

import os
import streamlit as st
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import traceback # Ensure this is imported

# Import all necessary configuration values from config.py
# Wrap this import in a try-except
try:
    from config import (
        IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN, VOCABULARY, BLANK_TOKEN_SYMBOL,
        TRAIN_CSV_PATH, TEST_CSV_PATH, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
        MODEL_SAVE_PATH, BATCH_SIZE, NUM_EPOCHS
    )
except Exception as e:
    st.error(f"FATAL ERROR: Could not load config.py. Please check your config.py file for errors. Details: {e}")
    st.stop() # Stop the app if config fails to load

# Import classes and functions from data_handler_ocr.py and model_ocr.py
# Wrap these imports in a try-except
try:
    from data_handler_ocr import CharIndexer, OCRDataset, ocr_collate_fn, load_ocr_dataframes, create_ocr_dataloaders
    from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
    from utils_ocr import preprocess_user_image_for_ocr, binarize_image, resize_image_for_ocr, normalize_image_for_model
except Exception as e:
    st.error(f"FATAL ERROR: Could not load core modules (data_handler_ocr.py, model_ocr.py, utils_ocr.py). Please check these files for errors. Details: {e}")
    st.stop() # Stop the app if core modules fail to load


# --- Global Variables ---
# Initialize training_history in Streamlit's session state to persist across reruns
if 'training_history' not in st.session_state:
    st.session_state.training_history = None

# Initialize ocr_model and char_indexer as None; they will be populated below
ocr_model = None 
char_indexer = None 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Streamlit App Setup ---
st.set_page_config(layout="wide", page_title="Handwritten Name OCR App")

# Main Title and Description (Centered)
main_title_col1, main_title_col2, main_title_col3 = st.columns([1, 3, 1])
with main_title_col2:
    st.title("πŸ“ Handwritten Name Recognition (OCR) App")

# --- Initialize CharIndexer ---
try:
    char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
except Exception as e:
    st.error(f"FATAL ERROR: Could not initialize CharIndexer. Check config.py (VOCABULARY, BLANK_TOKEN_SYMBOL) and data_handler_ocr.py (CharIndexer class). Details: {e}")
    st.stop()


# --- Model Loading / Initialization (Cached and Global) ---
@st.cache_resource
def get_and_load_ocr_model_cached_internal(num_classes, model_path):
    """
    Initializes the OCR model and attempts to load a pre-trained model.
    Returns (model_instance, message_type, message_text)
    """
    model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
    message_type = "warning"
    message_text = "No pre-trained OCR model found. Please train a model using the 'Train & Evaluate' tab."
    
    if os.path.exists(model_path):
        try:
            model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
            message_type = "success"
            message_text = "OCR model loaded successfully!"
        except Exception as e:
            message_type = "error"
            message_text = f"Error loading model from '{model_path}' during app startup: {e}. A new model will be initialized."
            # If loading fails, re-initialize to a fresh model to avoid issues.
            model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
    
    return model_instance, message_type, message_text

# Display messages OUTSIDE the cached function
try:
    loaded_model_instance, load_msg_type, load_msg_text = get_and_load_ocr_model_cached_internal(char_indexer.num_classes, MODEL_SAVE_PATH)
    
    # Assign to global ocr_model
    ocr_model = loaded_model_instance 

    # Display status messages as toasts
    if load_msg_type == "success":
        st.toast(load_msg_text, icon="βœ…")
    elif load_msg_type == "warning":
        st.toast(load_msg_text, icon="⚠️")
    elif load_msg_type == "error":
        st.toast(load_msg_text, icon="🚨") 
 
    if ocr_model is not None:
        ocr_model.to(device)
        ocr_model.eval() # Set model to evaluation mode for inference by default
    else:
        st.error("Model instance is None after cached load. Prediction will not be available.")

except Exception as e:
    st.error(f"FATAL ERROR: Could not initialize or load OCR model during app startup (outer block). Check model_ocr.py (CRNN class) or your saved model file. Details: {e}")
    st.stop()


# --- Define Tabs ---
tabs_col1, tabs_col2, tabs_col3 = st.columns([1, 3, 1])
with tabs_col2:
    tab1, tab2, tab3 = st.tabs([" πŸ—¨οΈ Project Description", " πŸ”Ž Predict Name", " πŸ“ˆ Train & Evaluate"])

# --- Tab 1: Project Description ---
with tab1:
    st.markdown("""
        This application implements a Handwritten Name Recognition (OCR) system using a Convolutional Recurrent Neural Network (CRNN) built with PyTorch. 
        Its core aim is to accurately convert handwritten text from images into digital format, providing a user-friendly interface via Streamlit.
        
        Here are some helpful resources related to this project:
    """)

    st.markdown("""
        **[πŸ“ƒ Project Documentation ](https://drive.google.com/file/d/1HBrQT_UnzNLdEsouW9wMk4alAeCsQxZb/view?usp=sharing)**
                            
        **[🎞️ Demo Presentation ](https://drive.google.com/file/d/1j_S8cijxy6zxIn3cWg6tuLPNWB_7nwdI/view?usp=sharing)**
                            
        **[πŸ“š Dataset (from Kaggle)](https://www.kaggle.com/datasets/landlord/handwriting-recognition)**
        
        **[πŸ“‚ Github Repository ](https://github.com/marianeft/handwritten_name_ocr_app)**
    """)

# --- Tab 2: Predict Name (Main Content: Prediction Section) ---
with tab2:
        st.markdown("Upload a clear image of a single handwritten name or word for recognition.")

        # Check the global ocr_model for prediction availability
        if ocr_model is None:
            st.warning("Model not loaded. Please train or load a model in the 'Train & Evaluate' tab before attempting prediction.")
        else:
            uploaded_file = st.file_uploader("πŸ–ΌοΈ Choose an image...", type=["png", "jpg", "jpeg", "jfif"])

            if uploaded_file is not None:
                try:
                    image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale
                    st.image(image_pil, caption="Uploaded Image", use_container_width=True) 
                    st.write("---")
                    st.write("Processing and Recognizing...")

                    processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
                    
                    ocr_model.eval() # Ensure model is in eval mode for prediction
                    with torch.no_grad():
                        output = ocr_model(processed_image_tensor)
                    
                    predicted_texts = ctc_greedy_decode(output, char_indexer)
                    predicted_text = predicted_texts[0]

                    st.success(f"Recognized Text: **{predicted_text}**")

                except Exception as e:
                    st.error(f"Error processing image or recognizing text: {e}")
                    st.info("πŸ’‘ **Tips for best results:**\n"
                            "- Ensure the handwritten text is clear and on a clean background.\n"
                            "- Only include one name/word per image.\n"
                            "- The model is trained on specific characters. Unusual symbols might not be recognized.")
                    st.exception(e) # Display full traceback for debugging

# --- Tab 3: Train & Evaluate ---
with tab3:
        
        # --- Model Training Section ---
        st.subheader("Train OCR Model")
        st.write("Click the button below to start training the OCR model.")

        # Progress bar and label for training within this tab
        progress_message_placeholder = st.empty()
        progress_bar_placeholder = st.progress(0)
        
        def update_progress_callback(value, text):
            progress_bar_placeholder.progress(int(value * 100))
            progress_message_placeholder.info(text) # Use info for dynamic messages

        if st.button("πŸ“Š Start Training"):
            progress_message_placeholder.empty() # Clear previous messages
            progress_bar_placeholder.progress(0) # Reset progress bar

            if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
                st.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found! Please check file paths and ensure data is uploaded correctly.")
            elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
                st.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
                        "Evaluation might be affected or skipped. Please ensure all data paths are correct and data is uploaded.")
            else:
                progress_message_placeholder.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
                
                try:
                    train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
                    progress_message_placeholder.success("Training and Test DataFrames loaded successfully.")
                    progress_message_placeholder.info(f"Train DataFrame size: {len(train_df)} samples")
                    progress_message_placeholder.info(f"Test DataFrame size: {len(test_df)} samples")
                    if len(test_df) == 0:
                        progress_message_placeholder.error("ERROR: Test DataFrame is empty! Evaluation cannot proceed. Check TEST_CSV_PATH and TEST_IMAGES_DIR.")
                    if len(train_df) == 0:
                        progress_message_placeholder.error("ERROR: Train DataFrame is empty! Training cannot proceed. Check TRAIN_CSV_PATH and TRAIN_IMAGES_DIR.")
                    
                    if len(train_df) == 0 or len(test_df) == 0: # Stop if critical data is missing
                        st.stop() # Added st.stop for critical data missing scenario

                    char_indexer_for_training = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
                    progress_message_placeholder.success(f"CharIndexer initialized with {char_indexer_for_training.num_classes} classes.")

                    train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer_for_training, BATCH_SIZE)
                    progress_message_placeholder.success("DataLoaders created successfully.")
                    
                    ocr_model_for_training = CRNN(num_classes=char_indexer_for_training.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
                    ocr_model_for_training.to(device)
                    ocr_model_for_training.train() # Set to train mode before passing

                    progress_message_placeholder.write("Training in progress... This may take a while.")
                    
                    ocr_model_for_training, history_result = train_ocr_model(
                        model=ocr_model_for_training, # Pass the local ocr_model_for_training instance
                        train_loader=train_loader,
                        test_loader=test_loader,
                        char_indexer=char_indexer_for_training,
                        epochs=NUM_EPOCHS,
                        device=device,
                        progress_callback=update_progress_callback
                    )
                    
                    st.session_state.training_history = history_result # Save history to session state
                    
                    progress_message_placeholder.success("OCR model training finished!")
                    update_progress_callback(1.0, "Training complete!")

                    os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
                    save_ocr_model(ocr_model_for_training, MODEL_SAVE_PATH) # Save the now trained ocr_model_for_training
                    progress_message_placeholder.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")

                    # Crucial: Update the global ocr_model with the newly trained one
                    ocr_model = ocr_model_for_training 
                    ocr_model.eval() # Set to eval mode for subsequent predictions

                except Exception as e:
                    progress_message_placeholder.error(f"An error occurred during training: {e}")
                    st.exception(e) # This will print a detailed traceback in the Streamlit UI
                    update_progress_callback(0.0, "Training failed!")
            
        st.write("---")

        # --- Model Loading Section ---
        st.subheader("Load Pre-trained Model")
        st.write("If you have a saved model, you can load it here instead of training.")

        if st.button("πŸ’Ύ Load Model"):
            if os.path.exists(MODEL_SAVE_PATH):
                try:
                    loaded_model_instance = CRNN(num_classes=char_indexer.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
                    load_ocr_model(loaded_model_instance, MODEL_SAVE_PATH)
                    loaded_model_instance.to(device)
                    ocr_model = loaded_model_instance # Update global model reference
                    ocr_model.eval() # Set to eval mode after loading
                    st.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`")
                    
                    # For simplicity, training history is only populated after a training run.
                    # If you need to load history with the model, it would need to be saved separately.
                    
                except Exception as e:
                    st.error(f"Error loading model: {e}")
                    st.exception(e)
            else:
                st.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.")

        st.write("---")

        # --- Training History Plots Section ---
        st.subheader("Training History Plots")
        if st.session_state.training_history: # Check if history exists in session state
            history_df = pd.DataFrame({
                'Epoch': range(1, len(st.session_state.training_history['train_loss']) + 1),
                'Train Loss': st.session_state.training_history['train_loss'],
                'Test Loss': st.session_state.training_history['test_loss'],
                'Test CER (%)': [cer * 100 for cer in st.session_state.training_history['test_cer']],
                'Test Exact Match Accuracy (%)': [acc * 100 for acc in st.session_state.training_history['test_exact_match_accuracy']]
            })

            st.markdown("**Loss over Epochs**")
            st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
            st.caption("Lower loss indicates better model performance.")

            st.markdown("**Character Error Rate (CER) over Epochs**")
            st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
            st.caption("Lower CER indicates fewer character errors (0% is perfect).")

            st.markdown("**Exact Match Accuracy over Epochs**")
            st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
            st.caption("Higher exact match accuracy indicates more perfectly recognized names.")

            st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
            st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
            st.caption("CER should decrease, Accuracy should increase.")
        else:
            st.info("Train the model first to see training history plots here.")

# --- Final Footer (Centered) ---
footer_col1, footer_col2, footer_col3 = st.columns([1, 3, 1])
with footer_col2:
    st.markdown("""
        ---
        *Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
        """)