marianeft commited on
Commit
6388999
Β·
1 Parent(s): 04f0235

Updated UI, debug errors

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +237 -229
src/streamlit_app.py CHANGED
@@ -2,10 +2,6 @@
2
  # app.py
3
 
4
  import os
5
- # CRITICAL FIX: Disable Streamlit's file watcher to prevent conflicts with PyTorch
6
- # This MUST be the first thing, before any other imports or Streamlit calls
7
- os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false"
8
-
9
  import streamlit as st
10
  import pandas as pd
11
  import numpy as np
@@ -43,264 +39,276 @@ except Exception as e:
43
  if 'training_history' not in st.session_state:
44
  st.session_state.training_history = None
45
 
46
- ocr_model = None # Will be initialized by @st.cache_resource
47
- char_indexer = None # Will be initialized below
 
48
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
 
50
  # --- Streamlit App Setup ---
51
  st.set_page_config(layout="wide", page_title="Handwritten Name OCR App")
52
 
53
- col1, col2, col3 = st.columns([1, 3, 1])
54
- with col2:
 
55
  st.title("πŸ“ Handwritten Name Recognition (OCR) App")
56
 
57
  # --- Initialize CharIndexer ---
58
- # Wrap this in a try-except
59
  try:
60
  char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
61
  except Exception as e:
62
  st.error(f"FATAL ERROR: Could not initialize CharIndexer. Check config.py (VOCABULARY, BLANK_TOKEN_SYMBOL) and data_handler_ocr.py (CharIndexer class). Details: {e}")
63
  st.stop()
64
 
65
- # --- Define Tabs ---
66
- col1, col2, col3 = st.columns([1, 3, 1])
67
- with col2:
68
- tab1, tab2, tab3 = st.tabs(["Project Description", "Predict Name", "Train & Evaluate"])
69
 
70
- # --- Tab 1: Project Description ---
71
- with tab1:
72
- # Use columns for centering content within the tab
73
- st.markdown("""
74
- This application implements a Handwritten Name Recognition (OCR) system using a Convolutional Recurrent Neural Network (CRNN) built with PyTorch.
75
- Its core aim is to accurately convert handwritten text from images into digital format, providing a user-friendly interface via Streamlit.
76
-
77
- Here are some helpful resources related to this project:
78
- """)
79
- st.markdown("""
80
- **[πŸ“ƒ Project Documentation ](https://drive.google.com/file/d/1HBrQT_UnzNLdEsouW9wMk4alAeCsQxZb/view?usp=sharing)**
81
-
82
- **[🎞️ Demo Presentation ](https://drive.google.com/file/d/1j_S8cijxy6zxIn3cWg6tuLPNWB_7nwdI/view?usp=sharing)**
83
-
84
- **[πŸ“š Dataset (from Kaggle)](https://www.kaggle.com/datasets/landlord/handwriting-recognition)**
85
-
86
- **[πŸ“‚ Github Repository ](https://github.com/marianeft/handwritten_name_ocr_app)**
87
- """)
88
-
89
- # --- Tab 2: Predict Name (Main Content: Prediction Section) ---
90
- with tab2:
91
- st.markdown("Upload a clear image of a single handwritten name or word for recognition.")
92
-
93
- uploaded_file = st.file_uploader("πŸ–ΌοΈ Choose an image...", type=["png", "jpg", "jpeg", "jfif"])
94
- if uploaded_file is not None:
95
  try:
96
- image_pil = Image.open(uploaded_file).convert('L')
97
- st.image(image_pil, caption="Uploaded Image", use_container_width=True)
98
- st.write("---")
99
- st.write("Processing and Recognizing...")
100
-
101
- processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
102
-
103
- ocr_model.eval() # Ensure model is in eval mode for prediction
104
- with torch.no_grad():
105
- output = ocr_model(processed_image_tensor)
106
-
107
- predicted_texts = ctc_greedy_decode(output, char_indexer)
108
- predicted_text = predicted_texts[0]
109
-
110
- st.success(f"Recognized Text: **{predicted_text}**")
111
-
112
  except Exception as e:
