marianeft commited on
Commit
1ae7e6f
Β·
verified Β·
1 Parent(s): 5b1f074

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -315
app.py CHANGED
@@ -1,6 +1,11 @@
1
- <<<<<<< HEAD
2
  # app.py
3
 
 
 
 
 
 
4
  import streamlit as st
5
  import pandas as pd
6
  import numpy as np
@@ -8,20 +13,32 @@ from PIL import Image
8
  import torch
9
  import torch.nn.functional as F # Added F for log_softmax in inference
10
  import torchvision.transforms as transforms
11
- import os
12
  import traceback # For detailed error logging
13
 
14
- # Import custom modules
15
- from config import CHARS, BLANK_TOKEN, IMG_HEIGHT, TRAIN_CSV_PATH, TEST_CSV_PATH, \
16
- TRAIN_IMAGES_DIR, TEST_IMAGES_DIR, MODEL_SAVE_PATH, NUM_CLASSES, NUM_EPOCHS, BATCH_SIZE
17
- from data_handler_ocr import CharIndexer, OCRDataset
 
 
 
 
 
18
  from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
19
- from utils_ocr import preprocess_user_image_for_ocr
 
 
 
 
 
 
 
 
20
 
21
  # --- Streamlit App Setup ---
22
- st.set_page_config(page_title="Handwritten Name Recognizer", layout="centered")
23
 
