marianeft commited on
Commit
0385397
Β·
1 Parent(s): af3f1e7
Files changed (1) hide show
  1. src/streamlit_app.py +126 -126
src/streamlit_app.py CHANGED
@@ -162,142 +162,142 @@ with tab3:
162
 
163
 
164
 
165
- # --- Model Training Section ---
166
- st.subheader("1. Train OCR Model")
167
- st.write("Click the button below to start training the OCR model.")
168
-
169
- # Progress bar and label for training within this tab
170
- progress_container = st.empty() # Container for dynamic messages and progress
171
- progress_message_placeholder = st.empty()
172
- progress_bar_placeholder = st.progress(0)
173
-
174
- def update_progress_callback(value, text):
175
- progress_bar_placeholder.progress(int(value * 100))
176
- progress_message_placeholder.info(text) # Use info for dynamic messages
177
-
178
- if st.button("πŸ“Š Start Training"):
179
- progress_message_placeholder.empty() # Clear previous messages
180
- progress_bar_placeholder.progress(0) # Reset progress bar
181
-
182
- if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
183
- st.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found! Please check file paths and ensure data is uploaded correctly.")
184
- elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
185
- st.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
186
- "Evaluation might be affected or skipped. Please ensure all data paths are correct and data is uploaded.")
187
- else:
188
- progress_message_placeholder.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
 
 
 
 
 
 
 
 
 
 
189
 
190
- try:
191
- train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
192
- progress_message_placeholder.success("Training and Test DataFrames loaded successfully.")
193
- progress_message_placeholder.info(f"Train DataFrame size: {len(train_df)} samples")
194
- progress_message_placeholder.info(f"Test DataFrame size: {len(test_df)} samples")
195
- if len(test_df) == 0:
196
- progress_message_placeholder.error("ERROR: Test DataFrame is empty! Evaluation cannot proceed. Check TEST_CSV_PATH and TEST_IMAGES_DIR.")
197
- if len(train_df) == 0:
198
- progress_message_placeholder.error("ERROR: Train DataFrame is empty! Training cannot proceed. Check TRAIN_CSV_PATH and TRAIN_IMAGES_DIR.")
199
-
200
- if len(train_df) == 0 or len(test_df) == 0: # Stop if critical data is missing
201
- st.stop() # Added st.stop for critical data missing scenario
202
 
203
- char_indexer_for_training = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
204
- progress_message_placeholder.success(f"CharIndexer initialized with {char_indexer_for_training.num_classes} classes.")
205
 