113
- st.error(f"Error processing image or recognizing text: {e}")
114
- st.info("πŸ’‘ **Tips for best results:**\n"
115
- "- Ensure the handwritten text is clear and on a clean background.\n"
116
- "- Only include one name/word per image.\n"
117
- "- The model is trained on specific characters. Unusual symbols might not be recognized.")
118
- st.exception(e)
119
-
120
- else:
121
- st.warning("Model not loaded. Please train or load a model in the 'Train & Evaluate' tab before attempting prediction.")
122
-
123
- # --- Tab 3: Train & Evaluate ---
124
- with tab3:
125
- st.subheader("Model Training and Evaluation")
126
- st.markdown("Here you can train a new OCR model or load a pre-trained one.")
127
-
128
- # --- Model Loading / Initialization (Cached) ---
129
- @st.cache_resource # Cache the model to prevent reloading on every rerun
130
- def get_and_load_ocr_model_cached(num_classes, model_path):
131
- """
132
- Initializes the OCR model and attempts to load a pre-trained model.
133
- If no pre-trained model exists, a new model instance is returned.
134
- """
135
- model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
136
-
137
- if os.path.exists(model_path):
138
- st.info("Loading pre-trained OCR model...")
139
- try:
140
- # Load model to CPU first, then move to device
141
- model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
142
- st.success("OCR model loaded successfully!")
143
- except Exception as e:
144
- st.error(f"Error loading model from '{model_path}': {e}. A new model will be initialized.")
145
- # If loading fails, re-initialize an untrained model
146
- model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
147
- else:
148
- st.warning("No pre-trained OCR model found. Please train a model.")
149
-
150
- return model_instance
151
 
152
- # Wrap model loading in a try-except
153
- try:
154
- ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  ocr_model.to(device)
156
  ocr_model.eval() # Set model to evaluation mode for inference by default
157
- except Exception as e:
158
- st.error(f"FATAL ERROR: Could not initialize or load OCR model. Check model_ocr.py (CRNN class) or your saved model file. Details: {e}")
159
- st.stop()
160
-
161
 
 
 
 
162
 
163
- # --- Model Training Section ---
164
- st.subheader("Train OCR Model")
165
- st.write("Click the button below to start training the OCR model.")
166
 
167
- # Progress bar and label for training within this tab
168
- progress_container = st.empty() # Container for dynamic messages and progress
169
- progress_message_placeholder = st.empty()
170
- progress_bar_placeholder = st.progress(0)
171
-
172
- def update_progress_callback(value, text):
173
- progress_bar_placeholder.progress(int(value * 100))
174
- progress_message_placeholder.info(text) # Use info for dynamic messages
175
-
176
- if st.button("πŸ“Š Start Training"):
177
- progress_message_placeholder.empty() # Clear previous messages
178
- progress_bar_placeholder.progress(0) # Reset progress bar
179
-
180
- if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
181
- st.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found! Please check file paths and ensure data is uploaded correctly.")
182
- elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
183
- st.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
184
- "Evaluation might be affected or skipped. Please ensure all data paths are correct and data is uploaded.")
185
- else:
186
- progress_message_placeholder.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
187
-
188
- try:
189
- train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
190
- progress_message_placeholder.success("Training and Test DataFrames loaded successfully.")
191
- progress_message_placeholder.info(f"Train DataFrame size: {len(train_df)} samples")
192
- progress_message_placeholder.info(f"Test DataFrame size: {len(test_df)} samples")
193
- if len(test_df) == 0:
194
- progress_message_placeholder.error("ERROR: Test DataFrame is empty! Evaluation cannot proceed. Check TEST_CSV_PATH and TEST_IMAGES_DIR.")
195
- if len(train_df) == 0:
196
- progress_message_placeholder.error("ERROR: Train DataFrame is empty! Training cannot proceed. Check TRAIN_CSV_PATH and TRAIN_IMAGES_DIR.")
197
-
198
- if len(train_df) == 0 or len(test_df) == 0: # Stop if critical data is missing
199
- st.stop() # Added st.stop for critical data missing scenario
200
 
201
- char_indexer_for_training = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
202
- progress_message_placeholder.success(f"CharIndexer initialized with {char_indexer_for_training.num_classes} classes.")
 
 
 
 
 
 
203
 