24
- st.title("πŸ“ Handwritten Name Recognition (OCR)")
25
  st.markdown("""
26
  This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
27
  Optical Character Recognition (OCR) on handwritten names. You can upload an image
@@ -31,11 +48,9 @@ st.markdown("""
31
  """)
32
 
33
  # --- Initialize CharIndexer ---
34
- # The CHARS variable should contain all possible characters your model can recognize.
35
- # Make sure it's comprehensive based on your dataset.
36
- char_indexer = CharIndexer(CHARS, BLANK_TOKEN)
37
- # For robustness, it's best to always use char_indexer.num_classes
38
- # If NUM_CLASSES from config is used to initialize CRNN, ensure it matches char_indexer.num_classes
39
 
40
  # --- Model Loading / Initialization ---
41
  @st.cache_resource # Cache the model to prevent reloading on every rerun
@@ -64,7 +79,6 @@ def get_and_load_ocr_model_cached(num_classes, model_path):
64
  # Get the model instance
65
  ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
66
  # Determine the device (GPU if available, else CPU)
67
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
  ocr_model.to(device)
69
  ocr_model.eval() # Set model to evaluation mode for inference by default
70
 
@@ -73,102 +87,97 @@ st.sidebar.header("Model Training (Optional)")
73
  st.sidebar.markdown("If you want to train a new model or no model is found:")
74
 
75
  # Initialize Streamlit widgets outside the button block
76
- training_progress_bar = st.sidebar.empty() # Placeholder for progress bar
77
- status_text = st.sidebar.empty() # Placeholder for status messages
78
 
79
- if st.sidebar.button("πŸ“Š Train New OCR Model"):
80
  # Clear previous messages/widgets if button is clicked again
 
81
  training_progress_bar.empty()
82
- status_text.empty()
83
 
84
  # Check for existence of CSVs and image directories
85
- if not os.path.exists(TRAIN_CSV_PATH) or not os.path.exists(TEST_CSV_PATH) or \
86
- not os.path.isdir(TRAIN_IMAGES_DIR) or not os.path.isdir(TEST_IMAGES_DIR):
87
- status_text.error(f"""Dataset files or image directories not found.
88
- Please ensure '{TRAIN_CSV_PATH}', '{TEST_CSV_PATH}', and directories '{TRAIN_IMAGES_DIR}'
89
- and '{TEST_IMAGES_DIR}' exist. Refer to your project structure.""")
90
  else:
91
- status_text.write(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
92
 
 
93
  training_progress_bar_instance = training_progress_bar.progress(0.0, text="Training in progress. Please wait.")
94
 
 
 
 
 
 
95
  try:
96
- train_df = pd.read_csv(TRAIN_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
97
- test_df = pd.read_csv(TEST_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
98
-
99
- # Define standard image transforms for consistency
100
- train_transform = transforms.Compose([
101
- transforms.Resize((IMG_HEIGHT, 100)), # Resize to fixed height, width will be 100 (adjust as needed for variable width)
102
- transforms.ToTensor(), # Converts PIL Image to PyTorch Tensor (H, W) -> (C, H, W), normalizes to [0,1]
103
- ])
104
- test_transform = transforms.Compose([
105
- transforms.Resize((IMG_HEIGHT, 100)), # Same transformation as train
106
- transforms.ToTensor(),
107
- ])
108
-
109
- # Create dataset instances
110
- train_dataset = OCRDataset(dataframe=train_df, char_indexer=char_indexer, image_dir=TRAIN_IMAGES_DIR, transform=train_transform)
111
- test_dataset = OCRDataset(dataframe=test_df, char_indexer=char_indexer, image_dir=TEST_IMAGES_DIR, transform=test_transform)
112
-
113
- # Create DataLoader instances
114
- train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) # num_workers=0 for Windows
115
- test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
116
-
117
- # Train the model, passing the progress callback
118
  trained_ocr_model, training_history = train_ocr_model(
119
- ocr_model, # Pass the initialized model instance
120
- train_loader,
121
- test_loader,
122
- char_indexer, # Pass char_indexer for CER calculation
123
  epochs=NUM_EPOCHS,
124
  device=device,
125
- progress_callback=training_progress_bar_instance.progress # Pass the instance's progress method
126
  )
 
 
127
 
128
- # Ensure the directory for saving the model exists
129
  os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
130
  save_ocr_model(trained_ocr_model, MODEL_SAVE_PATH)
131
- status_text.success(f"Model training complete and saved to `{MODEL_SAVE_PATH}`!")
132
-
133
- # Display training history chart
134
- st.sidebar.subheader("Training History Plots")
135
-
136
- history_df = pd.DataFrame({
137
- 'Epoch': range(1, len(training_history['train_loss']) + 1),
138
- 'Train Loss': training_history['train_loss'],
139
- 'Test Loss': training_history['test_loss'],
140
- 'Test CER (%)': [cer * 100 for cer in training_history['test_cer']], # Convert CER to percentage for display
141
- 'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']] # Convert to percentage
142
- })
143
-
144
- # Plot 1: Training and Test Loss
145
- st.sidebar.markdown("**Loss over Epochs**")
146
- st.sidebar.line_chart(
147
- history_df.set_index('Epoch')[['Train Loss', 'Test Loss']]
148
- )
149
- st.sidebar.caption("Lower loss indicates better model performance.")
150
-
151
- # Plot 2: Character Error Rate (CER)
152
- st.sidebar.markdown("**Character Error Rate (CER) over Epochs**")
153
- st.sidebar.line_chart(
154
- history_df.set_index('Epoch')[['Test CER (%)']]
155
- )
156
- st.sidebar.caption("Lower CER indicates fewer character errors (0% is perfect).")
157
-
158
- # Plot 3: Exact Match Accuracy
159
- st.sidebar.markdown("**Exact Match Accuracy over Epochs**")
160
- st.sidebar.line_chart(
161
- history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']]
162
- )
163
- st.sidebar.caption("Higher exact match accuracy indicates more perfectly recognized names.")
164
-
165
- # Update the global model instance to the newly trained one for immediate inference
166
- ocr_model = trained_ocr_model
167
- ocr_model.eval()
168
 
169
  except Exception as e:
170
  status_text.error(f"An error occurred during training: {e}")
171
- st.sidebar.text(traceback.format_exc()) # Show full traceback for debugging
 
172
 
173
  # --- Main Content: Name Prediction ---
174
  st.header("Predict Your Handwritten Name")
@@ -180,22 +189,19 @@ if uploaded_file is not None:
180
  try:
181
  # Open the uploaded image
182
  image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale
183
- st.image(image_pil, caption="Uploaded Image", use_column_width=True)
 
184
  st.write("---")
185
  st.write("Processing and Recognizing...")
186
 
187
  # Preprocess the image for the model using utils_ocr function
188
  processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
189
-
190
- # Make prediction
191
- ocr_model.eval() # Ensure model is in evaluation mode
192
  with torch.no_grad(): # Disable gradient calculation for inference
193
- output = ocr_model(processed_image_tensor) # (sequence_length, batch_size, num_classes)
194
-
195
- # ctc_greedy_decode expects (sequence_length, batch_size, num_classes)
196
- # It returns a list of strings, so get the first element for single image inference.
197
- predicted_texts = ctc_greedy_decode(output, char_indexer)
198
- predicted_text = predicted_texts[0] # Get the first (and only) prediction
199
 
200
  st.success(f"Recognized Text: **{predicted_text}**")
201
 
@@ -205,222 +211,9 @@ if uploaded_file is not None:
205
  "- Ensure the handwritten text is clear and on a clean background.\n"
206
  "- Only include one name/word per image.\n"
207
  "- The model is trained on specific characters. Unusual symbols might not be recognized.")
208
- st.text(traceback.format_exc())
209
 
210
  st.markdown("""
211
  ---
212
  *Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
213
- =======
214
- # app.py
215
-
216
- import streamlit as st
217
- import pandas as pd
218
- import numpy as np
219
- from PIL import Image
220
- import torch
221
- import torch.nn.functional as F # Added F for log_softmax in inference
222
- import torchvision.transforms as transforms
223
- import os
224
- import traceback # For detailed error logging
225
-
226
- # Import custom modules
227
- from config import CHARS, BLANK_TOKEN, IMG_HEIGHT, TRAIN_CSV_PATH, TEST_CSV_PATH, \
228
- TRAIN_IMAGES_DIR, TEST_IMAGES_DIR, MODEL_SAVE_PATH, NUM_CLASSES, NUM_EPOCHS, BATCH_SIZE
229
- from data_handler_ocr import CharIndexer, OCRDataset
230
- from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
231
- from utils_ocr import preprocess_user_image_for_ocr
232
-
233
- # --- Streamlit App Setup ---
234
- st.set_page_config(page_title="Handwritten Name Recognizer", layout="centered")
235
-
236
- st.title("πŸ“ Handwritten Name Recognition (OCR)")
237
- st.markdown("""
238
- This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
239
- Optical Character Recognition (OCR) on handwritten names. You can upload an image
240
- of a handwritten name for prediction or train a new model using the provided dataset.
241
-
242
- **Note:** Training a robust OCR model can be time-consuming.
243
- """)
244
-
245
- # --- Initialize CharIndexer ---
246
- # The CHARS variable should contain all possible characters your model can recognize.
247
- # Make sure it's comprehensive based on your dataset.
248
- char_indexer = CharIndexer(CHARS, BLANK_TOKEN)
249
- # For robustness, it's best to always use char_indexer.num_classes
250
- # If NUM_CLASSES from config is used to initialize CRNN, ensure it matches char_indexer.num_classes
251
-
252
- # --- Model Loading / Initialization ---
253
- @st.cache_resource # Cache the model to prevent reloading on every rerun
254
- def get_and_load_ocr_model_cached(num_classes, model_path):
255
- """
256
- Initializes the OCR model and attempts to load a pre-trained model.
257
- If no pre-trained model exists, a new model instance is returned.
258
- """
259
- model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
260
-
261
- if os.path.exists(model_path):
262
- st.sidebar.info("Loading pre-trained OCR model...")
263
- try:
264
- # Load model to CPU first, then move to device
265
- model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
266
- st.sidebar.success("OCR model loaded successfully!")
267
- except Exception as e:
268
- st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
269
- # If loading fails, re-initialize an untrained model
270
- model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
271
- else:
272
- st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
273
-
274
- return model_instance
275
-
276
- # Get the model instance
277
- ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
278
- # Determine the device (GPU if available, else CPU)
279
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
280
- ocr_model.to(device)
281
- ocr_model.eval() # Set model to evaluation mode for inference by default
282
-
283
- # --- Sidebar for Model Training ---
284
- st.sidebar.header("Model Training (Optional)")
285
- st.sidebar.markdown("If you want to train a new model or no model is found:")
286
-
287
- # Initialize Streamlit widgets outside the button block
288
- training_progress_bar = st.sidebar.empty() # Placeholder for progress bar
289
- status_text = st.sidebar.empty() # Placeholder for status messages
290
-
291
- if st.sidebar.button("πŸ“Š Train New OCR Model"):
292
- # Clear previous messages/widgets if button is clicked again
293
- training_progress_bar.empty()
294
- status_text.empty()
295
-
296
- # Check for existence of CSVs and image directories
297
- if not os.path.exists(TRAIN_CSV_PATH) or not os.path.exists(TEST_CSV_PATH) or \
298
- not os.path.isdir(TRAIN_IMAGES_DIR) or not os.path.isdir(TEST_IMAGES_DIR):
299
- status_text.error(f"""Dataset files or image directories not found.
300
- Please ensure '{TRAIN_CSV_PATH}', '{TEST_CSV_PATH}', and directories '{TRAIN_IMAGES_DIR}'
301
- and '{TEST_IMAGES_DIR}' exist. Refer to your project structure.""")
302
- else:
303
- status_text.write(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
304
-
305
- training_progress_bar_instance = training_progress_bar.progress(0.0, text="Training in progress. Please wait.")
306
-
307
- try:
308
- train_df = pd.read_csv(TRAIN_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
309
- test_df = pd.read_csv(TEST_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
310
-
311
- # Define standard image transforms for consistency
312
- train_transform = transforms.Compose([
313
- transforms.Resize((IMG_HEIGHT, 100)), # Resize to fixed height, width will be 100 (adjust as needed for variable width)
314
- transforms.ToTensor(), # Converts PIL Image to PyTorch Tensor (H, W) -> (C, H, W), normalizes to [0,1]
315
- ])
316
- test_transform = transforms.Compose([
317
- transforms.Resize((IMG_HEIGHT, 100)), # Same transformation as train
318
- transforms.ToTensor(),
319
- ])
320
-
321
- # Create dataset instances
322
- train_dataset = OCRDataset(dataframe=train_df, char_indexer=char_indexer, image_dir=TRAIN_IMAGES_DIR, transform=train_transform)
323
- test_dataset = OCRDataset(dataframe=test_df, char_indexer=char_indexer, image_dir=TEST_IMAGES_DIR, transform=test_transform)
324
-
325
- # Create DataLoader instances
326
- train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) # num_workers=0 for Windows
327
- test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
328
-
329
- # Train the model, passing the progress callback
330
- trained_ocr_model, training_history = train_ocr_model(
331
- ocr_model, # Pass the initialized model instance
332
- train_loader,
333
- test_loader,
334
- char_indexer, # Pass char_indexer for CER calculation
335
- epochs=NUM_EPOCHS,
336
- device=device,
337
- progress_callback=training_progress_bar_instance.progress # Pass the instance's progress method
338
- )
339
-
340
- # Ensure the directory for saving the model exists
341
- os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
342
- save_ocr_model(trained_ocr_model, MODEL_SAVE_PATH)
343
- status_text.success(f"Model training complete and saved to `{MODEL_SAVE_PATH}`!")
344
-
345
- # Display training history chart
346
- st.sidebar.subheader("Training History Plots")
347
-
348
- history_df = pd.DataFrame({
349
- 'Epoch': range(1, len(training_history['train_loss']) + 1),
350
- 'Train Loss': training_history['train_loss'],
351
- 'Test Loss': training_history['test_loss'],
352
- 'Test CER (%)': [cer * 100 for cer in training_history['test_cer']], # Convert CER to percentage for display
353
- 'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']] # Convert to percentage
354
- })
355
-
356
- # Plot 1: Training and Test Loss
357
- st.sidebar.markdown("**Loss over Epochs**")
358
- st.sidebar.line_chart(
359
- history_df.set_index('Epoch')[['Train Loss', 'Test Loss']]
360
- )
361
- st.sidebar.caption("Lower loss indicates better model performance.")
362
-
363
- # Plot 2: Character Error Rate (CER)
364
- st.sidebar.markdown("**Character Error Rate (CER) over Epochs**")
365
- st.sidebar.line_chart(
366
- history_df.set_index('Epoch')[['Test CER (%)']]
367
- )
368
- st.sidebar.caption("Lower CER indicates fewer character errors (0% is perfect).")
369
-
370
- # Plot 3: Exact Match Accuracy
371
- st.sidebar.markdown("**Exact Match Accuracy over Epochs**")
372
- st.sidebar.line_chart(
373
- history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']]
374
- )
375
- st.sidebar.caption("Higher exact match accuracy indicates more perfectly recognized names.")
376
-
377
- # Update the global model instance to the newly trained one for immediate inference
378
- ocr_model = trained_ocr_model
379
- ocr_model.eval()
380
-
381
- except Exception as e:
382
- status_text.error(f"An error occurred during training: {e}")
383
- st.sidebar.text(traceback.format_exc()) # Show full traceback for debugging
384
-
385
- # --- Main Content: Name Prediction ---
386
- st.header("Predict Your Handwritten Name")
387
- st.markdown("Upload a clear image of a single handwritten name or word.")
388
-
389
- uploaded_file = st.file_uploader("πŸ–ΌοΈ Choose an image...", type=["png", "jpg", "jpeg"])
390
-
391
- if uploaded_file is not None:
392
- try:
393
- # Open the uploaded image
394
- image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale
395
- st.image(image_pil, caption="Uploaded Image", use_column_width=True)
396
- st.write("---")
397
- st.write("Processing and Recognizing...")
398
-
399
- # Preprocess the image for the model using utils_ocr function
400
- processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
401
-
402
- # Make prediction
403
- ocr_model.eval() # Ensure model is in evaluation mode
404
- with torch.no_grad(): # Disable gradient calculation for inference
405
- output = ocr_model(processed_image_tensor) # (sequence_length, batch_size, num_classes)
406
-
407
- # ctc_greedy_decode expects (sequence_length, batch_size, num_classes)
408
- # It returns a list of strings, so get the first element for single image inference.
409
- predicted_texts = ctc_greedy_decode(output, char_indexer)
410
- predicted_text = predicted_texts[0] # Get the first (and only) prediction
411
-
412
- st.success(f"Recognized Text: **{predicted_text}**")
413
-
414
- except Exception as e:
415
- st.error(f"Error processing image or recognizing text: {e}")
416
- st.info("πŸ’‘ **Tips for best results:**\n"
417
- "- Ensure the handwritten text is clear and on a clean background.\n"
418
- "- Only include one name/word per image.\n"
419
- "- The model is trained on specific characters. Unusual symbols might not be recognized.")
420
- st.text(traceback.format_exc())
421
-
422
- st.markdown("""
423
- ---
424
- *Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
425
- >>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
426
  """)
 
1
+ # -*- coding: utf-8 -*-
2
  # app.py
3
 
4
+ import os
5
+ # CRITICAL FIX: Disable Streamlit's file watcher to prevent conflicts with PyTorch
6
+ # This MUST be the first thing, before any other imports or Streamlit calls
7
+ os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false"
8
+
9
  import streamlit as st
10
  import pandas as pd
11
  import numpy as np
 
13
  import torch
14
  import torch.nn.functional as F # Added F for log_softmax in inference
15
  import torchvision.transforms as transforms
 
16
  import traceback # For detailed error logging
17
 
18
+ # Import all necessary configuration values from config.py
19
+ from config import (
20
+ IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN, VOCABULARY, BLANK_TOKEN_SYMBOL,
21
+ TRAIN_CSV_PATH, TEST_CSV_PATH, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR,
22
+ MODEL_SAVE_PATH, BATCH_SIZE, NUM_EPOCHS
23
+ )
24
+
25
+ # Import classes and functions from data_handler_ocr.py and model_ocr.py
26
+ from data_handler_ocr import CharIndexer, OCRDataset, ocr_collate_fn, load_ocr_dataframes, create_ocr_dataloaders
27
  from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
28
+ from utils_ocr import preprocess_user_image_for_ocr, binarize_image, resize_image_for_ocr, normalize_image_for_model # Ensure these are imported if needed
29
+
30
+ # --- Global Variables ---
31
+ # These will hold the model and char_indexer instance after training or loading
32
+ trained_ocr_model = None
33
+ char_indexer = None
34
+ training_history = None
35
+ # Determine the device (GPU if available, else CPU)
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
 
38
  # --- Streamlit App Setup ---
39
+ st.set_page_config(layout="wide", page_title="Handwritten Name OCR App") # Changed to wide layout for better display
40
 
41
+ st.title("πŸ“ Handwritten Name Recognition (OCR) App") # Updated title for consistency
42
  st.markdown("""
43
  This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
44
  Optical Character Recognition (OCR) on handwritten names. You can upload an image
 
48
  """)
49
 
50
  # --- Initialize CharIndexer ---
51
+ # CRITICAL FIX: Initialize CharIndexer with VOCABULARY and BLANK_TOKEN_SYMBOL
52
+ # This resolves the ValueError: "Blank token symbol '95' not found..."
53
+ char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
 
 
54
 
55
  # --- Model Loading / Initialization ---
56
  @st.cache_resource # Cache the model to prevent reloading on every rerun
 
79
  # Get the model instance
80
  ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
81
  # Determine the device (GPU if available, else CPU)
 
82
  ocr_model.to(device)
83
  ocr_model.eval() # Set model to evaluation mode for inference by default
84
 
 
87
  st.sidebar.markdown("If you want to train a new model or no model is found:")
88
 
89
  # Initialize Streamlit widgets outside the button block
90
+ training_progress_bar = st.sidebar.empty() # Placeholder for progress bar in sidebar
91
+ status_text = st.sidebar.empty() # Placeholder for status messages in sidebar
92
 
93
+ if st.sidebar.button("πŸ“Š Train New OCR Model"): # Keep button in sidebar as per user's last provided code
94
  # Clear previous messages/widgets if button is clicked again
95
+ training_progress_bar.progress(0) # Reset progress bar
96
  training_progress_bar.empty()
97
+ status_text.empty() # Clear status text
98
 
99
  # Check for existence of CSVs and image directories
100
+ if not os.path.exists(TRAIN_CSV_PATH) or not os.path.isdir(TRAIN_IMAGES_DIR):
101
+ status_text.error(f"Training CSV '{TRAIN_CSV_PATH}' or Images directory '{TRAIN_IMAGES_DIR}' not found!")
102
+ elif not os.path.exists(TEST_CSV_PATH) or not os.path.isdir(TEST_IMAGES_DIR):
103
+ status_text.warning(f"Test CSV '{TEST_CSV_PATH}' or Images directory '{TEST_IMAGES_DIR}' not found. "
104
+ "Evaluation might be affected or skipped. Please ensure all data paths are correct.")
105
  else:
106
+ status_text.info(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
107
 
108
+ # Define the progress bar instance here for the callback
109
  training_progress_bar_instance = training_progress_bar.progress(0.0, text="Training in progress. Please wait.")
110
 
111
+ def update_progress_callback_sidebar(value, text):
112
+ """Callback function to update Streamlit progress bar in sidebar."""
113
+ training_progress_bar_instance.progress(int(value * 100))
114
+ status_text.text(text) # Update status text in sidebar
115
+
116
  try:
117
+ train_df, test_df = load_ocr_dataframes(TRAIN_CSV_PATH, TEST_CSV_PATH)
118
+ status_text.success("Training and Test DataFrames loaded successfully.")
119
+
120
+ char_indexer = CharIndexer(vocabulary_string=VOCABULARY, blank_token_symbol=BLANK_TOKEN_SYMBOL)
121
+ status_text.success(f"CharIndexer initialized with {char_indexer.num_classes} classes.")
122
+
123
+ # Pass the limits to create_ocr_dataloaders
124
+ train_loader, test_loader = create_ocr_dataloaders(
125
+ train_df, test_df, char_indexer, BATCH_SIZE
126
+ )
127
+ status_text.success("DataLoaders created successfully.")
128
+
129
+ ocr_model_for_training = CRNN(num_classes=NUM_CLASSES) # Create a new instance for training
130
+ ocr_model_for_training.to(device)
131
+ status_text.info(f"CRNN model initialized and moved to {device}.")
132
+
133
+ status_text.write("Training in progress... This may take a while.")
 
 
 
 
 
134
  trained_ocr_model, training_history = train_ocr_model(
135
+ model=ocr_model_for_training, # Pass the new instance
136
+ train_loader=train_loader,
137
+ test_loader=test_loader,
138
+ char_indexer=char_indexer, # Pass char_indexer for CER calculation
139
  epochs=NUM_EPOCHS,
140
  device=device,
141
+ progress_callback=update_progress_callback_sidebar # Pass the sidebar callback
142
  )
143
+ status_text.success("OCR model training finished!")
144
+ update_progress_callback_sidebar(1.0, "Training complete!")
145
 
 
146
  os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
147
  save_ocr_model(trained_ocr_model, MODEL_SAVE_PATH)
148
+ status_text.success(f"Trained model saved to `{MODEL_SAVE_PATH}`")
149
+
150
+ # Display training history chart in the main section, not sidebar
151
+ if training_history:
152
+ st.subheader("Training History Plots")
153
+ history_df = pd.DataFrame({
154
+ 'Epoch': range(1, len(training_history['train_loss']) + 1),
155
+ 'Train Loss': training_history['train_loss'],
156
+ 'Test Loss': training_history['test_loss'],
157
+ 'Test CER (%)': [cer * 100 for cer in training_history['test_cer']],
158
+ 'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']]
159
+ })
160
+
161
+ st.markdown("**Loss over Epochs**")
162
+ st.line_chart(history_df.set_index('Epoch')[['Train Loss', 'Test Loss']])
163
+ st.caption("Lower loss indicates better model performance.")
164
+
165
+ st.markdown("**Character Error Rate (CER) over Epochs**")
166
+ st.line_chart(history_df.set_index('Epoch')[['Test CER (%)']])
167
+ st.caption("Lower CER indicates fewer character errors (0% is perfect).")
168
+
169
+ st.markdown("**Exact Match Accuracy over Epochs**")
170
+ st.line_chart(history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']])
171
+ st.caption("Higher exact match accuracy indicates more perfectly recognized names.")
172
+
173
+ st.markdown("**Performance Metrics over Epochs (CER vs. Exact Match Accuracy)**")
174
+ st.line_chart(history_df.set_index('Epoch')[['Test CER (%)', 'Test Exact Match Accuracy (%)']])
175
+ st.caption("CER should decrease, Accuracy should increase.")
 
 
 
 
 
 
 
 
 
176
 
177
  except Exception as e:
178
  status_text.error(f"An error occurred during training: {e}")
179
+ status_text.exception(e) # Display full traceback in Streamlit
180
+ update_progress_callback_sidebar(0.0, "Training failed!")
181
 
182
  # --- Main Content: Name Prediction ---
183
  st.header("Predict Your Handwritten Name")
 
189
  try:
190
  # Open the uploaded image
191
  image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale
192
+ # Use use_container_width for deprecation warning fix
193
+ st.image(image_pil, caption="Uploaded Image", use_container_width=True)
194
  st.write("---")
195
  st.write("Processing and Recognizing...")
196
 
197
  # Preprocess the image for the model using utils_ocr function
198
  processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
199
+
200
+ trained_ocr_model.eval() # Ensure model is in evaluation mode
 
201
  with torch.no_grad(): # Disable gradient calculation for inference
202
+ output = trained_ocr_model(processed_image_tensor) # (sequence_length, batch_size, num_classes)
203
+ predicted_texts = ctc_greedy_decode(output, char_indexer)
204
+ predicted_text = predicted_texts[0] # Get the first (and only) prediction
 
 
 
205
 
206
  st.success(f"Recognized Text: **{predicted_text}**")
207
 
 
211
  "- Ensure the handwritten text is clear and on a clean background.\n"
212
  "- Only include one name/word per image.\n"
213
  "- The model is trained on specific characters. Unusual symbols might not be recognized.")
214
+ st.exception(e) # Display full traceback in Streamlit
215
 
216
  st.markdown("""
217
  ---
218
  *Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  """)