File size: 10,135 Bytes
15dba6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# -*- coding: utf-8 -*-
# app.py

import os
# Disable Streamlit file watcher to prevent conflicts with PyTorch
os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false"

import streamlit as st
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import traceback

# Import all necessary configuration values from config.py
from config import (
    IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN, VOCABULARY, BLANK_TOKEN_SYMBOL,
    TRAIN_CSV_PATH, TEST_CSV_PATH, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
    MODEL_SAVE_PATH, BATCH_SIZE, NUM_EPOCHS
)

# Import classes and functions from data_handler_ocr.py and model_ocr.py
from data_handler_ocr import CharIndexer, OCRDataset, ocr_collate_fn, load_ocr_dataframes, create_ocr_dataloaders
from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
from utils_ocr import preprocess_user_image_for_ocr, binarize_image, resize_image_for_ocr, normalize_image_for_model


# --- Global Variables ---
ocr_model = None
char_indexer = None 
training_history = None 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Streamlit App Setup ---
st.set_page_config(layout="wide", page_title="Handwritten Name OCR App",)


st.title("πŸ“ Handwritten Name Recognition (OCR) App")
st.markdown("""

    This application uses a Convolutional Recurrent Neural Network (CRNN) to perform

    Optical Character Recognition (OCR) on handwritten names. You can upload an image

    of a handwritten name for prediction or train a new model using the provided dataset.



    **Note:** Training a robust OCR model can be time-consuming.

""")

# --- Initialize CharIndexer ---
# This initializes char_indexer once when the script starts
char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)

# --- Model Loading / Initialization ---
@st.cache_resource # Cache the model to prevent reloading on every rerun
def get_and_load_ocr_model_cached(num_classes, model_path):
    """

    Initializes the OCR model and attempts to load a pre-trained model.

    If no pre-trained model exists, a new model instance is returned.

    """
    model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
    
    if os.path.exists(model_path):
        st.sidebar.info("Loading pre-trained OCR model...")
        try:
            # Load model to CPU first, then move to device
            model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
            st.sidebar.success("OCR model loaded successfully!")
        except Exception as e:
            st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
            # If loading fails, re-initialize an untrained model
            model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
    else:
        st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
    
    return model_instance

# Get the model instance and assign it to the global 'ocr_model'
ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
# Ensure the model is on the correct device for inference
ocr_model.to(device)
ocr_model.eval() # Set model to evaluation mode for inference by default


# --- Sidebar for Model Training ---
st.sidebar.header("Train OCR Model")
st.sidebar.write("Click the button below to start training the OCR model.")

# Progress bar and label for training in the sidebar
progress_bar_sidebar = st.sidebar.progress(0)
progress_label_sidebar = st.sidebar.empty()

def update_progress_callback_sidebar(value, text):
    progress_bar_sidebar.progress(int(value * 100))
    progress_label_sidebar.text(text)

if st.sidebar.button("πŸ“Š Start Training"):
    progress_bar_sidebar.progress(0)
    progress_label_sidebar.empty()
    st.empty()

    if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
        st.sidebar.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found!")
    elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
        st.sidebar.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
                   "Evaluation might be affected or skipped. Please ensure all data paths are correct.")
    else:
        st.sidebar.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
        
        try:
            train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
            st.sidebar.success("Training and Test DataFrames loaded successfully.")

            st.sidebar.success(f"CharIndexer initialized with {char_indexer.num_classes} classes.")

            train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer, BATCH_SIZE)
            st.sidebar.success("DataLoaders created successfully.")
            
            ocr_model.train() 

            st.sidebar.write("Training in progress... This may take a while.")
            ocr_model, training_history = train_ocr_model(
                model=ocr_model,
                train_loader=train_loader,
                test_loader=test_loader,
                char_indexer=char_indexer,
                epochs=NUM_EPOCHS,
                device=device,
                progress_callback=update_progress_callback_sidebar 
            )
            st.sidebar.success("OCR model training finished!")
            update_progress_callback_sidebar(1.0, "Training complete!")

            os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
            save_ocr_model(ocr_model, MODEL_SAVE_PATH) 
            st.sidebar.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")

        except Exception as e:
            st.sidebar.error(f"An error occurred during training: {e}")
            st.exception(e) 
            update_progress_callback_sidebar(0.0, "Training failed!")

# --- Sidebar for Model Loading ---
st.sidebar.header("Load Pre-trained Model")
st.sidebar.write("If you have a saved model, you can load it here instead of training.")

if st.sidebar.button("πŸ’Ύ Load Model"):
    if os.path.exists(MODEL_SAVE_PATH):
        try:
            loaded_model = CRNN(num_classes=char_indexer.num_classes)
            load_ocr_model(loaded_model, MODEL_SAVE_PATH)
            loaded_model.to(device)
            
            st.sidebar.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`")
        except Exception as e:
            st.sidebar.error(f"Error loading model: {e}")
            st.exception(e) 
    else:
        st.sidebar.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.")

# --- Main Content: Prediction Section and Training History  ---

# Display training history chart
if training_history:
    st.subheader("Training History Plots") 
    history_df = pd.DataFrame({
        'Epoch': range(1, len(training_history['train_loss']) + 1),
        'Train Loss': training_history['train_loss'],
        'Test Loss': training_history['test_loss'],
        'Test CER (%)': [cer * 100 for cer in training_history['test_cer']],
        'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']]
    })

    st.markdown("**Loss over Epochs**")
    st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
    st.caption("Lower loss indicates better model performance.")

    st.markdown("**Character Error Rate (CER) over Epochs**")
    st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
    st.caption("Lower CER indicates fewer character errors (0% is perfect).")

    st.markdown("**Exact Match Accuracy over Epochs**")
    st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
    st.caption("Higher exact match accuracy indicates more perfectly recognized names.")

    st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
    st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
    st.caption("CER should decrease, Accuracy should increase.")
    st.write("---") # Separator after charts


# Predict on a New Image

if ocr_model is None:
    st.warning("Please train or load a model before attempting prediction.")
else:
    uploaded_file = st.file_uploader("πŸ–ΌοΈ Choose an image...", type=["png", "jpg", "jpeg", "jfif"])

    if uploaded_file is not None:
        try:
            image_pil = Image.open(uploaded_file).convert('L')
            st.image(image_pil, caption="Uploaded Image", use_container_width=True) 
            st.write("---")
            st.write("Processing and Recognizing...")

            processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
            
            ocr_model.eval()
            with torch.no_grad():
                output = ocr_model(processed_image_tensor)
            
            predicted_texts = ctc_greedy_decode(output, char_indexer)
            predicted_text = predicted_texts[0]

            st.success(f"Recognized Text: **{predicted_text}**")

        except Exception as e:
            st.error(f"Error processing image or recognizing text: {e}")
            st.info("πŸ’‘ **Tips for best results:**\n"
                    "- Ensure the handwritten text is clear and on a clean background.\n"
                    "- Only include one name/word per image.\n"
                    "- The model is trained on specific characters. Unusual symbols might not be recognized.")
            st.exception(e)

st.markdown("""

    ---

    *Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*

    """)