204
- train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer_for_training, BATCH_SIZE)
205
- progress_message_placeholder.success("DataLoaders created successfully.")
206
-
207
- # Re-initialize the model to train from scratch if the button is pressed
208
- # This ensures we don't continue training a potentially already trained model if it was loaded.
209
- ocr_model_for_training = CRNN(num_classes=char_indexer_for_training.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
210
- ocr_model_for_training.to(device)
211
- ocr_model_for_training.train()
 
212
 
213
- progress_message_placeholder.write("Training in progress... This may take a while.")
214
-
215
- # Capture the model and history
216
- ocr_model_for_training, history_result = train_ocr_model(
217
- model=ocr_model_for_training,
218
- train_loader=train_loader,
219
- test_loader=test_loader,
220
- char_indexer=char_indexer_for_training,
221
- epochs=NUM_EPOCHS,
222
- device=device,
223
- progress_callback=update_progress_callback
224
- )
225
-
226
- st.session_state.training_history = history_result # Save history to session state
227
-
228
- progress_message_placeholder.success("OCR model training finished!")
229
- update_progress_callback(1.0, "Training complete!")
230
 
231
- os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
232
- save_ocr_model(ocr_model_for_training, MODEL_SAVE_PATH)
233
- progress_message_placeholder.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
- ocr_model = ocr_model_for_training
236
- ocr_model.eval() # Set to eval mode for subsequent predictions
 
 
 
 
237
 
238
- except Exception as e:
239
- progress_message_placeholder.error(f"An error occurred during training: {e}")
240
- st.exception(e) # This will print a detailed traceback in the Streamlit UI
241
- update_progress_callback(0.0, "Training failed!")
242
 
243
- st.write("---")
244
-
245
- # --- Model Loading Section ---
246
- st.subheader("Load Pre-trained Model")
247
- st.write("If you have a saved model, you can load it here instead of training.")
248
-
249
- if st.button("πŸ’Ύ Load Model"):
250
- if os.path.exists(MODEL_SAVE_PATH):
251
- try:
252
- loaded_model_instance = CRNN(num_classes=char_indexer.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
253
- load_ocr_model(loaded_model_instance, MODEL_SAVE_PATH)
254
- loaded_model_instance.to(device)
255
- ocr_model = loaded_model_instance
256
- ocr_model.eval()
257
- st.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`")
258
-
259
- # If a model is loaded, we can try to re-evaluate it to get history,
260
- # but typically history is stored from a training run.
261
- # For simplicity, we'll assume training history is only stored after a training run.
262
 
263
- except Exception as e:
264
- st.error(f"Error loading model: {e}")
265
- st.exception(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  else:
267
- st.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.")
268
-
269
- st.write("---")
270
-
271
- # --- Training History Plots Section ---
272
- st.subheader("Training History Plots")
273
- if st.session_state.training_history:
274
- history_df = pd.DataFrame({
275
- 'Epoch': range(1, len(st.session_state.training_history['train_loss']) + 1),
276
- 'Train Loss': st.session_state.training_history['train_loss'],
277
- 'Test Loss': st.session_state.training_history['test_loss'],
278
- 'Test CER (%)': [cer * 100 for cer in st.session_state.training_history['test_cer']],
279
- 'Test Exact Match Accuracy (%)': [acc * 100 for acc in st.session_state.training_history['test_exact_match_accuracy']]
280
- })
281
-
282
- st.markdown("**Loss over Epochs**")
283
- st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
284
- st.caption("Lower loss indicates better model performance.")
285
-
286
- st.markdown("**Character Error Rate (CER) over Epochs**")
287
- st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
288
- st.caption("Lower CER indicates fewer character errors (0% is perfect).")
289
-
290
- st.markdown("**Exact Match Accuracy over Epochs**")
291
- st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
292
- st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
293
-
294
- st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
295
- st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
296
- st.caption("CER should decrease, Accuracy should increase.")
297
- else:
298
- st.info("Train the model first to see training history plots here.")
299
-
300
 
301
- # --- Final Footer ---
302
- col1, col2, col3 = st.columns([1, 3, 1])
303
- with col2:
304
  st.markdown("""
305
  ---
306
  *Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
 
2
  # app.py
3
 
4
  import os
 
 
 
 
5
  import streamlit as st
6
  import pandas as pd
7
  import numpy as np
 
39
  if 'training_history' not in st.session_state:
40
  st.session_state.training_history = None
41
 
42
+ # Initialize ocr_model and char_indexer as None; they will be populated below
43
+ ocr_model = None
44
+ char_indexer = None
45
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
 
47
  # --- Streamlit App Setup ---
48
  st.set_page_config(layout="wide", page_title="Handwritten Name OCR App")
49
 
50
+ # Main Title and Description (Centered)
51
+ main_title_col1, main_title_col2, main_title_col3 = st.columns([1, 3, 1])
52
+ with main_title_col2:
53
  st.title("πŸ“ Handwritten Name Recognition (OCR) App")
54
 
55
  # --- Initialize CharIndexer ---
 
56
  try:
57
  char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
58
  except Exception as e:
59
  st.error(f"FATAL ERROR: Could not initialize CharIndexer. Check config.py (VOCABULARY, BLANK_TOKEN_SYMBOL) and data_handler_ocr.py (CharIndexer class). Details: {e}")
60
  st.stop()
61
 
 
 
 
 
62
 
63
+ # --- Model Loading / Initialization (Cached and Global) ---
64
+ @st.cache_resource
65
+ def get_and_load_ocr_model_cached_internal(num_classes, model_path):
66
+ """
67
+ Initializes the OCR model and attempts to load a pre-trained model.
68
+ Returns (model_instance, message_type, message_text)
69
+ """
70
+ model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
71
+ message_type = "warning"
72
+ message_text = "No pre-trained OCR model found. Please train a model using the 'Train & Evaluate' tab."
73
+
74
+ if os.path.exists(model_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  try:
76
+ model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
77
+ message_type = "success"
78
+ message_text = "OCR model loaded successfully!"
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  except Exception as e:
80
+ message_type = "error"
81
+ message_text = f"Error loading model from '{model_path}' during app startup: {e}. A new model will be initialized."
82
+ # If loading fails, re-initialize to a fresh model to avoid issues.
83
+ model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
84
+
85
+ return model_instance, message_type, message_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ # Display messages OUTSIDE the cached function
88
+ try:
89
+ loaded_model_instance, load_msg_type, load_msg_text = get_and_load_ocr_model_cached_internal(char_indexer.num_classes, MODEL_SAVE_PATH)
90
+
91
+ # Assign to global ocr_model
92
+ ocr_model = loaded_model_instance
93
+
94
+ # Display status messages as toasts
95
+ if load_msg_type == "success":
96
+ st.toast(load_msg_text, icon="βœ…")
97
+ elif load_msg_type == "warning":
98
+ st.toast(load_msg_text, icon="⚠️")
99
+ elif load_msg_type == "error":
100
+ st.toast(load_msg_text, icon="🚨")
101
+
102
+ if ocr_model is not None:
103
  ocr_model.to(device)
104
  ocr_model.eval() # Set model to evaluation mode for inference by default
105
+ else:
106
+ st.error("Model instance is None after cached load. Prediction will not be available.")
 
 
107
 
108
+ except Exception as e:
109
+ st.error(f"FATAL ERROR: Could not initialize or load OCR model during app startup (outer block). Check model_ocr.py (CRNN class) or your saved model file. Details: {e}")
110
+ st.stop()
111
 
 
 
 
112
 
113
+ # --- Define Tabs ---
114
+ tabs_col1, tabs_col2, tabs_col3 = st.columns([1, 3, 1])
115
+ with tabs_col2:
116
+ tab1, tab2, tab3 = st.tabs([" πŸ—¨οΈ Project Description", " πŸ”Ž Predict Name", " πŸ“ˆ Train & Evaluate"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ # --- Tab 1: Project Description ---
119
+ with tab1:
120
+ st.markdown("""
121
+ This application implements a Handwritten Name Recognition (OCR) system using a Convolutional Recurrent Neural Network (CRNN) built with PyTorch.
122
+ Its core aim is to accurately convert handwritten text from images into digital format, providing a user-friendly interface via Streamlit.
123
+
124
+ Here are some helpful resources related to this project:
125
+ """)
126
 
127
+ st.markdown("""
128
+ **[πŸ“ƒ Project Documentation ](https://drive.google.com/file/d/1HBrQT_UnzNLdEsouW9wMk4alAeCsQxZb/view?usp=sharing)**
129
+
130
+ **[🎞️ Demo Presentation ](https://drive.google.com/file/d/1j_S8cijxy6zxIn3cWg6tuLPNWB_7nwdI/view?usp=sharing)**
131
+
132
+ **[πŸ“š Dataset (from Kaggle)](https://www.kaggle.com/datasets/landlord/handwriting-recognition)**
133
+
134
+ **[πŸ“‚ Github Repository ](https://github.com/marianeft/handwritten_name_ocr_app)**
135
+ """)
136
 
137
+ # --- Tab 2: Predict Name (Main Content: Prediction Section) ---
138
+ with tab2:
139
+ st.markdown("Upload a clear image of a single handwritten name or word for recognition.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ # Check the global ocr_model for prediction availability
142
+ if ocr_model is None:
143
+ st.warning("Model not loaded. Please train or load a model in the 'Train & Evaluate' tab before attempting prediction.")
144
+ else:
145
+ uploaded_file = st.file_uploader("πŸ–ΌοΈ Choose an image...", type=["png", "jpg", "jpeg", "jfif"])
146
+
147
+ if uploaded_file is not None:
148
+ try:
149
+ image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale
150
+ st.image(image_pil, caption="Uploaded Image", use_container_width=True)
151
+ st.write("---")
152
+ st.write("Processing and Recognizing...")
153
+
154
+ processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
155
+
156
+ ocr_model.eval() # Ensure model is in eval mode for prediction
157
+ with torch.no_grad():
158
+ output = ocr_model(processed_image_tensor)
159
+
160
+ predicted_texts = ctc_greedy_decode(output, char_indexer)
161
+ predicted_text = predicted_texts[0]
162
+
163
+ st.success(f"Recognized Text: **{predicted_text}**")
164
+
165
+ except Exception as e:
166
+ st.error(f"Error processing image or recognizing text: {e}")
167
+ st.info("πŸ’‘ **Tips for best results:**\n"
168
+ "- Ensure the handwritten text is clear and on a clean background.\n"
169
+ "- Only include one name/word per image.\n"
170
+ "- The model is trained on specific characters. Unusual symbols might not be recognized.")
171
+ st.exception(e) # Display full traceback for debugging
172
 
173
+ # --- Tab 3: Train & Evaluate ---
174
+ with tab3:
175
+
176
+ # --- Model Training Section ---
177
+ st.subheader("Train OCR Model")
178
+ st.write("Click the button below to start training the OCR model.")
179
 
180
+ # Progress bar and label for training within this tab
181
+ progress_message_placeholder = st.empty()
182
+ progress_bar_placeholder = st.progress(0)
 
183
 
184
+ def update_progress_callback(value, text):
185
+ progress_bar_placeholder.progress(int(value * 100))
186
+ progress_message_placeholder.info(text) # Use info for dynamic messages
187
+
188
+ if st.button("πŸ“Š Start Training"):
189
+ progress_message_placeholder.empty() # Clear previous messages
190
+ progress_bar_placeholder.progress(0) # Reset progress bar
191
+
192
+ if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
193
+ st.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found! Please check file paths and ensure data is uploaded correctly.")
194
+ elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
195
+ st.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
196
+ "Evaluation might be affected or skipped. Please ensure all data paths are correct and data is uploaded.")
197
+ else:
198
+ progress_message_placeholder.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
 
 
 
 
199
 
200
+ try:
201
+ train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
202
+ progress_message_placeholder.success("Training and Test DataFrames loaded successfully.")
203
+ progress_message_placeholder.info(f"Train DataFrame size: {len(train_df)} samples")
204
+ progress_message_placeholder.info(f"Test DataFrame size: {len(test_df)} samples")
205
+ if len(test_df) == 0:
206
+ progress_message_placeholder.error("ERROR: Test DataFrame is empty! Evaluation cannot proceed. Check TEST_CSV_PATH and TEST_IMAGES_DIR.")
207
+ if len(train_df) == 0:
208
+ progress_message_placeholder.error("ERROR: Train DataFrame is empty! Training cannot proceed. Check TRAIN_CSV_PATH and TRAIN_IMAGES_DIR.")
209
+
210
+ if len(train_df) == 0 or len(test_df) == 0: # Stop if critical data is missing
211
+ st.stop() # Added st.stop for critical data missing scenario
212
+
213
+ char_indexer_for_training = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
214
+ progress_message_placeholder.success(f"CharIndexer initialized with {char_indexer_for_training.num_classes} classes.")
215
+
216
+ train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer_for_training, BATCH_SIZE)
217
+ progress_message_placeholder.success("DataLoaders created successfully.")
218
+
219
+ ocr_model_for_training = CRNN(num_classes=char_indexer_for_training.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
220
+ ocr_model_for_training.to(device)
221
+ ocr_model_for_training.train() # Set to train mode before passing
222
+
223
+ progress_message_placeholder.write("Training in progress... This may take a while.")
224
+
225
+ ocr_model_for_training, history_result = train_ocr_model(
226
+ model=ocr_model_for_training, # Pass the local ocr_model_for_training instance
227
+ train_loader=train_loader,
228
+ test_loader=test_loader,
229
+ char_indexer=char_indexer_for_training,
230
+ epochs=NUM_EPOCHS,
231
+ device=device,
232
+ progress_callback=update_progress_callback
233
+ )
234
+
235
+ st.session_state.training_history = history_result # Save history to session state
236
+
237
+ progress_message_placeholder.success("OCR model training finished!")
238
+ update_progress_callback(1.0, "Training complete!")
239
+
240
+ os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
241
+ save_ocr_model(ocr_model_for_training, MODEL_SAVE_PATH) # Save the now trained ocr_model_for_training
242
+ progress_message_placeholder.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
243
+
244
+ # Crucial: Update the global ocr_model with the newly trained one
245
+ ocr_model = ocr_model_for_training
246
+ ocr_model.eval() # Set to eval mode for subsequent predictions
247
+
248
+ except Exception as e:
249
+ progress_message_placeholder.error(f"An error occurred during training: {e}")
250
+ st.exception(e) # This will print a detailed traceback in the Streamlit UI
251
+ update_progress_callback(0.0, "Training failed!")
252
+
253
+ st.write("---")
254
+
255
+ # --- Model Loading Section ---
256
+ st.subheader("Load Pre-trained Model")
257
+ st.write("If you have a saved model, you can load it here instead of training.")
258
+
259
+ if st.button("πŸ’Ύ Load Model"):
260
+ if os.path.exists(MODEL_SAVE_PATH):
261
+ try:
262
+ loaded_model_instance = CRNN(num_classes=char_indexer.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
263
+ load_ocr_model(loaded_model_instance, MODEL_SAVE_PATH)
264
+ loaded_model_instance.to(device)
265
+ ocr_model = loaded_model_instance # Update global model reference
266
+ ocr_model.eval() # Set to eval mode after loading
267
+ st.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`")
268
+
269
+ # For simplicity, training history is only populated after a training run.
270
+ # If you need to load history with the model, it would need to be saved separately.
271
+
272
+ except Exception as e:
273
+ st.error(f"Error loading model: {e}")
274
+ st.exception(e)
275
+ else:
276
+ st.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.")
277
+
278
+ st.write("---")
279
+
280
+ # --- Training History Plots Section ---
281
+ st.subheader("Training History Plots")
282
+ if st.session_state.training_history: # Check if history exists in session state
283
+ history_df = pd.DataFrame({
284
+ 'Epoch': range(1, len(st.session_state.training_history['train_loss']) + 1),
285
+ 'Train Loss': st.session_state.training_history['train_loss'],
286
+ 'Test Loss': st.session_state.training_history['test_loss'],
287
+ 'Test CER (%)': [cer * 100 for cer in st.session_state.training_history['test_cer']],
288
+ 'Test Exact Match Accuracy (%)': [acc * 100 for acc in st.session_state.training_history['test_exact_match_accuracy']]
289
+ })
290
+
291
+ st.markdown("**Loss over Epochs**")
292
+ st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
293
+ st.caption("Lower loss indicates better model performance.")
294
+
295
+ st.markdown("**Character Error Rate (CER) over Epochs**")
296
+ st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
297
+ st.caption("Lower CER indicates fewer character errors (0% is perfect).")
298
+
299
+ st.markdown("**Exact Match Accuracy over Epochs**")
300
+ st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
301
+ st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
302
+
303
+ st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
304
+ st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
305
+ st.caption("CER should decrease, Accuracy should increase.")
306
  else:
307
+ st.info("Train the model first to see training history plots here.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
+ # --- Final Footer (Centered) ---
310
+ footer_col1, footer_col2, footer_col3 = st.columns([1, 3, 1])
311
+ with footer_col2:
312
  st.markdown("""
313
  ---
314
  *Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*