206
- train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer_for_training, BATCH_SIZE)
207
- progress_message_placeholder.success("DataLoaders created successfully.")
208
-
209
- # Re-initialize the model to train from scratch if the button is pressed
210
- # This ensures we don't continue training a potentially already trained model if it was loaded.
211
- ocr_model_for_training = CRNN(num_classes=char_indexer_for_training.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
212
- ocr_model_for_training.to(device)
213
- ocr_model_for_training.train()
214
 
215
- progress_message_placeholder.write("Training in progress... This may take a while.")
216
-
217
- # Capture the model and history
218
- ocr_model_for_training, history_result = train_ocr_model(
219
- model=ocr_model_for_training,
220
- train_loader=train_loader,
221
- test_loader=test_loader,
222
- char_indexer=char_indexer_for_training,
223
- epochs=NUM_EPOCHS,
224
- device=device,
225
- progress_callback=update_progress_callback
226
- )
227
-
228
- st.session_state.training_history = history_result # Save history to session state
229
-
230
- progress_message_placeholder.success("OCR model training finished!")
231
- update_progress_callback(1.0, "Training complete!")
232
 
233
- os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
234
- save_ocr_model(ocr_model_for_training, MODEL_SAVE_PATH)
235
- progress_message_placeholder.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
236
 
237
- ocr_model = ocr_model_for_training
238
- ocr_model.eval() # Set to eval mode for subsequent predictions
239
 
240
- except Exception as e:
241
- progress_message_placeholder.error(f"An error occurred during training: {e}")
242
- st.exception(e) # This will print a detailed traceback in the Streamlit UI
243
- update_progress_callback(0.0, "Training failed!")
244
-
245
- st.write("---")
246
 
247
- # --- Model Loading Section ---
248
- st.subheader("2. Load Pre-trained Model")
249
- st.write("If you have a saved model, you can load it here instead of training.")
250
 
251
- if st.button("πŸ’Ύ Load Model"):
252
- if os.path.exists(MODEL_SAVE_PATH):
253
- try:
254
- loaded_model_instance = CRNN(num_classes=char_indexer.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
255
- load_ocr_model(loaded_model_instance, MODEL_SAVE_PATH)
256
- loaded_model_instance.to(device)
257
- ocr_model = loaded_model_instance
258
- ocr_model.eval()
259
- st.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`")
260
-
261
- # If a model is loaded, we can try to re-evaluate it to get history,
262
- # but typically history is stored from a training run.
263
- # For simplicity, we'll assume training history is only stored after a training run.
264
-
265
- except Exception as e:
266
- st.error(f"Error loading model: {e}")
267
- st.exception(e)
268
- else:
269
- st.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.")
270
-
271
- st.write("---")
272
-
273
- # --- Training History Plots Section ---
274
- st.subheader("3. Training History Plots")
275
- if st.session_state.training_history: # Check if history exists in session state
276
- history_df = pd.DataFrame({
277
- 'Epoch': range(1, len(st.session_state.training_history['train_loss']) + 1),
278
- 'Train Loss': st.session_state.training_history['train_loss'],
279
- 'Test Loss': st.session_state.training_history['test_loss'],
280
- 'Test CER (%)': [cer * 100 for cer in st.session_state.training_history['test_cer']],
281
- 'Test Exact Match Accuracy (%)': [acc * 100 for acc in st.session_state.training_history['test_exact_match_accuracy']]
282
- })
283
-
284
- st.markdown("**Loss over Epochs**")
285
- st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
286
- st.caption("Lower loss indicates better model performance.")
287
-
288
- st.markdown("**Character Error Rate (CER) over Epochs**")
289
- st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
290
- st.caption("Lower CER indicates fewer character errors (0% is perfect).")
291
-
292
- st.markdown("**Exact Match Accuracy over Epochs**")
293
- st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
294
- st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
295
-
296
- st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
297
- st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
298
- st.caption("CER should decrease, Accuracy should increase.")
299
  else:
300
- st.info("Train the model first to see training history plots here.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
 
303
  # --- Final Footer ---
 
162
 
163
 
164
 
165
+ # --- Model Training Section ---
166
+ st.subheader("Train OCR Model")
167
+ st.write("Click the button below to start training the OCR model.")
168
+
169
+ # Progress bar and label for training within this tab
170
+ progress_container = st.empty() # Container for dynamic messages and progress
171
+ progress_message_placeholder = st.empty()
172
+ progress_bar_placeholder = st.progress(0)
173
+
174
+ def update_progress_callback(value, text):
175
+ progress_bar_placeholder.progress(int(value * 100))
176
+ progress_message_placeholder.info(text) # Use info for dynamic messages
177
+
178
+ if st.button("πŸ“Š Start Training"):
179
+ progress_message_placeholder.empty() # Clear previous messages
180
+ progress_bar_placeholder.progress(0) # Reset progress bar
181
+
182
+ if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
183
+ st.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found! Please check file paths and ensure data is uploaded correctly.")
184
+ elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
185
+ st.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
186
+ "Evaluation might be affected or skipped. Please ensure all data paths are correct and data is uploaded.")
187
+ else:
188
+ progress_message_placeholder.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
189
+
190
+ try:
191
+ train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
192
+ progress_message_placeholder.success("Training and Test DataFrames loaded successfully.")
193
+ progress_message_placeholder.info(f"Train DataFrame size: {len(train_df)} samples")
194
+ progress_message_placeholder.info(f"Test DataFrame size: {len(test_df)} samples")
195
+ if len(test_df) == 0:
196
+ progress_message_placeholder.error("ERROR: Test DataFrame is empty! Evaluation cannot proceed. Check TEST_CSV_PATH and TEST_IMAGES_DIR.")
197
+ if len(train_df) == 0:
198
+ progress_message_placeholder.error("ERROR: Train DataFrame is empty! Training cannot proceed. Check TRAIN_CSV_PATH and TRAIN_IMAGES_DIR.")
199
 
200
+ if len(train_df) == 0 or len(test_df) == 0: # Stop if critical data is missing
201
+ st.stop() # Added st.stop for critical data missing scenario
 
 
 
 
 
 
 
 
 
 
202
 
203
+ char_indexer_for_training = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
204
+ progress_message_placeholder.success(f"CharIndexer initialized with {char_indexer_for_training.num_classes} classes.")
205
 
206
+ train_loader, test_loader = create_ocr_dataloaders(train_df, test_df, char_indexer_for_training, BATCH_SIZE)
207
+ progress_message_placeholder.success("DataLoaders created successfully.")
208
+
209
+ # Re-initialize the model to train from scratch if the button is pressed
210
+ # This ensures we don't continue training a potentially already trained model if it was loaded.
211
+ ocr_model_for_training = CRNN(num_classes=char_indexer_for_training.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
212
+ ocr_model_for_training.to(device)
213
+ ocr_model_for_training.train()
214
 
215
+ progress_message_placeholder.write("Training in progress... This may take a while.")
216
+
217
+ # Capture the model and history
218
+ ocr_model_for_training, history_result = train_ocr_model(
219
+ model=ocr_model_for_training,
220
+ train_loader=train_loader,
221
+ test_loader=test_loader,
222
+ char_indexer=char_indexer_for_training,
223
+ epochs=NUM_EPOCHS,
224
+ device=device,
225
+ progress_callback=update_progress_callback
226
+ )
227
+
228
+ st.session_state.training_history = history_result # Save history to session state
229
+
230
+ progress_message_placeholder.success("OCR model training finished!")
231
+ update_progress_callback(1.0, "Training complete!")
232
 
233
+ os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
234
+ save_ocr_model(ocr_model_for_training, MODEL_SAVE_PATH)
235
+ progress_message_placeholder.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
236
 
237
+ ocr_model = ocr_model_for_training
238
+ ocr_model.eval() # Set to eval mode for subsequent predictions
239
 
240
+ except Exception as e:
241
+ progress_message_placeholder.error(f"An error occurred during training: {e}")
242
+ st.exception(e) # This will print a detailed traceback in the Streamlit UI
243
+ update_progress_callback(0.0, "Training failed!")
244
+
245
+ st.write("---")
246
 
247
+ # --- Model Loading Section ---
248
+ st.subheader("Load Pre-trained Model")
249
+ st.write("If you have a saved model, you can load it here instead of training.")
250
 
251
+ if st.button("πŸ’Ύ Load Model"):
252
+ if os.path.exists(MODEL_SAVE_PATH):
253
+ try:
254
+ loaded_model_instance = CRNN(num_classes=char_indexer.num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
255
+ load_ocr_model(loaded_model_instance, MODEL_SAVE_PATH)
256
+ loaded_model_instance.to(device)
257
+ ocr_model = loaded_model_instance
258
+ ocr_model.eval()
259
+ st.success(f"Model loaded successfully from `{MODEL_SAVE_PATH}`")
260
+
261
+ # If a model is loaded, we can try to re-evaluate it to get history,
262
+ # but typically history is stored from a training run.
263
+ # For simplicity, we'll assume training history is only stored after a training run.
264
+
265
+ except Exception as e:
266
+ st.error(f"Error loading model: {e}")
267
+ st.exception(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  else:
269
+ st.warning(f"No model found at `{MODEL_SAVE_PATH}`. Please train a model first or check the path.")
270
+
271
+ st.write("---")
272
+
273
+ # --- Training History Plots Section ---
274
+ st.subheader("Training History Plots")
275
+ if st.session_state.training_history: # Check if history exists in session state
276
+ history_df = pd.DataFrame({
277
+ 'Epoch': range(1, len(st.session_state.training_history['train_loss']) + 1),
278
+ 'Train Loss': st.session_state.training_history['train_loss'],
279
+ 'Test Loss': st.session_state.training_history['test_loss'],
280
+ 'Test CER (%)': [cer * 100 for cer in st.session_state.training_history['test_cer']],
281
+ 'Test Exact Match Accuracy (%)': [acc * 100 for acc in st.session_state.training_history['test_exact_match_accuracy']]
282
+ })
283
+
284
+ st.markdown("**Loss over Epochs**")
285
+ st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
286
+ st.caption("Lower loss indicates better model performance.")
287
+
288
+ st.markdown("**Character Error Rate (CER) over Epochs**")
289
+ st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
290
+ st.caption("Lower CER indicates fewer character errors (0% is perfect).")
291
+
292
+ st.markdown("**Exact Match Accuracy over Epochs**")
293
+ st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
294
+ st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
295
+
296
+ st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
297
+ st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
298
+ st.caption("CER should decrease, Accuracy should increase.")
299
+ else:
300
+ st.info("Train the model first to see training history plots here.")
301
 
302
 
303
  # --- Final Footer ---