marianeft commited on
Commit
8900f0a
·
verified ·
1 Parent(s): 2c31f12

Initial update of files

Browse files
Files changed (7) hide show
  1. LICENSE +201 -0
  2. app.py +426 -0
  3. config.py +112 -0
  4. data_handler_ocr.py +270 -0
  5. model_ocr.py +584 -0
  6. requirements.txt +32 -0
  7. utils_ocr.py +184 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
+ # app.py
3
+
4
+ import streamlit as st
5
+ import pandas as pd
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch
9
+ import torch.nn.functional as F # Added F for log_softmax in inference
10
+ import torchvision.transforms as transforms
11
+ import os
12
+ import traceback # For detailed error logging
13
+
14
+ # Import custom modules
15
+ from config import CHARS, BLANK_TOKEN, IMG_HEIGHT, TRAIN_CSV_PATH, TEST_CSV_PATH, \
16
+ TRAIN_IMAGES_DIR, TEST_IMAGES_DIR, MODEL_SAVE_PATH, NUM_CLASSES, NUM_EPOCHS, BATCH_SIZE
17
+ from data_handler_ocr import CharIndexer, OCRDataset
18
+ from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
19
+ from utils_ocr import preprocess_user_image_for_ocr
20
+
21
+ # --- Streamlit App Setup ---
22
+ st.set_page_config(page_title="Handwritten Name Recognizer", layout="centered")
23
+
24
+ st.title("📝 Handwritten Name Recognition (OCR)")
25
+ st.markdown("""
26
+ This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
27
+ Optical Character Recognition (OCR) on handwritten names. You can upload an image
28
+ of a handwritten name for prediction or train a new model using the provided dataset.
29
+
30
+ **Note:** Training a robust OCR model can be time-consuming.
31
+ """)
32
+
33
+ # --- Initialize CharIndexer ---
34
+ # The CHARS variable should contain all possible characters your model can recognize.
35
+ # Make sure it's comprehensive based on your dataset.
36
+ char_indexer = CharIndexer(CHARS, BLANK_TOKEN)
37
+ # For robustness, it's best to always use char_indexer.num_classes
38
+ # If NUM_CLASSES from config is used to initialize CRNN, ensure it matches char_indexer.num_classes
39
+
40
+ # --- Model Loading / Initialization ---
41
+ @st.cache_resource # Cache the model to prevent reloading on every rerun
42
+ def get_and_load_ocr_model_cached(num_classes, model_path):
43
+ """
44
+ Initializes the OCR model and attempts to load a pre-trained model.
45
+ If no pre-trained model exists, a new model instance is returned.
46
+ """
47
+ model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
48
+
49
+ if os.path.exists(model_path):
50
+ st.sidebar.info("Loading pre-trained OCR model...")
51
+ try:
52
+ # Load model to CPU first, then move to device
53
+ model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
54
+ st.sidebar.success("OCR model loaded successfully!")
55
+ except Exception as e:
56
+ st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
57
+ # If loading fails, re-initialize an untrained model
58
+ model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
59
+ else:
60
+ st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
61
+
62
+ return model_instance
63
+
64
+ # Get the model instance
65
+ ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
66
+ # Determine the device (GPU if available, else CPU)
67
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
+ ocr_model.to(device)
69
+ ocr_model.eval() # Set model to evaluation mode for inference by default
70
+
71
+ # --- Sidebar for Model Training ---
72
+ st.sidebar.header("Model Training (Optional)")
73
+ st.sidebar.markdown("If you want to train a new model or no model is found:")
74
+
75
+ # Initialize Streamlit widgets outside the button block
76
+ training_progress_bar = st.sidebar.empty() # Placeholder for progress bar
77
+ status_text = st.sidebar.empty() # Placeholder for status messages
78
+
79
+ if st.sidebar.button("📊 Train New OCR Model"):
80
+ # Clear previous messages/widgets if button is clicked again
81
+ training_progress_bar.empty()
82
+ status_text.empty()
83
+
84
+ # Check for existence of CSVs and image directories
85
+ if not os.path.exists(TRAIN_CSV_PATH) or not os.path.exists(TEST_CSV_PATH) or \
86
+ not os.path.isdir(TRAIN_IMAGES_DIR) or not os.path.isdir(TEST_IMAGES_DIR):
87
+ status_text.error(f"""Dataset files or image directories not found.
88
+ Please ensure '{TRAIN_CSV_PATH}', '{TEST_CSV_PATH}', and directories '{TRAIN_IMAGES_DIR}'
89
+ and '{TEST_IMAGES_DIR}' exist. Refer to your project structure.""")
90
+ else:
91
+ status_text.write(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
92
+
93
+ training_progress_bar_instance = training_progress_bar.progress(0.0, text="Training in progress. Please wait.")
94
+
95
+ try:
96
+ train_df = pd.read_csv(TRAIN_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
97
+ test_df = pd.read_csv(TEST_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
98
+
99
+ # Define standard image transforms for consistency
100
+ train_transform = transforms.Compose([
101
+ transforms.Resize((IMG_HEIGHT, 100)), # Resize to fixed height, width will be 100 (adjust as needed for variable width)
102
+ transforms.ToTensor(), # Converts PIL Image to PyTorch Tensor (H, W) -> (C, H, W), normalizes to [0,1]
103
+ ])
104
+ test_transform = transforms.Compose([
105
+ transforms.Resize((IMG_HEIGHT, 100)), # Same transformation as train
106
+ transforms.ToTensor(),
107
+ ])
108
+
109
+ # Create dataset instances
110
+ train_dataset = OCRDataset(dataframe=train_df, char_indexer=char_indexer, image_dir=TRAIN_IMAGES_DIR, transform=train_transform)
111
+ test_dataset = OCRDataset(dataframe=test_df, char_indexer=char_indexer, image_dir=TEST_IMAGES_DIR, transform=test_transform)
112
+
113
+ # Create DataLoader instances
114
+ train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) # num_workers=0 for Windows
115
+ test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
116
+
117
+ # Train the model, passing the progress callback
118
+ trained_ocr_model, training_history = train_ocr_model(
119
+ ocr_model, # Pass the initialized model instance
120
+ train_loader,
121
+ test_loader,
122
+ char_indexer, # Pass char_indexer for CER calculation
123
+ epochs=NUM_EPOCHS,
124
+ device=device,
125
+ progress_callback=training_progress_bar_instance.progress # Pass the instance's progress method
126
+ )
127
+
128
+ # Ensure the directory for saving the model exists
129
+ os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
130
+ save_ocr_model(trained_ocr_model, MODEL_SAVE_PATH)
131
+ status_text.success(f"Model training complete and saved to `{MODEL_SAVE_PATH}`!")
132
+
133
+ # Display training history chart
134
+ st.sidebar.subheader("Training History Plots")
135
+
136
+ history_df = pd.DataFrame({
137
+ 'Epoch': range(1, len(training_history['train_loss']) + 1),
138
+ 'Train Loss': training_history['train_loss'],
139
+ 'Test Loss': training_history['test_loss'],
140
+ 'Test CER (%)': [cer * 100 for cer in training_history['test_cer']], # Convert CER to percentage for display
141
+ 'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']] # Convert to percentage
142
+ })
143
+
144
+ # Plot 1: Training and Test Loss
145
+ st.sidebar.markdown("**Loss over Epochs**")
146
+ st.sidebar.line_chart(
147
+ history_df.set_index('Epoch')[['Train Loss', 'Test Loss']]
148
+ )
149
+ st.sidebar.caption("Lower loss indicates better model performance.")
150
+
151
+ # Plot 2: Character Error Rate (CER)
152
+ st.sidebar.markdown("**Character Error Rate (CER) over Epochs**")
153
+ st.sidebar.line_chart(
154
+ history_df.set_index('Epoch')[['Test CER (%)']]
155
+ )
156
+ st.sidebar.caption("Lower CER indicates fewer character errors (0% is perfect).")
157
+
158
+ # Plot 3: Exact Match Accuracy
159
+ st.sidebar.markdown("**Exact Match Accuracy over Epochs**")
160
+ st.sidebar.line_chart(
161
+ history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']]
162
+ )
163
+ st.sidebar.caption("Higher exact match accuracy indicates more perfectly recognized names.")
164
+
165
+ # Update the global model instance to the newly trained one for immediate inference
166
+ ocr_model = trained_ocr_model
167
+ ocr_model.eval()
168
+
169
+ except Exception as e:
170
+ status_text.error(f"An error occurred during training: {e}")
171
+ st.sidebar.text(traceback.format_exc()) # Show full traceback for debugging
172
+
173
+ # --- Main Content: Name Prediction ---
174
+ st.header("Predict Your Handwritten Name")
175
+ st.markdown("Upload a clear image of a single handwritten name or word.")
176
+
177
+ uploaded_file = st.file_uploader("🖼️ Choose an image...", type=["png", "jpg", "jpeg"])
178
+
179
+ if uploaded_file is not None:
180
+ try:
181
+ # Open the uploaded image
182
+ image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale
183
+ st.image(image_pil, caption="Uploaded Image", use_column_width=True)
184
+ st.write("---")
185
+ st.write("Processing and Recognizing...")
186
+
187
+ # Preprocess the image for the model using utils_ocr function
188
+ processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
189
+
190
+ # Make prediction
191
+ ocr_model.eval() # Ensure model is in evaluation mode
192
+ with torch.no_grad(): # Disable gradient calculation for inference
193
+ output = ocr_model(processed_image_tensor) # (sequence_length, batch_size, num_classes)
194
+
195
+ # ctc_greedy_decode expects (sequence_length, batch_size, num_classes)
196
+ # It returns a list of strings, so get the first element for single image inference.
197
+ predicted_texts = ctc_greedy_decode(output, char_indexer)
198
+ predicted_text = predicted_texts[0] # Get the first (and only) prediction
199
+
200
+ st.success(f"Recognized Text: **{predicted_text}**")
201
+
202
+ except Exception as e:
203
+ st.error(f"Error processing image or recognizing text: {e}")
204
+ st.info("💡 **Tips for best results:**\n"
205
+ "- Ensure the handwritten text is clear and on a clean background.\n"
206
+ "- Only include one name/word per image.\n"
207
+ "- The model is trained on specific characters. Unusual symbols might not be recognized.")
208
+ st.text(traceback.format_exc())
209
+
210
+ st.markdown("""
211
+ ---
212
+ *Built using Streamlit, PyTorch, OpenCV, and EditDistance ©2025 by MFT*
213
+ =======
214
+ # app.py
215
+
216
+ import streamlit as st
217
+ import pandas as pd
218
+ import numpy as np
219
+ from PIL import Image
220
+ import torch
221
+ import torch.nn.functional as F # Added F for log_softmax in inference
222
+ import torchvision.transforms as transforms
223
+ import os
224
+ import traceback # For detailed error logging
225
+
226
+ # Import custom modules
227
+ from config import CHARS, BLANK_TOKEN, IMG_HEIGHT, TRAIN_CSV_PATH, TEST_CSV_PATH, \
228
+ TRAIN_IMAGES_DIR, TEST_IMAGES_DIR, MODEL_SAVE_PATH, NUM_CLASSES, NUM_EPOCHS, BATCH_SIZE
229
+ from data_handler_ocr import CharIndexer, OCRDataset
230
+ from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
231
+ from utils_ocr import preprocess_user_image_for_ocr
232
+
233
+ # --- Streamlit App Setup ---
234
+ st.set_page_config(page_title="Handwritten Name Recognizer", layout="centered")
235
+
236
+ st.title("📝 Handwritten Name Recognition (OCR)")
237
+ st.markdown("""
238
+ This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
239
+ Optical Character Recognition (OCR) on handwritten names. You can upload an image
240
+ of a handwritten name for prediction or train a new model using the provided dataset.
241
+
242
+ **Note:** Training a robust OCR model can be time-consuming.
243
+ """)
244
+
245
+ # --- Initialize CharIndexer ---
246
+ # The CHARS variable should contain all possible characters your model can recognize.
247
+ # Make sure it's comprehensive based on your dataset.
248
+ char_indexer = CharIndexer(CHARS, BLANK_TOKEN)
249
+ # For robustness, it's best to always use char_indexer.num_classes
250
+ # If NUM_CLASSES from config is used to initialize CRNN, ensure it matches char_indexer.num_classes
251
+
252
+ # --- Model Loading / Initialization ---
253
+ @st.cache_resource # Cache the model to prevent reloading on every rerun
254
+ def get_and_load_ocr_model_cached(num_classes, model_path):
255
+ """
256
+ Initializes the OCR model and attempts to load a pre-trained model.
257
+ If no pre-trained model exists, a new model instance is returned.
258
+ """
259
+ model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
260
+
261
+ if os.path.exists(model_path):
262
+ st.sidebar.info("Loading pre-trained OCR model...")
263
+ try:
264
+ # Load model to CPU first, then move to device
265
+ model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
266
+ st.sidebar.success("OCR model loaded successfully!")
267
+ except Exception as e:
268
+ st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
269
+ # If loading fails, re-initialize an untrained model
270
+ model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
271
+ else:
272
+ st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
273
+
274
+ return model_instance
275
+
276
+ # Get the model instance
277
+ ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
278
+ # Determine the device (GPU if available, else CPU)
279
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
280
+ ocr_model.to(device)
281
+ ocr_model.eval() # Set model to evaluation mode for inference by default
282
+
283
+ # --- Sidebar for Model Training ---
284
+ st.sidebar.header("Model Training (Optional)")
285
+ st.sidebar.markdown("If you want to train a new model or no model is found:")
286
+
287
+ # Initialize Streamlit widgets outside the button block
288
+ training_progress_bar = st.sidebar.empty() # Placeholder for progress bar
289
+ status_text = st.sidebar.empty() # Placeholder for status messages
290
+
291
+ if st.sidebar.button("📊 Train New OCR Model"):
292
+ # Clear previous messages/widgets if button is clicked again
293
+ training_progress_bar.empty()
294
+ status_text.empty()
295
+
296
+ # Check for existence of CSVs and image directories
297
+ if not os.path.exists(TRAIN_CSV_PATH) or not os.path.exists(TEST_CSV_PATH) or \
298
+ not os.path.isdir(TRAIN_IMAGES_DIR) or not os.path.isdir(TEST_IMAGES_DIR):
299
+ status_text.error(f"""Dataset files or image directories not found.
300
+ Please ensure '{TRAIN_CSV_PATH}', '{TEST_CSV_PATH}', and directories '{TRAIN_IMAGES_DIR}'
301
+ and '{TEST_IMAGES_DIR}' exist. Refer to your project structure.""")
302
+ else:
303
+ status_text.write(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
304
+
305
+ training_progress_bar_instance = training_progress_bar.progress(0.0, text="Training in progress. Please wait.")
306
+
307
+ try:
308
+ train_df = pd.read_csv(TRAIN_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
309
+ test_df = pd.read_csv(TEST_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
310
+
311
+ # Define standard image transforms for consistency
312
+ train_transform = transforms.Compose([
313
+ transforms.Resize((IMG_HEIGHT, 100)), # Resize to fixed height, width will be 100 (adjust as needed for variable width)
314
+ transforms.ToTensor(), # Converts PIL Image to PyTorch Tensor (H, W) -> (C, H, W), normalizes to [0,1]
315
+ ])
316
+ test_transform = transforms.Compose([
317
+ transforms.Resize((IMG_HEIGHT, 100)), # Same transformation as train
318
+ transforms.ToTensor(),
319
+ ])
320
+
321
+ # Create dataset instances
322
+ train_dataset = OCRDataset(dataframe=train_df, char_indexer=char_indexer, image_dir=TRAIN_IMAGES_DIR, transform=train_transform)
323
+ test_dataset = OCRDataset(dataframe=test_df, char_indexer=char_indexer, image_dir=TEST_IMAGES_DIR, transform=test_transform)
324
+
325
+ # Create DataLoader instances
326
+ train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) # num_workers=0 for Windows
327
+ test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
328
+
329
+ # Train the model, passing the progress callback
330
+ trained_ocr_model, training_history = train_ocr_model(
331
+ ocr_model, # Pass the initialized model instance
332
+ train_loader,
333
+ test_loader,
334
+ char_indexer, # Pass char_indexer for CER calculation
335
+ epochs=NUM_EPOCHS,
336
+ device=device,
337
+ progress_callback=training_progress_bar_instance.progress # Pass the instance's progress method
338
+ )
339
+
340
+ # Ensure the directory for saving the model exists
341
+ os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
342
+ save_ocr_model(trained_ocr_model, MODEL_SAVE_PATH)
343
+ status_text.success(f"Model training complete and saved to `{MODEL_SAVE_PATH}`!")
344
+
345
+ # Display training history chart
346
+ st.sidebar.subheader("Training History Plots")
347
+
348
+ history_df = pd.DataFrame({
349
+ 'Epoch': range(1, len(training_history['train_loss']) + 1),
350
+ 'Train Loss': training_history['train_loss'],
351
+ 'Test Loss': training_history['test_loss'],
352
+ 'Test CER (%)': [cer * 100 for cer in training_history['test_cer']], # Convert CER to percentage for display
353
+ 'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']] # Convert to percentage
354
+ })
355
+
356
+ # Plot 1: Training and Test Loss
357
+ st.sidebar.markdown("**Loss over Epochs**")
358
+ st.sidebar.line_chart(
359
+ history_df.set_index('Epoch')[['Train Loss', 'Test Loss']]
360
+ )
361
+ st.sidebar.caption("Lower loss indicates better model performance.")
362
+
363
+ # Plot 2: Character Error Rate (CER)
364
+ st.sidebar.markdown("**Character Error Rate (CER) over Epochs**")
365
+ st.sidebar.line_chart(
366
+ history_df.set_index('Epoch')[['Test CER (%)']]
367
+ )
368
+ st.sidebar.caption("Lower CER indicates fewer character errors (0% is perfect).")
369
+
370
+ # Plot 3: Exact Match Accuracy
371
+ st.sidebar.markdown("**Exact Match Accuracy over Epochs**")
372
+ st.sidebar.line_chart(
373
+ history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']]
374
+ )
375
+ st.sidebar.caption("Higher exact match accuracy indicates more perfectly recognized names.")
376
+
377
+ # Update the global model instance to the newly trained one for immediate inference
378
+ ocr_model = trained_ocr_model
379
+ ocr_model.eval()
380
+
381
+ except Exception as e:
382
+ status_text.error(f"An error occurred during training: {e}")
383
+ st.sidebar.text(traceback.format_exc()) # Show full traceback for debugging
384
+
385
+ # --- Main Content: Name Prediction ---
386
+ st.header("Predict Your Handwritten Name")
387
+ st.markdown("Upload a clear image of a single handwritten name or word.")
388
+
389
+ uploaded_file = st.file_uploader("🖼️ Choose an image...", type=["png", "jpg", "jpeg"])
390
+
391
+ if uploaded_file is not None:
392
+ try:
393
+ # Open the uploaded image
394
+ image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale
395
+ st.image(image_pil, caption="Uploaded Image", use_column_width=True)
396
+ st.write("---")
397
+ st.write("Processing and Recognizing...")
398
+
399
+ # Preprocess the image for the model using utils_ocr function
400
+ processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
401
+
402
+ # Make prediction
403
+ ocr_model.eval() # Ensure model is in evaluation mode
404
+ with torch.no_grad(): # Disable gradient calculation for inference
405
+ output = ocr_model(processed_image_tensor) # (sequence_length, batch_size, num_classes)
406
+
407
+ # ctc_greedy_decode expects (sequence_length, batch_size, num_classes)
408
+ # It returns a list of strings, so get the first element for single image inference.
409
+ predicted_texts = ctc_greedy_decode(output, char_indexer)
410
+ predicted_text = predicted_texts[0] # Get the first (and only) prediction
411
+
412
+ st.success(f"Recognized Text: **{predicted_text}**")
413
+
414
+ except Exception as e:
415
+ st.error(f"Error processing image or recognizing text: {e}")
416
+ st.info("💡 **Tips for best results:**\n"
417
+ "- Ensure the handwritten text is clear and on a clean background.\n"
418
+ "- Only include one name/word per image.\n"
419
+ "- The model is trained on specific characters. Unusual symbols might not be recognized.")
420
+ st.text(traceback.format_exc())
421
+
422
+ st.markdown("""
423
+ ---
424
+ *Built using Streamlit, PyTorch, OpenCV, and EditDistance ©2025 by MFT*
425
+ >>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
426
+ """)
config.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
+ # config.py
3
+
4
+ import os
5
+
6
+ # --- Paths ---
7
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
8
+ DATA_DIR = os.path.join(BASE_DIR, 'data')
9
+ MODELS_DIR = os.path.join(BASE_DIR, 'models')
10
+
11
+ TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'images', 'train')
12
+ TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'images', 'test')
13
+
14
+ TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
15
+ TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
16
+
17
+ MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'handwritten_name_ocr_model.pth')
18
+
19
+ # --- Character Set and OCR Configuration ---
20
+ # This character set MUST cover all characters present in your dataset.
21
+ # Add any special characters if needed.
22
+ # The order here is crucial as it defines the indices for your characters.
23
+ CHARS = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
24
+
25
+ # Define the character for the blank token. It MUST NOT be in CHARS.
26
+ BLANK_TOKEN_SYMBOL = 'Þ'
27
+
28
+ # Construct the full vocabulary string. It's conventional to put the blank token last.
29
+ # This VOCABULARY string is what you pass to CharIndexer.
30
+ VOCABULARY = CHARS + BLANK_TOKEN_SYMBOL
31
+
32
+ # NUM_CLASSES is the total number of unique symbols in the vocabulary, including the blank.
33
+ NUM_CLASSES = len(VOCABULARY)
34
+
35
+ # BLANK_TOKEN is the actual index of the blank symbol within the VOCABULARY.
36
+ # Since we appended it last, its index will be len(CHARS).
37
+ BLANK_TOKEN = VOCABULARY.find(BLANK_TOKEN_SYMBOL)
38
+
39
+ # --- Sanity Checks (Highly Recommended) ---
40
+ if BLANK_TOKEN == -1:
41
+ raise ValueError(f"Error: BLANK_TOKEN_SYMBOL '{BLANK_TOKEN_SYMBOL}' not found in VOCABULARY. Check config.py definitions.")
42
+ if BLANK_TOKEN >= NUM_CLASSES:
43
+ raise ValueError(f"Error: BLANK_TOKEN index ({BLANK_TOKEN}) must be less than NUM_CLASSES ({NUM_CLASSES}).")
44
+
45
+ print(f"Config Loaded: NUM_CLASSES={NUM_CLASSES}, BLANK_TOKEN_INDEX={BLANK_TOKEN}")
46
+ print(f"Vocabulary Length: {len(VOCABULARY)}")
47
+ print(f"Blank Symbol: '{BLANK_TOKEN_SYMBOL}' at index {BLANK_TOKEN}")
48
+
49
+
50
+ # --- Image Preprocessing Parameters ---
51
+ IMG_HEIGHT = 32
52
+
53
+ # --- Training Parameters ---
54
+ BATCH_SIZE = 64
55
+ LEARNING_RATE = 0.001
56
+ =======
57
+ # config.py
58
+
59
+ import os
60
+
61
+ # --- Paths ---
62
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
63
+ DATA_DIR = os.path.join(BASE_DIR, 'data')
64
+ MODELS_DIR = os.path.join(BASE_DIR, 'models')
65
+
66
+ TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'images', 'train')
67
+ TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'images', 'test')
68
+
69
+ TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
70
+ TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
71
+
72
+ MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'handwritten_name_ocr_model.pth')
73
+
74
+ # --- Character Set and OCR Configuration ---
75
+ # This character set MUST cover all characters present in your dataset.
76
+ # Add any special characters if needed.
77
+ # The order here is crucial as it defines the indices for your characters.
78
+ CHARS = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
79
+
80
+ # Define the character for the blank token. It MUST NOT be in CHARS.
81
+ BLANK_TOKEN_SYMBOL = 'Þ'
82
+
83
+ # Construct the full vocabulary string. It's conventional to put the blank token last.
84
+ # This VOCABULARY string is what you pass to CharIndexer.
85
+ VOCABULARY = CHARS + BLANK_TOKEN_SYMBOL
86
+
87
+ # NUM_CLASSES is the total number of unique symbols in the vocabulary, including the blank.
88
+ NUM_CLASSES = len(VOCABULARY)
89
+
90
+ # BLANK_TOKEN is the actual index of the blank symbol within the VOCABULARY.
91
+ # Since we appended it last, its index will be len(CHARS).
92
+ BLANK_TOKEN = VOCABULARY.find(BLANK_TOKEN_SYMBOL)
93
+
94
+ # --- Sanity Checks (Highly Recommended) ---
95
+ if BLANK_TOKEN == -1:
96
+ raise ValueError(f"Error: BLANK_TOKEN_SYMBOL '{BLANK_TOKEN_SYMBOL}' not found in VOCABULARY. Check config.py definitions.")
97
+ if BLANK_TOKEN >= NUM_CLASSES:
98
+ raise ValueError(f"Error: BLANK_TOKEN index ({BLANK_TOKEN}) must be less than NUM_CLASSES ({NUM_CLASSES}).")
99
+
100
+ print(f"Config Loaded: NUM_CLASSES={NUM_CLASSES}, BLANK_TOKEN_INDEX={BLANK_TOKEN}")
101
+ print(f"Vocabulary Length: {len(VOCABULARY)}")
102
+ print(f"Blank Symbol: '{BLANK_TOKEN_SYMBOL}' at index {BLANK_TOKEN}")
103
+
104
+
105
+ # --- Image Preprocessing Parameters ---
106
+ IMG_HEIGHT = 32
107
+
108
+ # --- Training Parameters ---
109
+ BATCH_SIZE = 64
110
+ LEARNING_RATE = 0.001
111
+ >>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
112
+ NUM_EPOCHS = 3
data_handler_ocr.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
+ #data_handler_ocr.py
3
+
4
+ import pandas as pd
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from torchvision import transforms
8
+ import os
9
+ from PIL import Image
10
+ import numpy as np
11
+ import torch.nn.functional as F
12
+
13
+ # Import utility functions and config
14
+ from config import CHARS, BLANK_TOKEN, IMG_HEIGHT, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR
15
+ from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
16
+
17
+ class CharIndexer:
18
+ """Manages character-to-index and index-to-character mappings."""
19
+ def __init__(self, chars: str, blank_token: str):
20
+ self.char_to_idx = {char: i for i, char in enumerate(chars)}
21
+ self.idx_to_char = {i: char for i, char in enumerate(chars)}
22
+ self.blank_token_idx = len(chars) # Index for the blank token
23
+ self.idx_to_char[self.blank_token_idx] = blank_token # Add blank token to idx_to_char
24
+ self.num_classes = len(chars) + 1 # Total classes including blank
25
+
26
+ def encode(self, text: str) -> list[int]:
27
+ """Converts a text string to a list of integer indices."""
28
+ return [self.char_to_idx[char] for char in text]
29
+
30
+ def decode(self, indices: list[int]) -> str:
31
+ """Converts a list of integer indices back to a text string."""
32
+ # CTC decoding often produces repeated characters and blank tokens.
33
+ # This simple decoder removes blanks and duplicates.
34
+ decoded_text = []
35
+ for i, idx in enumerate(indices):
36
+ if idx == self.blank_token_idx:
37
+ continue
38
+ # Remove consecutive duplicates
39
+ if i > 0 and indices[i-1] == idx:
40
+ continue
41
+ decoded_text.append(self.idx_to_char[idx])
42
+ return "".join(decoded_text)
43
+
44
+ class OCRDataset(Dataset):
45
+ """
46
+ Custom PyTorch Dataset for the Handwritten Name Recognition task.
47
+ Loads images and their corresponding text labels.
48
+ """
49
+ def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
50
+ """
51
+ Initializes the OCR Dataset.
52
+ Args:
53
+ dataframe (pd.DataFrame): A DataFrame containing 'image_path' and 'label' columns.
54
+ char_indexer (CharIndexer): An instance of CharIndexer for character encoding.
55
+ transform (callable, optional): Optional transform to be applied on an image.
56
+ """
57
+ self.data = dataframe
58
+ self.char_indexer = char_indexer
59
+ self.image_dir = image_dir
60
+ self.transform = transform
61
+
62
+
63
+ def __len__(self) -> int:
64
+ return len(self.data)
65
+
66
+ def __getitem__(self, idx):
67
+ raw_filename_entry = self.data.iloc[idx]['FILENAME']
68
+ ground_truth_text = self.data.iloc[idx]['IDENTITY']
69
+
70
+ filename = raw_filename_entry.split(',')[0].strip() # .strip() removes any whitespace
71
+ # Construct the full image path
72
+ img_path = os.path.join(self.image_dir, filename)
73
+ # Ensure ground_truth_text is a string
74
+ ground_truth_text = str(ground_truth_text)
75
+
76
+ # Load and transform image
77
+ try:
78
+ image = Image.open(img_path).convert('L') # Convert to grayscale
79
+ except FileNotFoundError:
80
+ print(f"Error: Image file not found at {img_path}. Skipping this item.")
81
+ raise # Re-raise to let the main traceback be seen.
82
+
83
+ if self.transform:
84
+ image = self.transform(image)
85
+
86
+ image_width = image.size(2) # Assuming image is a tensor (C, H, W) after transform
87
+
88
+ text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
89
+ text_length = len(text_encoded)
90
+
91
+ return image, text_encoded, image_width, text_length
92
+
93
+ def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
94
+ """
95
+ Custom collate function for the DataLoader to handle variable-width images
96
+ and variable-length text sequences for CTC loss.
97
+ """
98
+ images, texts, image_widths, text_lengths = zip(*batch)
99
+
100
+ # Pad images to the maximum width in the current batch
101
+ max_batch_width = max(image_widths)
102
+ padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
103
+ images_batch = torch.stack(padded_images, 0) # Stack to (N, C, H, max_W)
104
+
105
+ # Concatenate all text sequences and get their lengths
106
+ texts_batch = torch.cat(texts, 0)
107
+ text_lengths_tensor = torch.tensor(text_lengths, dtype=torch.long)
108
+ image_widths_tensor = torch.tensor(image_widths, dtype=torch.long) # Actual widths
109
+
110
+ return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
111
+
112
+
113
+ def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
114
+ """
115
+ Loads training and testing dataframes.
116
+ Assumes CSVs have 'filename' and 'name' columns.
117
+ """
118
+ train_df = pd.read_csv(train_csv_path)
119
+ test_df = pd.read_csv(test_csv_path)
120
+ return train_df, test_df
121
+
122
+ def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
123
+ char_indexer: CharIndexer, batch_size: int) -> tuple[DataLoader, DataLoader]:
124
+ """
125
+ Creates PyTorch DataLoader objects for OCR training and testing datasets,
126
+ using specific image directories for train/test.
127
+ """
128
+ train_dataset = OCRDataset(train_df, TRAIN_IMAGES_DIR, char_indexer)
129
+ test_dataset = OCRDataset(test_df, TEST_IMAGES_DIR, char_indexer)
130
+
131
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
132
+ num_workers=0, collate_fn=ocr_collate_fn)
133
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
134
+ num_workers=0, collate_fn=ocr_collate_fn)
135
+ =======
136
+ #data_handler_ocr.py
137
+
138
+ import pandas as pd
139
+ import torch
140
+ from torch.utils.data import Dataset, DataLoader
141
+ from torchvision import transforms
142
+ import os
143
+ from PIL import Image
144
+ import numpy as np
145
+ import torch.nn.functional as F
146
+
147
+ # Import utility functions and config
148
+ from config import CHARS, BLANK_TOKEN, IMG_HEIGHT, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR
149
+ from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
150
+
151
+ class CharIndexer:
152
+ """Manages character-to-index and index-to-character mappings."""
153
+ def __init__(self, chars: str, blank_token: str):
154
+ self.char_to_idx = {char: i for i, char in enumerate(chars)}
155
+ self.idx_to_char = {i: char for i, char in enumerate(chars)}
156
+ self.blank_token_idx = len(chars) # Index for the blank token
157
+ self.idx_to_char[self.blank_token_idx] = blank_token # Add blank token to idx_to_char
158
+ self.num_classes = len(chars) + 1 # Total classes including blank
159
+
160
+ def encode(self, text: str) -> list[int]:
161
+ """Converts a text string to a list of integer indices."""
162
+ return [self.char_to_idx[char] for char in text]
163
+
164
+ def decode(self, indices: list[int]) -> str:
165
+ """Converts a list of integer indices back to a text string."""
166
+ # CTC decoding often produces repeated characters and blank tokens.
167
+ # This simple decoder removes blanks and duplicates.
168
+ decoded_text = []
169
+ for i, idx in enumerate(indices):
170
+ if idx == self.blank_token_idx:
171
+ continue
172
+ # Remove consecutive duplicates
173
+ if i > 0 and indices[i-1] == idx:
174
+ continue
175
+ decoded_text.append(self.idx_to_char[idx])
176
+ return "".join(decoded_text)
177
+
178
+ class OCRDataset(Dataset):
179
+ """
180
+ Custom PyTorch Dataset for the Handwritten Name Recognition task.
181
+ Loads images and their corresponding text labels.
182
+ """
183
+ def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
184
+ """
185
+ Initializes the OCR Dataset.
186
+ Args:
187
+ dataframe (pd.DataFrame): A DataFrame containing 'image_path' and 'label' columns.
188
+ char_indexer (CharIndexer): An instance of CharIndexer for character encoding.
189
+ transform (callable, optional): Optional transform to be applied on an image.
190
+ """
191
+ self.data = dataframe
192
+ self.char_indexer = char_indexer
193
+ self.image_dir = image_dir
194
+ self.transform = transform
195
+
196
+
197
+ def __len__(self) -> int:
198
+ return len(self.data)
199
+
200
+ def __getitem__(self, idx):
201
+ raw_filename_entry = self.data.iloc[idx]['FILENAME']
202
+ ground_truth_text = self.data.iloc[idx]['IDENTITY']
203
+
204
+ filename = raw_filename_entry.split(',')[0].strip() # .strip() removes any whitespace
205
+ # Construct the full image path
206
+ img_path = os.path.join(self.image_dir, filename)
207
+ # Ensure ground_truth_text is a string
208
+ ground_truth_text = str(ground_truth_text)
209
+
210
+ # Load and transform image
211
+ try:
212
+ image = Image.open(img_path).convert('L') # Convert to grayscale
213
+ except FileNotFoundError:
214
+ print(f"Error: Image file not found at {img_path}. Skipping this item.")
215
+ raise # Re-raise to let the main traceback be seen.
216
+
217
+ if self.transform:
218
+ image = self.transform(image)
219
+
220
+ image_width = image.size(2) # Assuming image is a tensor (C, H, W) after transform
221
+
222
+ text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
223
+ text_length = len(text_encoded)
224
+
225
+ return image, text_encoded, image_width, text_length
226
+
227
+ def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
228
+ """
229
+ Custom collate function for the DataLoader to handle variable-width images
230
+ and variable-length text sequences for CTC loss.
231
+ """
232
+ images, texts, image_widths, text_lengths = zip(*batch)
233
+
234
+ # Pad images to the maximum width in the current batch
235
+ max_batch_width = max(image_widths)
236
+ padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
237
+ images_batch = torch.stack(padded_images, 0) # Stack to (N, C, H, max_W)
238
+
239
+ # Concatenate all text sequences and get their lengths
240
+ texts_batch = torch.cat(texts, 0)
241
+ text_lengths_tensor = torch.tensor(text_lengths, dtype=torch.long)
242
+ image_widths_tensor = torch.tensor(image_widths, dtype=torch.long) # Actual widths
243
+
244
+ return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
245
+
246
+
247
+ def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
248
+ """
249
+ Loads training and testing dataframes.
250
+ Assumes CSVs have 'filename' and 'name' columns.
251
+ """
252
+ train_df = pd.read_csv(train_csv_path)
253
+ test_df = pd.read_csv(test_csv_path)
254
+ return train_df, test_df
255
+
256
+ def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
257
+ char_indexer: CharIndexer, batch_size: int) -> tuple[DataLoader, DataLoader]:
258
+ """
259
+ Creates PyTorch DataLoader objects for OCR training and testing datasets,
260
+ using specific image directories for train/test.
261
+ """
262
+ train_dataset = OCRDataset(train_df, TRAIN_IMAGES_DIR, char_indexer)
263
+ test_dataset = OCRDataset(test_df, TEST_IMAGES_DIR, char_indexer)
264
+
265
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
266
+ num_workers=0, collate_fn=ocr_collate_fn)
267
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
268
+ num_workers=0, collate_fn=ocr_collate_fn)
269
+ >>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
270
+ return train_loader, test_loader
model_ocr.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
+ # model_ocr.py
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ from torch.utils.data import DataLoader # Keep DataLoader for type hinting
9
+ from tqdm import tqdm
10
+ from sklearn.metrics import accuracy_score
11
+ import editdistance
12
+
13
+ # Import config and char_indexer
14
+ # Ensure these imports align with your current config.py
15
+ from config import IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN
16
+ from data_handler_ocr import CharIndexer
17
+ # You might also need to import binarize_image, resize_image_for_ocr, normalize_image_for_model
18
+ # if they are used directly in model_ocr.py for internal preprocessing (e.g., in evaluate_model if not using DataLoader)
19
+ # For now, assuming they are handled by DataLoader transforms.
20
+ from utils_ocr import binarize_image, resize_image_for_ocr, normalize_image_for_model # Add this for completeness if needed elsewhere
21
+
22
+
23
+ class CNN_Backbone(nn.Module):
24
+ """
25
+ CNN feature extractor for OCR. Designed to produce features suitable for RNN.
26
+ Output feature map should have height 1 after the final pooling/reduction.
27
+ """
28
+ def __init__(self, input_channels=1, output_channels=512):
29
+ super(CNN_Backbone, self).__init__()
30
+ self.cnn = nn.Sequential(
31
+ # First block
32
+ nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
33
+ nn.ReLU(True),
34
+ nn.MaxPool2d(kernel_size=2, stride=2), # H: 32 -> 16, W: W_in -> W_in/2
35
+
36
+ # Second block
37
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
38
+ nn.ReLU(True),
39
+ nn.MaxPool2d(kernel_size=2, stride=2), # H: 16 -> 8, W: W_in/2 -> W_in/4
40
+
41
+ # Third block (with two conv layers)
42
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
43
+ nn.ReLU(True),
44
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
45
+ nn.ReLU(True),
46
+ # This MaxPool2d effectively brings height from 8 to 4, with a small width adjustment due to padding
47
+ # The original comment (W/4 + 1) is due to padding=1 and stride=1 on width, which is fine.
48
+ nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # H: 8 -> 4, W: (W/4) -> (W/4 + 1) (approx)
49
+
50
+ # Fourth block
51
+ nn.Conv2d(256, output_channels, kernel_size=3, stride=1, padding=1),
52
+ nn.ReLU(True),
53
+ # This AdaptiveAvgPool2d makes sure the height dimension becomes 1
54
+ # while preserving the width. This is crucial for RNN input.
55
+ nn.AdaptiveAvgPool2d((1, None)) # Output height 1, preserve width
56
+ )
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ # x: (N, C, H, W) e.g., (B, 1, 32, W_img)
60
+
61
+ # Pass through the CNN layers
62
+ conv_features = self.cnn(x) # Output: (N, cnn_out_channels, 1, W_prime)
63
+
64
+ # Squeeze the height dimension (which is 1)
65
+ # This transforms (N, C_out, 1, W_prime) to (N, C_out, W_prime)
66
+ conv_features = conv_features.squeeze(2)
67
+
68
+ # Permute for RNN input: (sequence_length, batch_size, input_size)
69
+ # This transforms (N, C_out, W_prime) to (W_prime, N, C_out)
70
+ conv_features = conv_features.permute(2, 0, 1)
71
+
72
+ # Return the CNN features, ready for the RNN layer in CRNN
73
+ return conv_features
74
+
75
+ class BidirectionalLSTM(nn.Module):
76
+ """Bidirectional LSTM layer for sequence modeling."""
77
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int, dropout: float = 0.5):
78
+ super(BidirectionalLSTM, self).__init__()
79
+ self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
80
+ bidirectional=True, dropout=dropout, batch_first=False)
81
+ # batch_first=False expects input as (sequence_length, batch_size, input_size)
82
+
83
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
84
+ output, _ = self.lstm(x) # [0] returns the output, [1] returns (h_n, c_n)
85
+ return output
86
+
87
+ class CRNN(nn.Module):
88
+ """
89
+ Convolutional Recurrent Neural Network for OCR.
90
+ Combines CNN for feature extraction, LSTMs for sequence modeling,
91
+ and a final linear layer for character prediction.
92
+ """
93
+ def __init__(self, num_classes: int, cnn_output_channels: int = 512,
94
+ rnn_hidden_size: int = 256, rnn_num_layers: int = 2):
95
+ super(CRNN, self).__init__()
96
+ self.cnn = CNN_Backbone(output_channels=cnn_output_channels)
97
+ # Input to LSTM is the number of channels from the CNN output
98
+ self.rnn = BidirectionalLSTM(cnn_output_channels, rnn_hidden_size, rnn_num_layers)
99
+ # Output of bidirectional LSTM is hidden_size * 2
100
+ self.fc = nn.Linear(rnn_hidden_size * 2, num_classes) # Final linear layer for classes
101
+
102
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
103
+ # x: (N, C, H, W) e.g., (B, 1, 32, W_img)
104
+
105
+ # 1. Pass through the CNN to extract features
106
+ conv_features = self.cnn(x) # Output: (W_prime, N, C_out) after permute in CNN_Backbone
107
+
108
+ # 2. Pass CNN features through the RNN (LSTM)
109
+ rnn_features = self.rnn(conv_features) # Output: (W_prime, N, rnn_hidden_size * 2)
110
+
111
+ # 3. Pass RNN features through the final fully connected layer
112
+ # Apply the linear layer to each time step independently
113
+ # output will be (W_prime, N, num_classes)
114
+ output = self.fc(rnn_features)
115
+
116
+ return output
117
+
118
+
119
+ # --- Decoding Function ---
120
+ def ctc_greedy_decode(output: torch.Tensor, char_indexer: CharIndexer) -> list[str]:
121
+ """
122
+ Performs greedy decoding on the CTC output.
123
+ output: (sequence_length, batch_size, num_classes) - raw logits
124
+ """
125
+ # Apply log_softmax to get probabilities for argmax
126
+ log_probs = F.log_softmax(output, dim=2)
127
+
128
+ # Permute to (batch_size, sequence_length, num_classes) for argmax along class dim
129
+ # This gives us the index of the most probable character at each time step for each sample in the batch.
130
+ predicted_indices = torch.argmax(log_probs.permute(1, 0, 2), dim=2).cpu().numpy()
131
+
132
+ decoded_texts = []
133
+ for seq in predicted_indices:
134
+ # Use char_indexer's decode method, which handles blank removal and duplicate collapse
135
+ decoded_texts.append(char_indexer.decode(seq.tolist())) # Convert numpy array to list
136
+ return decoded_texts
137
+
138
+ # --- Evaluation Function ---
139
+ def evaluate_model(model: nn.Module, dataloader: DataLoader, char_indexer: CharIndexer, device: str):
140
+ model.eval() # Set model to evaluation mode
141
+ # CTCLoss needs the blank token index, which is available from char_indexer
142
+ criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
143
+ total_loss = 0
144
+ all_predictions = []
145
+ all_ground_truths = []
146
+
147
+ with torch.no_grad(): # Disable gradient calculation for evaluation
148
+ for inputs, targets_padded, _, target_lengths in tqdm(dataloader, desc="Evaluating"):
149
+ inputs = inputs.to(device)
150
+ targets_padded = targets_padded.to(device)
151
+ target_lengths = target_lengths.to(device)
152
+
153
+ output = model(inputs) # (seq_len, batch_size, num_classes)
154
+
155
+ # Calculate input_lengths for CTCLoss. This is the sequence length produced by the CNN/RNN.
156
+ # It's the `output.shape[0]` (sequence_length) for each item in the batch.
157
+ outputs_seq_len_for_ctc = torch.full(
158
+ size=(output.shape[1],), # batch_size
159
+ fill_value=output.shape[0], # actual sequence length (T) from model output
160
+ dtype=torch.long,
161
+ device=device
162
+ )
163
+
164
+ # CTC Loss calculation requires log_softmax on the output logits
165
+ log_probs_for_loss = F.log_softmax(output, dim=2) # (T, N, C)
166
+
167
+ loss = criterion(log_probs_for_loss, targets_padded, outputs_seq_len_for_ctc, target_lengths)
168
+ total_loss += loss.item() * inputs.size(0) # Multiply by batch size for correct average
169
+
170
+ # Decode predictions for metrics
171
+ decoded_preds = ctc_greedy_decode(output, char_indexer)
172
+
173
+ # Reconstruct ground truths from encoded tensors
174
+ ground_truths = []
175
+ # Loop through each sample in the batch
176
+ for i in range(targets_padded.size(0)):
177
+ # Extract the actual target sequence for the i-th sample using its length
178
+ # Convert to list before passing to char_indexer.decode
179
+ ground_truths.append(char_indexer.decode(targets_padded[i, :target_lengths[i]].tolist()))
180
+
181
+ all_predictions.extend(decoded_preds)
182
+ all_ground_truths.extend(ground_truths)
183
+
184
+ avg_loss = total_loss / len(dataloader.dataset)
185
+
186
+ # Calculate Character Error Rate (CER)
187
+ cer_sum = 0
188
+ total_chars = 0
189
+ for pred, gt in zip(all_predictions, all_ground_truths):
190
+ cer_sum += editdistance.eval(pred, gt)
191
+ total_chars += len(gt)
192
+ char_error_rate = cer_sum / total_chars if total_chars > 0 else 0.0
193
+
194
+ # Calculate Exact Match Accuracy (Word-level Accuracy)
195
+ exact_match_accuracy = accuracy_score(all_ground_truths, all_predictions)
196
+
197
+ return avg_loss, char_error_rate, exact_match_accuracy
198
+
199
+ # --- Training Function ---
200
+ def train_ocr_model(model: nn.Module, train_loader: DataLoader,
201
+ test_loader: DataLoader, char_indexer: CharIndexer,
202
+ epochs: int, device: str, progress_callback=None) -> tuple[nn.Module, dict]:
203
+ """
204
+ Trains the OCR model using CTC loss.
205
+ """
206
+ # CTCLoss needs the blank token index
207
+ criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
208
+ optimizer = optim.Adam(model.parameters(), lr=0.001) # Using a fixed LR for now
209
+ # Using ReduceLROnPlateau to adjust LR based on test loss (monitor 'min' loss)
210
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5)
211
+
212
+ model.to(device) # Ensure model is on the correct device
213
+ model.train() # Set model to training mode
214
+
215
+ training_history = {
216
+ 'train_loss': [],
217
+ 'test_loss': [],
218
+ 'test_cer': [],
219
+ 'test_exact_match_accuracy': []
220
+ }
221
+
222
+ for epoch in range(epochs):
223
+ running_loss = 0.0
224
+ pbar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)")
225
+ for images, texts_encoded, _, text_lengths in pbar_train:
226
+ images = images.to(device)
227
+ # Ensure target tensors are on the correct device for CTCLoss calculation
228
+ texts_encoded = texts_encoded.to(device)
229
+ text_lengths = text_lengths.to(device)
230
+
231
+ optimizer.zero_grad() # Clear gradients from previous step
232
+ outputs = model(images) # (sequence_length_from_cnn, batch_size, num_classes)
233
+
234
+ # `outputs.shape[0]` is the actual sequence length (T) produced by the model.
235
+ # CTC loss expects `input_lengths` to be a tensor of shape (batch_size,) with these values.
236
+ outputs_seq_len_for_ctc = torch.full(
237
+ size=(outputs.shape[1],), # batch_size
238
+ fill_value=outputs.shape[0], # actual sequence length (T) from model output
239
+ dtype=torch.long,
240
+ device=device
241
+ )
242
+
243
+ # CTC Loss calculation requires log_softmax on the output logits
244
+ log_probs_for_loss = F.log_softmax(outputs, dim=2) # (T, N, C)
245
+
246
+ # Use outputs_seq_len_for_ctc for the input_lengths argument
247
+ loss = criterion(log_probs_for_loss, texts_encoded, outputs_seq_len_for_ctc, text_lengths)
248
+ loss.backward() # Backpropagate
249
+ optimizer.step() # Update model weights
250
+
251
+ running_loss += loss.item() * images.size(0) # Multiply by batch size for correct average
252
+ pbar_train.set_postfix(loss=loss.item())
253
+
254
+ epoch_train_loss = running_loss / len(train_loader.dataset)
255
+ training_history['train_loss'].append(epoch_train_loss)
256
+
257
+ # Evaluate on test set using the dedicated function
258
+ # Ensure model is in eval mode before calling evaluate_model
259
+ model.eval()
260
+ test_loss, test_cer, test_exact_match_accuracy = evaluate_model(model, test_loader, char_indexer, device)
261
+ training_history['test_loss'].append(test_loss)
262
+ training_history['test_cer'].append(test_cer)
263
+ training_history['test_exact_match_accuracy'].append(test_exact_match_accuracy)
264
+
265
+ # Adjust learning rate based on test loss (this is where scheduler.step() is called)
266
+ scheduler.step(test_loss)
267
+
268
+ print(f"Epoch {epoch+1}/{epochs}: Train Loss={epoch_train_loss:.4f}, "
269
+ f"Test Loss={test_loss:.4f}, Test CER={test_cer:.4f}, Test Exact Match Acc={test_exact_match_accuracy:.4f}")
270
+
271
+ if progress_callback:
272
+ # Update progress bar with current epoch and key metrics
273
+ progress_val = (epoch + 1) / epochs
274
+ progress_callback(progress_val, text=f"Epoch {epoch+1}/{epochs} done. Test CER: {test_cer:.4f}, Test Exact Match Acc: {test_exact_match_accuracy:.4f}")
275
+
276
+ model.train() # Set model back to training mode after evaluation
277
+
278
+ return model, training_history
279
+
280
+ def save_ocr_model(model: nn.Module, path: str):
281
+ """Saves the state dictionary of the trained OCR model."""
282
+ torch.save(model.state_dict(), path)
283
+ print(f"OCR model saved to {path}")
284
+
285
+ def load_ocr_model(model: nn.Module, path: str):
286
+ """
287
+ Loads a trained OCR model's state dictionary.
288
+ Includes map_location to handle loading models trained on GPU to CPU, and vice versa.
289
+ """
290
+ model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) # Always load to CPU first
291
+ model.eval() # Set to evaluation mode
292
+ =======
293
+ # model_ocr.py
294
+
295
+ import torch
296
+ import torch.nn as nn
297
+ import torch.nn.functional as F
298
+ import torch.optim as optim
299
+ from torch.utils.data import DataLoader # Keep DataLoader for type hinting
300
+ from tqdm import tqdm
301
+ from sklearn.metrics import accuracy_score
302
+ import editdistance
303
+
304
+ # Import config and char_indexer
305
+ # Ensure these imports align with your current config.py
306
+ from config import IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN
307
+ from data_handler_ocr import CharIndexer
308
+ # You might also need to import binarize_image, resize_image_for_ocr, normalize_image_for_model
309
+ # if they are used directly in model_ocr.py for internal preprocessing (e.g., in evaluate_model if not using DataLoader)
310
+ # For now, assuming they are handled by DataLoader transforms.
311
+ from utils_ocr import binarize_image, resize_image_for_ocr, normalize_image_for_model # Add this for completeness if needed elsewhere
312
+
313
+
314
+ class CNN_Backbone(nn.Module):
315
+ """
316
+ CNN feature extractor for OCR. Designed to produce features suitable for RNN.
317
+ Output feature map should have height 1 after the final pooling/reduction.
318
+ """
319
+ def __init__(self, input_channels=1, output_channels=512):
320
+ super(CNN_Backbone, self).__init__()
321
+ self.cnn = nn.Sequential(
322
+ # First block
323
+ nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
324
+ nn.ReLU(True),
325
+ nn.MaxPool2d(kernel_size=2, stride=2), # H: 32 -> 16, W: W_in -> W_in/2
326
+
327
+ # Second block
328
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
329
+ nn.ReLU(True),
330
+ nn.MaxPool2d(kernel_size=2, stride=2), # H: 16 -> 8, W: W_in/2 -> W_in/4
331
+
332
+ # Third block (with two conv layers)
333
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
334
+ nn.ReLU(True),
335
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
336
+ nn.ReLU(True),
337
+ # This MaxPool2d effectively brings height from 8 to 4, with a small width adjustment due to padding
338
+ # The original comment (W/4 + 1) is due to padding=1 and stride=1 on width, which is fine.
339
+ nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # H: 8 -> 4, W: (W/4) -> (W/4 + 1) (approx)
340
+
341
+ # Fourth block
342
+ nn.Conv2d(256, output_channels, kernel_size=3, stride=1, padding=1),
343
+ nn.ReLU(True),
344
+ # This AdaptiveAvgPool2d makes sure the height dimension becomes 1
345
+ # while preserving the width. This is crucial for RNN input.
346
+ nn.AdaptiveAvgPool2d((1, None)) # Output height 1, preserve width
347
+ )
348
+
349
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
350
+ # x: (N, C, H, W) e.g., (B, 1, 32, W_img)
351
+
352
+ # Pass through the CNN layers
353
+ conv_features = self.cnn(x) # Output: (N, cnn_out_channels, 1, W_prime)
354
+
355
+ # Squeeze the height dimension (which is 1)
356
+ # This transforms (N, C_out, 1, W_prime) to (N, C_out, W_prime)
357
+ conv_features = conv_features.squeeze(2)
358
+
359
+ # Permute for RNN input: (sequence_length, batch_size, input_size)
360
+ # This transforms (N, C_out, W_prime) to (W_prime, N, C_out)
361
+ conv_features = conv_features.permute(2, 0, 1)
362
+
363
+ # Return the CNN features, ready for the RNN layer in CRNN
364
+ return conv_features
365
+
366
+ class BidirectionalLSTM(nn.Module):
367
+ """Bidirectional LSTM layer for sequence modeling."""
368
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int, dropout: float = 0.5):
369
+ super(BidirectionalLSTM, self).__init__()
370
+ self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
371
+ bidirectional=True, dropout=dropout, batch_first=False)
372
+ # batch_first=False expects input as (sequence_length, batch_size, input_size)
373
+
374
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
375
+ output, _ = self.lstm(x) # [0] returns the output, [1] returns (h_n, c_n)
376
+ return output
377
+
378
+ class CRNN(nn.Module):
379
+ """
380
+ Convolutional Recurrent Neural Network for OCR.
381
+ Combines CNN for feature extraction, LSTMs for sequence modeling,
382
+ and a final linear layer for character prediction.
383
+ """
384
+ def __init__(self, num_classes: int, cnn_output_channels: int = 512,
385
+ rnn_hidden_size: int = 256, rnn_num_layers: int = 2):
386
+ super(CRNN, self).__init__()
387
+ self.cnn = CNN_Backbone(output_channels=cnn_output_channels)
388
+ # Input to LSTM is the number of channels from the CNN output
389
+ self.rnn = BidirectionalLSTM(cnn_output_channels, rnn_hidden_size, rnn_num_layers)
390
+ # Output of bidirectional LSTM is hidden_size * 2
391
+ self.fc = nn.Linear(rnn_hidden_size * 2, num_classes) # Final linear layer for classes
392
+
393
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
394
+ # x: (N, C, H, W) e.g., (B, 1, 32, W_img)
395
+
396
+ # 1. Pass through the CNN to extract features
397
+ conv_features = self.cnn(x) # Output: (W_prime, N, C_out) after permute in CNN_Backbone
398
+
399
+ # 2. Pass CNN features through the RNN (LSTM)
400
+ rnn_features = self.rnn(conv_features) # Output: (W_prime, N, rnn_hidden_size * 2)
401
+
402
+ # 3. Pass RNN features through the final fully connected layer
403
+ # Apply the linear layer to each time step independently
404
+ # output will be (W_prime, N, num_classes)
405
+ output = self.fc(rnn_features)
406
+
407
+ return output
408
+
409
+
410
+ # --- Decoding Function ---
411
+ def ctc_greedy_decode(output: torch.Tensor, char_indexer: CharIndexer) -> list[str]:
412
+ """
413
+ Performs greedy decoding on the CTC output.
414
+ output: (sequence_length, batch_size, num_classes) - raw logits
415
+ """
416
+ # Apply log_softmax to get probabilities for argmax
417
+ log_probs = F.log_softmax(output, dim=2)
418
+
419
+ # Permute to (batch_size, sequence_length, num_classes) for argmax along class dim
420
+ # This gives us the index of the most probable character at each time step for each sample in the batch.
421
+ predicted_indices = torch.argmax(log_probs.permute(1, 0, 2), dim=2).cpu().numpy()
422
+
423
+ decoded_texts = []
424
+ for seq in predicted_indices:
425
+ # Use char_indexer's decode method, which handles blank removal and duplicate collapse
426
+ decoded_texts.append(char_indexer.decode(seq.tolist())) # Convert numpy array to list
427
+ return decoded_texts
428
+
429
+ # --- Evaluation Function ---
430
+ def evaluate_model(model: nn.Module, dataloader: DataLoader, char_indexer: CharIndexer, device: str):
431
+ model.eval() # Set model to evaluation mode
432
+ # CTCLoss needs the blank token index, which is available from char_indexer
433
+ criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
434
+ total_loss = 0
435
+ all_predictions = []
436
+ all_ground_truths = []
437
+
438
+ with torch.no_grad(): # Disable gradient calculation for evaluation
439
+ for inputs, targets_padded, _, target_lengths in tqdm(dataloader, desc="Evaluating"):
440
+ inputs = inputs.to(device)
441
+ targets_padded = targets_padded.to(device)
442
+ target_lengths = target_lengths.to(device)
443
+
444
+ output = model(inputs) # (seq_len, batch_size, num_classes)
445
+
446
+ # Calculate input_lengths for CTCLoss. This is the sequence length produced by the CNN/RNN.
447
+ # It's the `output.shape[0]` (sequence_length) for each item in the batch.
448
+ outputs_seq_len_for_ctc = torch.full(
449
+ size=(output.shape[1],), # batch_size
450
+ fill_value=output.shape[0], # actual sequence length (T) from model output
451
+ dtype=torch.long,
452
+ device=device
453
+ )
454
+
455
+ # CTC Loss calculation requires log_softmax on the output logits
456
+ log_probs_for_loss = F.log_softmax(output, dim=2) # (T, N, C)
457
+
458
+ loss = criterion(log_probs_for_loss, targets_padded, outputs_seq_len_for_ctc, target_lengths)
459
+ total_loss += loss.item() * inputs.size(0) # Multiply by batch size for correct average
460
+
461
+ # Decode predictions for metrics
462
+ decoded_preds = ctc_greedy_decode(output, char_indexer)
463
+
464
+ # Reconstruct ground truths from encoded tensors
465
+ ground_truths = []
466
+ # Loop through each sample in the batch
467
+ for i in range(targets_padded.size(0)):
468
+ # Extract the actual target sequence for the i-th sample using its length
469
+ # Convert to list before passing to char_indexer.decode
470
+ ground_truths.append(char_indexer.decode(targets_padded[i, :target_lengths[i]].tolist()))
471
+
472
+ all_predictions.extend(decoded_preds)
473
+ all_ground_truths.extend(ground_truths)
474
+
475
+ avg_loss = total_loss / len(dataloader.dataset)
476
+
477
+ # Calculate Character Error Rate (CER)
478
+ cer_sum = 0
479
+ total_chars = 0
480
+ for pred, gt in zip(all_predictions, all_ground_truths):
481
+ cer_sum += editdistance.eval(pred, gt)
482
+ total_chars += len(gt)
483
+ char_error_rate = cer_sum / total_chars if total_chars > 0 else 0.0
484
+
485
+ # Calculate Exact Match Accuracy (Word-level Accuracy)
486
+ exact_match_accuracy = accuracy_score(all_ground_truths, all_predictions)
487
+
488
+ return avg_loss, char_error_rate, exact_match_accuracy
489
+
490
+ # --- Training Function ---
491
+ def train_ocr_model(model: nn.Module, train_loader: DataLoader,
492
+ test_loader: DataLoader, char_indexer: CharIndexer,
493
+ epochs: int, device: str, progress_callback=None) -> tuple[nn.Module, dict]:
494
+ """
495
+ Trains the OCR model using CTC loss.
496
+ """
497
+ # CTCLoss needs the blank token index
498
+ criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
499
+ optimizer = optim.Adam(model.parameters(), lr=0.001) # Using a fixed LR for now
500
+ # Using ReduceLROnPlateau to adjust LR based on test loss (monitor 'min' loss)
501
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5)
502
+
503
+ model.to(device) # Ensure model is on the correct device
504
+ model.train() # Set model to training mode
505
+
506
+ training_history = {
507
+ 'train_loss': [],
508
+ 'test_loss': [],
509
+ 'test_cer': [],
510
+ 'test_exact_match_accuracy': []
511
+ }
512
+
513
+ for epoch in range(epochs):
514
+ running_loss = 0.0
515
+ pbar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)")
516
+ for images, texts_encoded, _, text_lengths in pbar_train:
517
+ images = images.to(device)
518
+ # Ensure target tensors are on the correct device for CTCLoss calculation
519
+ texts_encoded = texts_encoded.to(device)
520
+ text_lengths = text_lengths.to(device)
521
+
522
+ optimizer.zero_grad() # Clear gradients from previous step
523
+ outputs = model(images) # (sequence_length_from_cnn, batch_size, num_classes)
524
+
525
+ # `outputs.shape[0]` is the actual sequence length (T) produced by the model.
526
+ # CTC loss expects `input_lengths` to be a tensor of shape (batch_size,) with these values.
527
+ outputs_seq_len_for_ctc = torch.full(
528
+ size=(outputs.shape[1],), # batch_size
529
+ fill_value=outputs.shape[0], # actual sequence length (T) from model output
530
+ dtype=torch.long,
531
+ device=device
532
+ )
533
+
534
+ # CTC Loss calculation requires log_softmax on the output logits
535
+ log_probs_for_loss = F.log_softmax(outputs, dim=2) # (T, N, C)
536
+
537
+ # Use outputs_seq_len_for_ctc for the input_lengths argument
538
+ loss = criterion(log_probs_for_loss, texts_encoded, outputs_seq_len_for_ctc, text_lengths)
539
+ loss.backward() # Backpropagate
540
+ optimizer.step() # Update model weights
541
+
542
+ running_loss += loss.item() * images.size(0) # Multiply by batch size for correct average
543
+ pbar_train.set_postfix(loss=loss.item())
544
+
545
+ epoch_train_loss = running_loss / len(train_loader.dataset)
546
+ training_history['train_loss'].append(epoch_train_loss)
547
+
548
+ # Evaluate on test set using the dedicated function
549
+ # Ensure model is in eval mode before calling evaluate_model
550
+ model.eval()
551
+ test_loss, test_cer, test_exact_match_accuracy = evaluate_model(model, test_loader, char_indexer, device)
552
+ training_history['test_loss'].append(test_loss)
553
+ training_history['test_cer'].append(test_cer)
554
+ training_history['test_exact_match_accuracy'].append(test_exact_match_accuracy)
555
+
556
+ # Adjust learning rate based on test loss (this is where scheduler.step() is called)
557
+ scheduler.step(test_loss)
558
+
559
+ print(f"Epoch {epoch+1}/{epochs}: Train Loss={epoch_train_loss:.4f}, "
560
+ f"Test Loss={test_loss:.4f}, Test CER={test_cer:.4f}, Test Exact Match Acc={test_exact_match_accuracy:.4f}")
561
+
562
+ if progress_callback:
563
+ # Update progress bar with current epoch and key metrics
564
+ progress_val = (epoch + 1) / epochs
565
+ progress_callback(progress_val, text=f"Epoch {epoch+1}/{epochs} done. Test CER: {test_cer:.4f}, Test Exact Match Acc: {test_exact_match_accuracy:.4f}")
566
+
567
+ model.train() # Set model back to training mode after evaluation
568
+
569
+ return model, training_history
570
+
571
+ def save_ocr_model(model: nn.Module, path: str):
572
+ """Saves the state dictionary of the trained OCR model."""
573
+ torch.save(model.state_dict(), path)
574
+ print(f"OCR model saved to {path}")
575
+
576
+ def load_ocr_model(model: nn.Module, path: str):
577
+ """
578
+ Loads a trained OCR model's state dictionary.
579
+ Includes map_location to handle loading models trained on GPU to CPU, and vice versa.
580
+ """
581
+ model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) # Always load to CPU first
582
+ model.eval() # Set to evaluation mode
583
+ >>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
584
+ print(f"OCR model loaded from {path}")
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
+ #requirements.txt
3
+ # This file lists all the Python libraries required to run the Handwritten Name OCR application.
4
+ # Install using: pip install -r requirements.txt
5
+
6
+ streamlit>=1.33.0
7
+ pandas>=2.2.2
8
+ numpy>=1.26.4
9
+ Pillow>=10.3.0
10
+ opencv-python-headless>=4.9.0.80
11
+ torch>=2.2.2
12
+ torchvision>=0.17.2 # PyTorch companion library for vision tasks (datasets, transforms)
13
+ matplotlib>=3.8.4 # For plotting training history
14
+ tqdm>=4.66.2 # For displaying progress bars during training
15
+ editdistance>=0.8.1 # For calculating character error rate (CER)
16
+ =======
17
+ #requirements.txt
18
+ # This file lists all the Python libraries required to run the Handwritten Name OCR application.
19
+ # Install using: pip install -r requirements.txt
20
+
21
+ streamlit>=1.33.0
22
+ pandas>=2.2.2
23
+ numpy>=1.26.4
24
+ Pillow>=10.3.0
25
+ opencv-python-headless>=4.9.0.80
26
+ torch>=2.2.2
27
+ torchvision>=0.17.2 # PyTorch companion library for vision tasks (datasets, transforms)
28
+ matplotlib>=3.8.4 # For plotting training history
29
+ tqdm>=4.66.2 # For displaying progress bars during training
30
+ editdistance>=0.8.1 # For calculating character error rate (CER)
31
+ >>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
32
+ scikit-learn>=1.4.2
utils_ocr.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
+ #utils_ocr.py
3
+
4
+ import cv2
5
+ from matplotlib.pylab import f
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch
9
+ from torchvision import transforms
10
+
11
+ # --- Image Preprocessing for OCR ---
12
+
13
+ def load_image_as_grayscale(image_path: str) -> Image.Image:
14
+ """Loads an image from path and converts it to grayscale PIL Image."""
15
+ # Use PIL for robust image loading and conversion to grayscale 'L' mode
16
+ img = Image.open(image_path).convert('L')
17
+ return img
18
+
19
+ def binarize_image(image_pil: Image.Image) -> Image.Image:
20
+ """Binarizes a grayscale PIL Image (black and white)."""
21
+ # Convert PIL to OpenCV format (numpy array)
22
+ img_np = np.array(image_pil)
23
+ # Apply Otsu's thresholding for adaptive binarization
24
+ _, img_bin = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
25
+ # Invert colors: Handwritten text usually dark on light. OCR models often
26
+ # prefer light text on dark background. Check your training data's style.
27
+ # This example assumes dark text on light background and inverts to white text on black.
28
+ img_bin = 255 - img_bin
29
+ return Image.fromarray(img_bin)
30
+
31
+ def resize_image_for_ocr(image_pil: Image.Image, target_height: int) -> Image.Image:
32
+ """
33
+ Resizes a PIL Image to a target height while maintaining aspect ratio.
34
+ Pads width if necessary to avoid distortion.
35
+ """
36
+ original_width, original_height = image_pil.size
37
+ # Calculate new width based on target height and original aspect ratio
38
+ new_width = int(original_width * (target_height / original_height))
39
+ resized_img = image_pil.resize((new_width, target_height), Image.LANCZOS)
40
+ return resized_img
41
+
42
+ def normalize_image_for_model(image_pil: Image.Image) -> torch.Tensor:
43
+ """
44
+ Converts a PIL Image to a PyTorch Tensor and normalizes pixel values.
45
+ """
46
+ # Convert to tensor (scales to 0-1 automatically)
47
+ tensor_transform = transforms.ToTensor()
48
+ img_tensor = tensor_transform(image_pil)
49
+ # For grayscale images, mean and std are single values.
50
+ # Adjust normalization values if your training data uses different ones.
51
+ img_tensor = transforms.Normalize((0.5,), (0.5,))(img_tensor) # Normalize to [-1, 1]
52
+ return img_tensor
53
+
54
+ def preprocess_user_image_for_ocr(uploaded_image_pil: Image.Image, target_height: int) -> torch.Tensor:
55
+ """
56
+ Combines all preprocessing steps for a single user-uploaded image
57
+ to prepare it for the OCR model.
58
+ """
59
+ # Ensure it's grayscale
60
+ img_gray = uploaded_image_pil.convert('L')
61
+
62
+ # Binarize
63
+ img_bin = binarize_image(img_gray)
64
+
65
+ # Resize (maintain aspect ratio)
66
+ img_resized = resize_image_for_ocr(img_bin, target_height)
67
+
68
+ # Normalize and convert to tensor
69
+ img_tensor = normalize_image_for_model(img_resized)
70
+
71
+ # Add batch dimension: (C, H, W) -> (1, C, H, W)
72
+ img_tensor = img_tensor.unsqueeze(0)
73
+
74
+ return img_tensor
75
+
76
+ def pad_image_tensor(image_tensor: torch.Tensor, max_width: int) -> torch.Tensor:
77
+ """
78
+ Pads a single image tensor to a max_width with zeros.
79
+ Input tensor shape: (C, H, W)
80
+ Output tensor shape: (C, H, max_width)
81
+ """
82
+ C, H, W = image_tensor.shape
83
+ if W > max_width:
84
+ # If image is wider than max_width, you might want to crop or resize it.
85
+ # For this example, we'll just return a warning or clip.
86
+ # A more robust solution might split text lines or use a different resizing strategy.
87
+ print(f"Warning: Image width {W} exceeds max_width {max_width}. Cropping.")
88
+ return image_tensor[:, :, :max_width] # Simple cropping
89
+ padding = max_width - W
90
+ # Pad on the right (P_left, P_right, P_top, P_bottom)
91
+ padded_tensor = f.pad(image_tensor, (0, padding), 'constant', 0)
92
+ =======
93
+ #utils_ocr.py
94
+
95
+ import cv2
96
+ from matplotlib.pylab import f
97
+ import numpy as np
98
+ from PIL import Image
99
+ import torch
100
+ from torchvision import transforms
101
+
102
+ # --- Image Preprocessing for OCR ---
103
+
104
+ def load_image_as_grayscale(image_path: str) -> Image.Image:
105
+ """Loads an image from path and converts it to grayscale PIL Image."""
106
+ # Use PIL for robust image loading and conversion to grayscale 'L' mode
107
+ img = Image.open(image_path).convert('L')
108
+ return img
109
+
110
+ def binarize_image(image_pil: Image.Image) -> Image.Image:
111
+ """Binarizes a grayscale PIL Image (black and white)."""
112
+ # Convert PIL to OpenCV format (numpy array)
113
+ img_np = np.array(image_pil)
114
+ # Apply Otsu's thresholding for adaptive binarization
115
+ _, img_bin = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
116
+ # Invert colors: Handwritten text usually dark on light. OCR models often
117
+ # prefer light text on dark background. Check your training data's style.
118
+ # This example assumes dark text on light background and inverts to white text on black.
119
+ img_bin = 255 - img_bin
120
+ return Image.fromarray(img_bin)
121
+
122
+ def resize_image_for_ocr(image_pil: Image.Image, target_height: int) -> Image.Image:
123
+ """
124
+ Resizes a PIL Image to a target height while maintaining aspect ratio.
125
+ Pads width if necessary to avoid distortion.
126
+ """
127
+ original_width, original_height = image_pil.size
128
+ # Calculate new width based on target height and original aspect ratio
129
+ new_width = int(original_width * (target_height / original_height))
130
+ resized_img = image_pil.resize((new_width, target_height), Image.LANCZOS)
131
+ return resized_img
132
+
133
+ def normalize_image_for_model(image_pil: Image.Image) -> torch.Tensor:
134
+ """
135
+ Converts a PIL Image to a PyTorch Tensor and normalizes pixel values.
136
+ """
137
+ # Convert to tensor (scales to 0-1 automatically)
138
+ tensor_transform = transforms.ToTensor()
139
+ img_tensor = tensor_transform(image_pil)
140
+ # For grayscale images, mean and std are single values.
141
+ # Adjust normalization values if your training data uses different ones.
142
+ img_tensor = transforms.Normalize((0.5,), (0.5,))(img_tensor) # Normalize to [-1, 1]
143
+ return img_tensor
144
+
145
+ def preprocess_user_image_for_ocr(uploaded_image_pil: Image.Image, target_height: int) -> torch.Tensor:
146
+ """
147
+ Combines all preprocessing steps for a single user-uploaded image
148
+ to prepare it for the OCR model.
149
+ """
150
+ # Ensure it's grayscale
151
+ img_gray = uploaded_image_pil.convert('L')
152
+
153
+ # Binarize
154
+ img_bin = binarize_image(img_gray)
155
+
156
+ # Resize (maintain aspect ratio)
157
+ img_resized = resize_image_for_ocr(img_bin, target_height)
158
+
159
+ # Normalize and convert to tensor
160
+ img_tensor = normalize_image_for_model(img_resized)
161
+
162
+ # Add batch dimension: (C, H, W) -> (1, C, H, W)
163
+ img_tensor = img_tensor.unsqueeze(0)
164
+
165
+ return img_tensor
166
+
167
+ def pad_image_tensor(image_tensor: torch.Tensor, max_width: int) -> torch.Tensor:
168
+ """
169
+ Pads a single image tensor to a max_width with zeros.
170
+ Input tensor shape: (C, H, W)
171
+ Output tensor shape: (C, H, max_width)
172
+ """
173
+ C, H, W = image_tensor.shape
174
+ if W > max_width:
175
+ # If image is wider than max_width, you might want to crop or resize it.
176
+ # For this example, we'll just return a warning or clip.
177
+ # A more robust solution might split text lines or use a different resizing strategy.
178
+ print(f"Warning: Image width {W} exceeds max_width {max_width}. Cropping.")
179
+ return image_tensor[:, :, :max_width] # Simple cropping
180
+ padding = max_width - W
181
+ # Pad on the right (P_left, P_right, P_top, P_bottom)
182
+ padded_tensor = f.pad(image_tensor, (0, padding), 'constant', 0)
183
+ >>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
184
+ return padded_tensor