marianeft commited on
Commit
af3f1e7
Β·
1 Parent(s): ab2bdda
Files changed (1) hide show
  1. src/streamlit_app.py +279 -197
src/streamlit_app.py CHANGED
@@ -2,7 +2,8 @@
2
  # app.py
3
 
4
  import os
5
- # Disable Streamlit file watcher to prevent conflicts with PyTorch
 
6
  os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false"
7
 
8
  import streamlit as st
@@ -12,216 +13,297 @@ from PIL import Image
12
  import torch
13
  import torch.nn.functional as F
14
  import torchvision.transforms as transforms
15
- import traceback
16
 
17
  # Import all necessary configuration values from config.py
18
- from config import (
19
- IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN, VOCABULARY, BLANK_TOKEN_SYMBOL,
20
- TRAIN_CSV_PATH, TEST_CSV_PATH, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
21
- MODEL_SAVE_PATH, BATCH_SIZE, NUM_EPOCHS
22
- )
 
 
 
 
 
23
 
24
  # Import classes and functions from data_handler_ocr.py and model_ocr.py
25
- from data_handler_ocr import CharIndexer, OCRDataset, ocr_collate_fn, load_ocr_dataframes, create_ocr_dataloaders
26
- from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
27
- from utils_ocr import preprocess_user_image_for_ocr, binarize_image, resize_image_for_ocr, normalize_image_for_model
 
 
 
 
 
28
 
29
 
30
  # --- Global Variables ---
31
- ocr_model = None
32
- char_indexer = None
33
- training_history = None
 
 
 
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
 
36
  # --- Streamlit App Setup ---
37
- st.set_page_config(layout="wide", page_title="Handwritten Name OCR App",)
38
-
39
 
40
- st.title("πŸ“ Handwritten Name Recognition (OCR) App")
41
- st.markdown("""
42
- This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
43
- Optical Character Recognition (OCR) on handwritten names. You can upload an image
44
- of a handwritten name for prediction or train a new model using the provided dataset.
45
-
46
- **Note:** Training a robust OCR model can be time-consuming.
47
- """)
48
 
49
  # --- Initialize CharIndexer ---
50
- # This initializes char_indexer once when the script starts
51
- char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
52
-
53
- # --- Model Loading / Initialization ---
54
- @st.cache_resource # Cache the model to prevent reloading on every rerun
55
- def get_and_load_ocr_model_cached(num_classes, model_path):
56
- """
57
- Initializes the OCR model and attempts to load a pre-trained model.
58
- If no pre-trained model exists, a new model instance is returned.
59
- """
60
- model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
61
-
62
- if os.path.exists(model_path):
63
- st.sidebar.info("Loading pre-trained OCR model...")
64
- try:
65
- # Load model to CPU first, then move to device
66
- model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
67
- st.sidebar.success("OCR model loaded successfully!")
68
- except Exception as e:
69
- st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
70
- # If loading fails, re-initialize an untrained model
71
- model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
72
- else:
73
- st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
74
-
75
- return model_instance
76
-
77
- # Get the model instance and assign it to the global 'ocr_model'
78
- ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
79
- # Ensure the model is on the correct device for inference
80
- ocr_model.to(device)
81
- ocr_model.eval() # Set model to evaluation mode for inference by default
82
-
83
-
84
- # --- Sidebar for Model Training ---
85
- st.sidebar.header("Train OCR Model")
86
- st.sidebar.write("Click the button below to start training the OCR model.")
87
-
88
- # Progress bar and label for training in the sidebar
89
- progress_bar_sidebar = st.sidebar.progress(0)
90
- progress_label_sidebar = st.sidebar.empty()
91
-
92
- def update_progress_callback_sidebar(value, text):
93
- progress_bar_sidebar.progress(int(value * 100))
94
- progress_label_sidebar.text(text)
95
-
96
- if st.sidebar.button("πŸ“Š Start Training"):
97
- progress_bar_sidebar.progress(0)
98
- progress_label_sidebar.empty()
99
- st.empty()
100
-
101
- if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
102
- st.sidebar.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found!")
103
- elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
104
- st.sidebar.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
105
- "Evaluation might be affected or skipped. Please ensure all data paths are correct.")
106
- else:
107
- st.sidebar.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
108
 
