Spaces:
Build error
Build error
File size: 21,515 Bytes
8900f0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 |
<<<<<<< HEAD
# app.py
import streamlit as st
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F # Added F for log_softmax in inference
import torchvision.transforms as transforms
import os
import traceback # For detailed error logging
# Import custom modules
from config import CHARS, BLANK_TOKEN, IMG_HEIGHT, TRAIN_CSV_PATH, TEST_CSV_PATH, \
TRAIN_IMAGES_DIR, TEST_IMAGES_DIR, MODEL_SAVE_PATH, NUM_CLASSES, NUM_EPOCHS, BATCH_SIZE
from data_handler_ocr import CharIndexer, OCRDataset
from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
from utils_ocr import preprocess_user_image_for_ocr
# --- Streamlit App Setup ---
st.set_page_config(page_title="Handwritten Name Recognizer", layout="centered")
st.title("π Handwritten Name Recognition (OCR)")
st.markdown("""
This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
Optical Character Recognition (OCR) on handwritten names. You can upload an image
of a handwritten name for prediction or train a new model using the provided dataset.
**Note:** Training a robust OCR model can be time-consuming.
""")
# --- Initialize CharIndexer ---
# The CHARS variable should contain all possible characters your model can recognize.
# Make sure it's comprehensive based on your dataset.
char_indexer = CharIndexer(CHARS, BLANK_TOKEN)
# For robustness, it's best to always use char_indexer.num_classes
# If NUM_CLASSES from config is used to initialize CRNN, ensure it matches char_indexer.num_classes
# --- Model Loading / Initialization ---
@st.cache_resource # Cache the model to prevent reloading on every rerun
def get_and_load_ocr_model_cached(num_classes, model_path):
"""
Initializes the OCR model and attempts to load a pre-trained model.
If no pre-trained model exists, a new model instance is returned.
"""
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
if os.path.exists(model_path):
st.sidebar.info("Loading pre-trained OCR model...")
try:
# Load model to CPU first, then move to device
model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
st.sidebar.success("OCR model loaded successfully!")
except Exception as e:
st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
# If loading fails, re-initialize an untrained model
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
else:
st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
return model_instance
# Get the model instance
ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
# Determine the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ocr_model.to(device)
ocr_model.eval() # Set model to evaluation mode for inference by default
# --- Sidebar for Model Training ---
st.sidebar.header("Model Training (Optional)")
st.sidebar.markdown("If you want to train a new model or no model is found:")
# Initialize Streamlit widgets outside the button block
training_progress_bar = st.sidebar.empty() # Placeholder for progress bar
status_text = st.sidebar.empty() # Placeholder for status messages
if st.sidebar.button("π Train New OCR Model"):
# Clear previous messages/widgets if button is clicked again
training_progress_bar.empty()
status_text.empty()
# Check for existence of CSVs and image directories
if not os.path.exists(TRAIN_CSV_PATH) or not os.path.exists(TEST_CSV_PATH) or \
not os.path.isdir(TRAIN_IMAGES_DIR) or not os.path.isdir(TEST_IMAGES_DIR):
status_text.error(f"""Dataset files or image directories not found.
Please ensure '{TRAIN_CSV_PATH}', '{TEST_CSV_PATH}', and directories '{TRAIN_IMAGES_DIR}'
and '{TEST_IMAGES_DIR}' exist. Refer to your project structure.""")
else:
status_text.write(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
training_progress_bar_instance = training_progress_bar.progress(0.0, text="Training in progress. Please wait.")
try:
train_df = pd.read_csv(TRAIN_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
test_df = pd.read_csv(TEST_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
# Define standard image transforms for consistency
train_transform = transforms.Compose([
transforms.Resize((IMG_HEIGHT, 100)), # Resize to fixed height, width will be 100 (adjust as needed for variable width)
transforms.ToTensor(), # Converts PIL Image to PyTorch Tensor (H, W) -> (C, H, W), normalizes to [0,1]
])
test_transform = transforms.Compose([
transforms.Resize((IMG_HEIGHT, 100)), # Same transformation as train
transforms.ToTensor(),
])
# Create dataset instances
train_dataset = OCRDataset(dataframe=train_df, char_indexer=char_indexer, image_dir=TRAIN_IMAGES_DIR, transform=train_transform)
test_dataset = OCRDataset(dataframe=test_df, char_indexer=char_indexer, image_dir=TEST_IMAGES_DIR, transform=test_transform)
# Create DataLoader instances
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) # num_workers=0 for Windows
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
# Train the model, passing the progress callback
trained_ocr_model, training_history = train_ocr_model(
ocr_model, # Pass the initialized model instance
train_loader,
test_loader,
char_indexer, # Pass char_indexer for CER calculation
epochs=NUM_EPOCHS,
device=device,
progress_callback=training_progress_bar_instance.progress # Pass the instance's progress method
)
# Ensure the directory for saving the model exists
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
save_ocr_model(trained_ocr_model, MODEL_SAVE_PATH)
status_text.success(f"Model training complete and saved to `{MODEL_SAVE_PATH}`!")
# Display training history chart
st.sidebar.subheader("Training History Plots")
history_df = pd.DataFrame({
'Epoch': range(1, len(training_history['train_loss']) + 1),
'Train Loss': training_history['train_loss'],
'Test Loss': training_history['test_loss'],
'Test CER (%)': [cer * 100 for cer in training_history['test_cer']], # Convert CER to percentage for display
'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']] # Convert to percentage
})
# Plot 1: Training and Test Loss
st.sidebar.markdown("**Loss over Epochs**")
st.sidebar.line_chart(
history_df.set_index('Epoch')[['Train Loss', 'Test Loss']]
)
st.sidebar.caption("Lower loss indicates better model performance.")
# Plot 2: Character Error Rate (CER)
st.sidebar.markdown("**Character Error Rate (CER) over Epochs**")
st.sidebar.line_chart(
history_df.set_index('Epoch')[['Test CER (%)']]
)
st.sidebar.caption("Lower CER indicates fewer character errors (0% is perfect).")
# Plot 3: Exact Match Accuracy
st.sidebar.markdown("**Exact Match Accuracy over Epochs**")
st.sidebar.line_chart(
history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']]
)
st.sidebar.caption("Higher exact match accuracy indicates more perfectly recognized names.")
# Update the global model instance to the newly trained one for immediate inference
ocr_model = trained_ocr_model
ocr_model.eval()
except Exception as e:
status_text.error(f"An error occurred during training: {e}")
st.sidebar.text(traceback.format_exc()) # Show full traceback for debugging
# --- Main Content: Name Prediction ---
st.header("Predict Your Handwritten Name")
st.markdown("Upload a clear image of a single handwritten name or word.")
uploaded_file = st.file_uploader("πΌοΈ Choose an image...", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
try:
# Open the uploaded image
image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale
st.image(image_pil, caption="Uploaded Image", use_column_width=True)
st.write("---")
st.write("Processing and Recognizing...")
# Preprocess the image for the model using utils_ocr function
processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
# Make prediction
ocr_model.eval() # Ensure model is in evaluation mode
with torch.no_grad(): # Disable gradient calculation for inference
output = ocr_model(processed_image_tensor) # (sequence_length, batch_size, num_classes)
# ctc_greedy_decode expects (sequence_length, batch_size, num_classes)
# It returns a list of strings, so get the first element for single image inference.
predicted_texts = ctc_greedy_decode(output, char_indexer)
predicted_text = predicted_texts[0] # Get the first (and only) prediction
st.success(f"Recognized Text: **{predicted_text}**")
except Exception as e:
st.error(f"Error processing image or recognizing text: {e}")
st.info("π‘ **Tips for best results:**\n"
"- Ensure the handwritten text is clear and on a clean background.\n"
"- Only include one name/word per image.\n"
"- The model is trained on specific characters. Unusual symbols might not be recognized.")
st.text(traceback.format_exc())
st.markdown("""
---
*Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
=======
# app.py
import streamlit as st
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F # Added F for log_softmax in inference
import torchvision.transforms as transforms
import os
import traceback # For detailed error logging
# Import custom modules
from config import CHARS, BLANK_TOKEN, IMG_HEIGHT, TRAIN_CSV_PATH, TEST_CSV_PATH, \
TRAIN_IMAGES_DIR, TEST_IMAGES_DIR, MODEL_SAVE_PATH, NUM_CLASSES, NUM_EPOCHS, BATCH_SIZE
from data_handler_ocr import CharIndexer, OCRDataset
from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
from utils_ocr import preprocess_user_image_for_ocr
# --- Streamlit App Setup ---
st.set_page_config(page_title="Handwritten Name Recognizer", layout="centered")
st.title("π Handwritten Name Recognition (OCR)")
st.markdown("""
This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
Optical Character Recognition (OCR) on handwritten names. You can upload an image
of a handwritten name for prediction or train a new model using the provided dataset.
**Note:** Training a robust OCR model can be time-consuming.
""")
# --- Initialize CharIndexer ---
# The CHARS variable should contain all possible characters your model can recognize.
# Make sure it's comprehensive based on your dataset.
char_indexer = CharIndexer(CHARS, BLANK_TOKEN)
# For robustness, it's best to always use char_indexer.num_classes
# If NUM_CLASSES from config is used to initialize CRNN, ensure it matches char_indexer.num_classes
# --- Model Loading / Initialization ---
@st.cache_resource # Cache the model to prevent reloading on every rerun
def get_and_load_ocr_model_cached(num_classes, model_path):
"""
Initializes the OCR model and attempts to load a pre-trained model.
If no pre-trained model exists, a new model instance is returned.
"""
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
if os.path.exists(model_path):
st.sidebar.info("Loading pre-trained OCR model...")
try:
# Load model to CPU first, then move to device
model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
st.sidebar.success("OCR model loaded successfully!")
except Exception as e:
st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
# If loading fails, re-initialize an untrained model
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
else:
st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
return model_instance
# Get the model instance
ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
# Determine the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ocr_model.to(device)
ocr_model.eval() # Set model to evaluation mode for inference by default
# --- Sidebar for Model Training ---
st.sidebar.header("Model Training (Optional)")
st.sidebar.markdown("If you want to train a new model or no model is found:")
# Initialize Streamlit widgets outside the button block
training_progress_bar = st.sidebar.empty() # Placeholder for progress bar
status_text = st.sidebar.empty() # Placeholder for status messages
if st.sidebar.button("π Train New OCR Model"):
# Clear previous messages/widgets if button is clicked again
training_progress_bar.empty()
status_text.empty()
# Check for existence of CSVs and image directories
if not os.path.exists(TRAIN_CSV_PATH) or not os.path.exists(TEST_CSV_PATH) or \
not os.path.isdir(TRAIN_IMAGES_DIR) or not os.path.isdir(TEST_IMAGES_DIR):
status_text.error(f"""Dataset files or image directories not found.
Please ensure '{TRAIN_CSV_PATH}', '{TEST_CSV_PATH}', and directories '{TRAIN_IMAGES_DIR}'
and '{TEST_IMAGES_DIR}' exist. Refer to your project structure.""")
else:
status_text.write(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
training_progress_bar_instance = training_progress_bar.progress(0.0, text="Training in progress. Please wait.")
try:
train_df = pd.read_csv(TRAIN_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
test_df = pd.read_csv(TEST_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
# Define standard image transforms for consistency
train_transform = transforms.Compose([
transforms.Resize((IMG_HEIGHT, 100)), # Resize to fixed height, width will be 100 (adjust as needed for variable width)
transforms.ToTensor(), # Converts PIL Image to PyTorch Tensor (H, W) -> (C, H, W), normalizes to [0,1]
])
test_transform = transforms.Compose([
transforms.Resize((IMG_HEIGHT, 100)), # Same transformation as train
transforms.ToTensor(),
])
# Create dataset instances
train_dataset = OCRDataset(dataframe=train_df, char_indexer=char_indexer, image_dir=TRAIN_IMAGES_DIR, transform=train_transform)
test_dataset = OCRDataset(dataframe=test_df, char_indexer=char_indexer, image_dir=TEST_IMAGES_DIR, transform=test_transform)
# Create DataLoader instances
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) # num_workers=0 for Windows
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
# Train the model, passing the progress callback
trained_ocr_model, training_history = train_ocr_model(
ocr_model, # Pass the initialized model instance
train_loader,
test_loader,
char_indexer, # Pass char_indexer for CER calculation
epochs=NUM_EPOCHS,
device=device,
progress_callback=training_progress_bar_instance.progress # Pass the instance's progress method
)
# Ensure the directory for saving the model exists
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
save_ocr_model(trained_ocr_model, MODEL_SAVE_PATH)
status_text.success(f"Model training complete and saved to `{MODEL_SAVE_PATH}`!")
# Display training history chart
st.sidebar.subheader("Training History Plots")
history_df = pd.DataFrame({
'Epoch': range(1, len(training_history['train_loss']) + 1),
'Train Loss': training_history['train_loss'],
'Test Loss': training_history['test_loss'],
'Test CER (%)': [cer * 100 for cer in training_history['test_cer']], # Convert CER to percentage for display
'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']] # Convert to percentage
})
# Plot 1: Training and Test Loss
st.sidebar.markdown("**Loss over Epochs**")
st.sidebar.line_chart(
history_df.set_index('Epoch')[['Train Loss', 'Test Loss']]
)
st.sidebar.caption("Lower loss indicates better model performance.")
# Plot 2: Character Error Rate (CER)
st.sidebar.markdown("**Character Error Rate (CER) over Epochs**")
st.sidebar.line_chart(
history_df.set_index('Epoch')[['Test CER (%)']]
)
st.sidebar.caption("Lower CER indicates fewer character errors (0% is perfect).")
# Plot 3: Exact Match Accuracy
st.sidebar.markdown("**Exact Match Accuracy over Epochs**")
st.sidebar.line_chart(
history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']]
)
st.sidebar.caption("Higher exact match accuracy indicates more perfectly recognized names.")
# Update the global model instance to the newly trained one for immediate inference
ocr_model = trained_ocr_model
ocr_model.eval()
except Exception as e:
status_text.error(f"An error occurred during training: {e}")
st.sidebar.text(traceback.format_exc()) # Show full traceback for debugging
# --- Main Content: Name Prediction ---
st.header("Predict Your Handwritten Name")
st.markdown("Upload a clear image of a single handwritten name or word.")
uploaded_file = st.file_uploader("πΌοΈ Choose an image...", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
try:
# Open the uploaded image
image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale
st.image(image_pil, caption="Uploaded Image", use_column_width=True)
st.write("---")
st.write("Processing and Recognizing...")
# Preprocess the image for the model using utils_ocr function
processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
# Make prediction
ocr_model.eval() # Ensure model is in evaluation mode
with torch.no_grad(): # Disable gradient calculation for inference
output = ocr_model(processed_image_tensor) # (sequence_length, batch_size, num_classes)
# ctc_greedy_decode expects (sequence_length, batch_size, num_classes)
# It returns a list of strings, so get the first element for single image inference.
predicted_texts = ctc_greedy_decode(output, char_indexer)
predicted_text = predicted_texts[0] # Get the first (and only) prediction
st.success(f"Recognized Text: **{predicted_text}**")
except Exception as e:
st.error(f"Error processing image or recognizing text: {e}")
st.info("π‘ **Tips for best results:**\n"
"- Ensure the handwritten text is clear and on a clean background.\n"
"- Only include one name/word per image.\n"
"- The model is trained on specific characters. Unusual symbols might not be recognized.")
st.text(traceback.format_exc())
st.markdown("""
---
*Built using Streamlit, PyTorch, OpenCV, and EditDistance Β©2025 by MFT*
>>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
""") |