109
- try:
110
- train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
111
- st.sidebar.success("Training and Test DataFrames loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- st.sidebar.success(f"CharIndexer initialized with {char_indexer.num_classes} classes.")
 
 
 
 
 
 
 
114
 
115
- train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer, BATCH_SIZE)
116
- st.sidebar.success("DataLoaders created successfully.")
117
-
118
- ocr_model.train()
119
-
120
- st.sidebar.write("Training in progress... This may take a while.")
121
- ocr_model, training_history = train_ocr_model(
122
- model=ocr_model,
123
- train_loader=train_loader,
124
- test_loader=test_loader,
125
- char_indexer=char_indexer,
126
- epochs=NUM_EPOCHS,
127
- device=device,
128
- progress_callback=update_progress_callback_sidebar
129
- )
130
- st.sidebar.success("OCR model training finished!")
131
- update_progress_callback_sidebar(1.0, "Training complete!")
132
-
133
- os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
134
- save_ocr_model(ocr_model, MODEL_SAVE_PATH)
135
- st.sidebar.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
136
-
137
- except Exception as e:
138
- st.sidebar.error(f"An error occurred during training: {e}")
139
- st.exception(e)
140
- update_progress_callback_sidebar(0.0, "Training failed!")
141
-
142
- # --- Sidebar for Model Loading ---
143
- st.sidebar.header("Load Pre-trained Model")
144
- st.sidebar.write("If you have a saved model, you can load it here instead of training.")
145
-
146
- if st.sidebar.button("πŸ’Ύ Load Model"):
147
- if os.path.exists(MODEL_SAVE_PATH):
148
- try:
149
- loaded_model = CRNN(num_classes=char_indexer.num_classes)
150
- load_ocr_model(loaded_model, MODEL_SAVE_PATH)
151
- loaded_model.to(device)
152
-
153
- st.sidebar.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`")
154
- except Exception as e:
155
- st.sidebar.error(f"Error loading model: {e}")
156
- st.exception(e)
157
- else:
158
- st.sidebar.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.")
159
-
160
- # --- Main Content: Prediction Section and Training History ---
161
-
162
- # Display training history chart
163
- if training_history:
164
- st.subheader("Training History Plots")
165
- history_df = pd.DataFrame({
166
- 'Epoch': range(1, len(training_history['train_loss']) + 1),
167
- 'Train Loss': training_history['train_loss'],
168
- 'Test Loss': training_history['test_loss'],
169
- 'Test CER (%)': [cer * 100 for cer in training_history['test_cer']],
170
- 'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']]
171
- })
172
-
173
- st.markdown("**Loss over Epochs**")
174
- st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
175
- st.caption("Lower loss indicates better model performance.")
176
-
177
- st.markdown("**Character Error Rate (CER) over Epochs**")
178
- st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
179
- st.caption("Lower CER indicates fewer character errors (0% is perfect).")
180
-
181
- st.markdown("**Exact Match Accuracy over Epochs**")
182
- st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
183
- st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
184
-
185
- st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
186
- st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
187
- st.caption("CER should decrease, Accuracy should increase.")
188
- st.write("---") # Separator after charts
189
-
190
-
191
- # Predict on a New Image
192
-
193
- if ocr_model is None:
194
- st.warning("Please train or load a model before attempting prediction.")
195
- else:
196
- uploaded_file = st.file_uploader("πŸ–ΌοΈ Choose an image...", type=["png", "jpg", "jpeg", "jfif"])
197
-
198
- if uploaded_file is not None:
199
- try:
200
- image_pil = Image.open(uploaded_file).convert('L')
201
- st.image(image_pil, caption="Uploaded Image", use_container_width=True)
202
- st.write("---")
203
- st.write("Processing and Recognizing...")
204
-
205
- processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
206
-
207
- ocr_model.eval()
208
- with torch.no_grad():
209
- output = ocr_model(processed_image_tensor)
210
 
211
- predicted_texts = ctc_greedy_decode(output, char_indexer)
212
- predicted_text = predicted_texts[0]
213
-
214
- st.success(f"Recognized Text: **{predicted_text}**")
215
-
216
- except Exception as e:
217
- st.error(f"Error processing image or recognizing text: {e}")
218
- st.info("πŸ’‘ **Tips for best results:**\n"
219
- "- Ensure the handwritten text is clear and on a clean background.\n"
220
- "- Only include one name/word per image.\n"
221
- "- The model is trained on specific characters. Unusual symbols might not be recognized.")
222
- st.exception(e)
223
-
224
- st.markdown("""
225
- ---
226
- *Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
227
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  # app.py
3
 
4
  import os
5
+ # CRITICAL FIX: Disable Streamlit's file watcher to prevent conflicts with PyTorch
6
+ # This MUST be the first thing, before any other imports or Streamlit calls
7
  os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false"
8
 
9
  import streamlit as st
 
13
  import torch
14
  import torch.nn.functional as F
15
  import torchvision.transforms as transforms
16
+ import traceback # Ensure this is imported
17
 
18
  # Import all necessary configuration values from config.py
19
+ # Wrap this import in a try-except
20
+ try:
21
+ from config import (
22
+ IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN, VOCABULARY, BLANK_TOKEN_SYMBOL,
23
+ TRAIN_CSV_PATH, TEST_CSV_PATH, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
24
+ MODEL_SAVE_PATH, BATCH_SIZE, NUM_EPOCHS
25
+ )
26
+ except Exception as e:
27
+ st.error(f"FATAL ERROR: Could not load config.py. Please check your config.py file for errors. Details: {e}")
28
+ st.stop() # Stop the app if config fails to load
29
 
30
  # Import classes and functions from data_handler_ocr.py and model_ocr.py
31
+ # Wrap these imports in a try-except
32
+ try:
33
+ from data_handler_ocr import CharIndexer, OCRDataset, ocr_collate_fn, load_ocr_dataframes, create_ocr_dataloaders
34
+ from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
35
+ from utils_ocr import preprocess_user_image_for_ocr, binarize_image, resize_image_for_ocr, normalize_image_for_model
36
+ except Exception as e:
37
+ st.error(f"FATAL ERROR: Could not load core modules (data_handler_ocr.py, model_ocr.py, utils_ocr.py). Please check these files for errors. Details: {e}")
38
+ st.stop() # Stop the app if core modules fail to load
39
 
40
 
41
  # --- Global Variables ---
42
+ # Initialize training_history in Streamlit's session state to persist across reruns
43
+ if 'training_history' not in st.session_state:
44
+ st.session_state.training_history = None
45
+
46
+ ocr_model = None # Will be initialized by @st.cache_resource
47
+ char_indexer = None # Will be initialized below
48
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
 
50
  # --- Streamlit App Setup ---
51
+ st.set_page_config(layout="wide", page_title="Handwritten Name OCR App")
 
52
 
53
+ col1, col2, col3 = st.columns([1, 3, 1])
54
+ with col2:
55
+ st.title("πŸ“ Handwritten Name Recognition (OCR) App")
 
 
 
 
 
56
 
57
  # --- Initialize CharIndexer ---
58
+ # Wrap this in a try-except
59
+ try:
60
+ char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
61
+ except Exception as e:
62
+ st.error(f"FATAL ERROR: Could not initialize CharIndexer. Check config.py (VOCABULARY, BLANK_TOKEN_SYMBOL) and data_handler_ocr.py (CharIndexer class). Details: {e}")
63
+ st.stop()
64
+
65
+ # --- Define Tabs ---
66
+ col1, col2, col3 = st.columns([1, 3, 1])
67
+ with col2:
68
+ tab1, tab2, tab3 = st.tabs(["Project Description", "Predict Name", "Train & Evaluate"])
69
+
70
+ # --- Tab 1: Project Description ---
71
+ with tab1:
72
+ # Use columns for centering content within the tab
73
+ st.markdown("""
74
+ This application implements a Handwritten Name Recognition (OCR) system using a Convolutional Recurrent Neural Network (CRNN) built with PyTorch.
75
+ Its core aim is to accurately convert handwritten text from images into digital format, providing a user-friendly interface via Streamlit.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ Here are some helpful resources related to this project:
78
+ """)
79
+ st.markdown("""
80
+ **[πŸ“ƒ Project Documentation ](https://drive.google.com/file/d/1HBrQT_UnzNLdEsouW9wMk4alAeCsQxZb/view?usp=sharing)**
81
+
82
+ **[🎞️ Demo Presentation ](https://drive.google.com/drive/folders/1rOmwyTJkDCsU-Wuh-_CzvQ9sdb_ci_kX?usp=sharing)**
83
+
84
+ **[πŸ“š Dataset (from Kaggle)](https://www.kaggle.com/datasets/landlord/handwriting-recognition)**
85
+
86
+ **[πŸ“‚ Github Repository ](https://github.com/marianeft/handwritten_name_ocr_app)**
87
+ """)
88
+
89
+ # --- Tab 2: Predict Name (Main Content: Prediction Section) ---
90
+ with tab2:
91
+ st.header("Predict on a New Image")
92
+ st.markdown("Upload a clear image of a single handwritten name or word for recognition.")
93
+
94
+ if ocr_model is None:
95
+ st.warning("Model not loaded. Please train or load a model in the 'Train & Evaluate' tab before attempting prediction.")
96
+ else:
97
+ uploaded_file = st.file_uploader("πŸ–ΌοΈ Choose an image...", type=["png", "jpg", "jpeg", "jfif"])
98
+
99
+ if uploaded_file is not None:
100
+ try:
101
+ image_pil = Image.open(uploaded_file).convert('L')
102
+ st.image(image_pil, caption="Uploaded Image", use_container_width=True)
103
+ st.write("---")
104
+ st.write("Processing and Recognizing...")
105
+
106
+ processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
107
+
108
+ ocr_model.eval() # Ensure model is in eval mode for prediction
109
+ with torch.no_grad():
110
+ output = ocr_model(processed_image_tensor)
111
+
112
+ predicted_texts = ctc_greedy_decode(output, char_indexer)
113
+ predicted_text = predicted_texts[0]
114
+
115
+ st.success(f"Recognized Text: **{predicted_text}**")
116
+
117
+ except Exception as e:
118
+ st.error(f"Error processing image or recognizing text: {e}")
119
+ st.info("πŸ’‘ **Tips for best results:**\n"
120
+ "- Ensure the handwritten text is clear and on a clean background.\n"
121
+ "- Only include one name/word per image.\n"
122
+ "- The model is trained on specific characters. Unusual symbols might not be recognized.")
123
+ st.exception(e)
124
+
125
+ # --- Tab 3: Train & Evaluate ---
126
+ with tab3:
127
+ st.header("Model Training and Evaluation")
128
+ st.markdown("Here you can train a new OCR model or load a pre-trained one.")
129
+
130
+ # --- Model Loading / Initialization (Cached) ---
131
+ @st.cache_resource # Cache the model to prevent reloading on every rerun
132
+ def get_and_load_ocr_model_cached(num_classes, model_path):
133
+ """
134
+ Initializes the OCR model and attempts to load a pre-trained model.
135
+ If no pre-trained model exists, a new model instance is returned.
136
+ """
137
+ model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
138
+
139
+ if os.path.exists(model_path):
140
+ st.info("Loading pre-trained OCR model...")
141
+ try:
142
+ # Load model to CPU first, then move to device
143
+ model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
144
+ st.success("OCR model loaded successfully!")
145
+ except Exception as e:
146
+ st.error(f"Error loading model from '{model_path}': {e}. A new model will be initialized.")
147
+ # If loading fails, re-initialize an untrained model
148
+ model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
149
+ else:
150
+ st.warning("No pre-trained OCR model found. Please train a model.")
151
+
152
+ return model_instance
153
 
154
+ # Wrap model loading in a try-except
155
+ try:
156
+ ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
157
+ ocr_model.to(device)
158
+ ocr_model.eval() # Set model to evaluation mode for inference by default
159
+ except Exception as e:
160
+ st.error(f"FATAL ERROR: Could not initialize or load OCR model. Check model_ocr.py (CRNN class) or your saved model file. Details: {e}")
161
+ st.stop()
162
 
163
+
164
+
165
+ # --- Model Training Section ---
166
+ st.subheader("1. Train OCR Model")
167
+ st.write("Click the button below to start training the OCR model.")
168
+
169
+ # Progress bar and label for training within this tab
170
+ progress_container = st.empty() # Container for dynamic messages and progress
171
+ progress_message_placeholder = st.empty()
172
+ progress_bar_placeholder = st.progress(0)
173
+
174
+ def update_progress_callback(value, text):
175
+ progress_bar_placeholder.progress(int(value * 100))
176
+ progress_message_placeholder.info(text) # Use info for dynamic messages
177
+
178
+ if st.button("πŸ“Š Start Training"):
179
+ progress_message_placeholder.empty() # Clear previous messages
180
+ progress_bar_placeholder.progress(0) # Reset progress bar
181
+
182
+ if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
183
+ st.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found! Please check file paths and ensure data is uploaded correctly.")
184
+ elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
185
+ st.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
186
+ "Evaluation might be affected or skipped. Please ensure all data paths are correct and data is uploaded.")
187
+ else:
188
+ progress_message_placeholder.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
189
+
190
+ try:
191
+ train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
192
+ progress_message_placeholder.success("Training and Test DataFrames loaded successfully.")
193
+ progress_message_placeholder.info(f"Train DataFrame size: {len(train_df)} samples")
194
+ progress_message_placeholder.info(f"Test DataFrame size: {len(test_df)} samples")
195
+ if len(test_df) == 0:
196
+ progress_message_placeholder.error("ERROR: Test DataFrame is empty! Evaluation cannot proceed. Check TEST_CSV_PATH and TEST_IMAGES_DIR.")
197
+ if len(train_df) == 0:
198
+ progress_message_placeholder.error("ERROR: Train DataFrame is empty! Training cannot proceed. Check TRAIN_CSV_PATH and TRAIN_IMAGES_DIR.")
199
+
200
+ if len(train_df) == 0 or len(test_df) == 0: # Stop if critical data is missing
201
+ st.stop() # Added st.stop for critical data missing scenario
202
+
203
+ char_indexer_for_training = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
204
+ progress_message_placeholder.success(f"CharIndexer initialized with {char_indexer_for_training.num_classes} classes.")
205
+
206
+ train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer_for_training, BATCH_SIZE)
207
+ progress_message_placeholder.success("DataLoaders created successfully.")
208
+
209
+ # Re-initialize the model to train from scratch if the button is pressed
210
+ # This ensures we don't continue training a potentially already trained model if it was loaded.
211
+ ocr_model_for_training = CRNN(num_classes=char_indexer_for_training.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
212
+ ocr_model_for_training.to(device)
213
+ ocr_model_for_training.train()
214
+
215
+ progress_message_placeholder.write("Training in progress... This may take a while.")
216
+
217
+ # Capture the model and history
218
+ ocr_model_for_training, history_result = train_ocr_model(
219
+ model=ocr_model_for_training,
220
+ train_loader=train_loader,
221
+ test_loader=test_loader,
222
+ char_indexer=char_indexer_for_training,
223
+ epochs=NUM_EPOCHS,
224
+ device=device,
225
+ progress_callback=update_progress_callback
226
+ )
227
+
228
+ st.session_state.training_history = history_result # Save history to session state
229
+
230
+ progress_message_placeholder.success("OCR model training finished!")
231
+ update_progress_callback(1.0, "Training complete!")
232
+
233
+ os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
234
+ save_ocr_model(ocr_model_for_training, MODEL_SAVE_PATH)
235
+ progress_message_placeholder.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
236
+
237
+ ocr_model = ocr_model_for_training
238
+ ocr_model.eval() # Set to eval mode for subsequent predictions
239
+
240
+ except Exception as e:
241
+ progress_message_placeholder.error(f"An error occurred during training: {e}")
242
+ st.exception(e) # This will print a detailed traceback in the Streamlit UI
243
+ update_progress_callback(0.0, "Training failed!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
+ st.write("---")
246
+
247
+ # --- Model Loading Section ---
248
+ st.subheader("2. Load Pre-trained Model")
249
+ st.write("If you have a saved model, you can load it here instead of training.")
250
+
251
+ if st.button("πŸ’Ύ Load Model"):
252
+ if os.path.exists(MODEL_SAVE_PATH):
253
+ try:
254
+ loaded_model_instance = CRNN(num_classes=char_indexer.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
255
+ load_ocr_model(loaded_model_instance, MODEL_SAVE_PATH)
256
+ loaded_model_instance.to(device)
257
+ ocr_model = loaded_model_instance
258
+ ocr_model.eval()
259
+ st.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`")
260
+
261
+ # If a model is loaded, we can try to re-evaluate it to get history,
262
+ # but typically history is stored from a training run.
263
+ # For simplicity, we'll assume training history is only stored after a training run.
264
+
265
+ except Exception as e:
266
+ st.error(f"Error loading model: {e}")
267
+ st.exception(e)
268
+ else:
269
+ st.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.")
270
+
271
+ st.write("---")
272
+
273
+ # --- Training History Plots Section ---
274
+ st.subheader("3. Training History Plots")
275
+ if st.session_state.training_history: # Check if history exists in session state
276
+ history_df = pd.DataFrame({
277
+ 'Epoch': range(1, len(st.session_state.training_history['train_loss']) + 1),
278
+ 'Train Loss': st.session_state.training_history['train_loss'],
279
+ 'Test Loss': st.session_state.training_history['test_loss'],
280
+ 'Test CER (%)': [cer * 100 for cer in st.session_state.training_history['test_cer']],
281
+ 'Test Exact Match Accuracy (%)': [acc * 100 for acc in st.session_state.training_history['test_exact_match_accuracy']]
282
+ })
283
+
284
+ st.markdown("**Loss over Epochs**")
285
+ st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
286
+ st.caption("Lower loss indicates better model performance.")
287
+
288
+ st.markdown("**Character Error Rate (CER) over Epochs**")
289
+ st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
290
+ st.caption("Lower CER indicates fewer character errors (0% is perfect).")
291
+
292
+ st.markdown("**Exact Match Accuracy over Epochs**")
293
+ st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
294
+ st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
295
+
296
+ st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
297
+ st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
298
+ st.caption("CER should decrease, Accuracy should increase.")
299
+ else:
300
+ st.info("Train the model first to see training history plots here.")
301
+
302
+
303
+ # --- Final Footer ---
304
+ col1, col2, col3 = st.columns([1, 3, 1])
305
+ with col2:
306
+ st.markdown("""
307
+ ---
308
+ *Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
309
+ """)