Spaces:
Build error
Build error
Initial update of files
Browse files- LICENSE +201 -0
- app.py +426 -0
- config.py +112 -0
- data_handler_ocr.py +270 -0
- model_ocr.py +584 -0
- requirements.txt +32 -0
- utils_ocr.py +184 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
app.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<<<<<<< HEAD
|
2 |
+
# app.py
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F # Added F for log_softmax in inference
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
import os
|
12 |
+
import traceback # For detailed error logging
|
13 |
+
|
14 |
+
# Import custom modules
|
15 |
+
from config import CHARS, BLANK_TOKEN, IMG_HEIGHT, TRAIN_CSV_PATH, TEST_CSV_PATH, \
|
16 |
+
TRAIN_IMAGES_DIR, TEST_IMAGES_DIR, MODEL_SAVE_PATH, NUM_CLASSES, NUM_EPOCHS, BATCH_SIZE
|
17 |
+
from data_handler_ocr import CharIndexer, OCRDataset
|
18 |
+
from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
|
19 |
+
from utils_ocr import preprocess_user_image_for_ocr
|
20 |
+
|
21 |
+
# --- Streamlit App Setup ---
|
22 |
+
st.set_page_config(page_title="Handwritten Name Recognizer", layout="centered")
|
23 |
+
|
24 |
+
st.title("📝 Handwritten Name Recognition (OCR)")
|
25 |
+
st.markdown("""
|
26 |
+
This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
|
27 |
+
Optical Character Recognition (OCR) on handwritten names. You can upload an image
|
28 |
+
of a handwritten name for prediction or train a new model using the provided dataset.
|
29 |
+
|
30 |
+
**Note:** Training a robust OCR model can be time-consuming.
|
31 |
+
""")
|
32 |
+
|
33 |
+
# --- Initialize CharIndexer ---
|
34 |
+
# The CHARS variable should contain all possible characters your model can recognize.
|
35 |
+
# Make sure it's comprehensive based on your dataset.
|
36 |
+
char_indexer = CharIndexer(CHARS, BLANK_TOKEN)
|
37 |
+
# For robustness, it's best to always use char_indexer.num_classes
|
38 |
+
# If NUM_CLASSES from config is used to initialize CRNN, ensure it matches char_indexer.num_classes
|
39 |
+
|
40 |
+
# --- Model Loading / Initialization ---
|
41 |
+
@st.cache_resource # Cache the model to prevent reloading on every rerun
|
42 |
+
def get_and_load_ocr_model_cached(num_classes, model_path):
|
43 |
+
"""
|
44 |
+
Initializes the OCR model and attempts to load a pre-trained model.
|
45 |
+
If no pre-trained model exists, a new model instance is returned.
|
46 |
+
"""
|
47 |
+
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
|
48 |
+
|
49 |
+
if os.path.exists(model_path):
|
50 |
+
st.sidebar.info("Loading pre-trained OCR model...")
|
51 |
+
try:
|
52 |
+
# Load model to CPU first, then move to device
|
53 |
+
model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
54 |
+
st.sidebar.success("OCR model loaded successfully!")
|
55 |
+
except Exception as e:
|
56 |
+
st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
|
57 |
+
# If loading fails, re-initialize an untrained model
|
58 |
+
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
|
59 |
+
else:
|
60 |
+
st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
|
61 |
+
|
62 |
+
return model_instance
|
63 |
+
|
64 |
+
# Get the model instance
|
65 |
+
ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
|
66 |
+
# Determine the device (GPU if available, else CPU)
|
67 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
68 |
+
ocr_model.to(device)
|
69 |
+
ocr_model.eval() # Set model to evaluation mode for inference by default
|
70 |
+
|
71 |
+
# --- Sidebar for Model Training ---
|
72 |
+
st.sidebar.header("Model Training (Optional)")
|
73 |
+
st.sidebar.markdown("If you want to train a new model or no model is found:")
|
74 |
+
|
75 |
+
# Initialize Streamlit widgets outside the button block
|
76 |
+
training_progress_bar = st.sidebar.empty() # Placeholder for progress bar
|
77 |
+
status_text = st.sidebar.empty() # Placeholder for status messages
|
78 |
+
|
79 |
+
if st.sidebar.button("📊 Train New OCR Model"):
|
80 |
+
# Clear previous messages/widgets if button is clicked again
|
81 |
+
training_progress_bar.empty()
|
82 |
+
status_text.empty()
|
83 |
+
|
84 |
+
# Check for existence of CSVs and image directories
|
85 |
+
if not os.path.exists(TRAIN_CSV_PATH) or not os.path.exists(TEST_CSV_PATH) or \
|
86 |
+
not os.path.isdir(TRAIN_IMAGES_DIR) or not os.path.isdir(TEST_IMAGES_DIR):
|
87 |
+
status_text.error(f"""Dataset files or image directories not found.
|
88 |
+
Please ensure '{TRAIN_CSV_PATH}', '{TEST_CSV_PATH}', and directories '{TRAIN_IMAGES_DIR}'
|
89 |
+
and '{TEST_IMAGES_DIR}' exist. Refer to your project structure.""")
|
90 |
+
else:
|
91 |
+
status_text.write(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
|
92 |
+
|
93 |
+
training_progress_bar_instance = training_progress_bar.progress(0.0, text="Training in progress. Please wait.")
|
94 |
+
|
95 |
+
try:
|
96 |
+
train_df = pd.read_csv(TRAIN_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
|
97 |
+
test_df = pd.read_csv(TEST_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
|
98 |
+
|
99 |
+
# Define standard image transforms for consistency
|
100 |
+
train_transform = transforms.Compose([
|
101 |
+
transforms.Resize((IMG_HEIGHT, 100)), # Resize to fixed height, width will be 100 (adjust as needed for variable width)
|
102 |
+
transforms.ToTensor(), # Converts PIL Image to PyTorch Tensor (H, W) -> (C, H, W), normalizes to [0,1]
|
103 |
+
])
|
104 |
+
test_transform = transforms.Compose([
|
105 |
+
transforms.Resize((IMG_HEIGHT, 100)), # Same transformation as train
|
106 |
+
transforms.ToTensor(),
|
107 |
+
])
|
108 |
+
|
109 |
+
# Create dataset instances
|
110 |
+
train_dataset = OCRDataset(dataframe=train_df, char_indexer=char_indexer, image_dir=TRAIN_IMAGES_DIR, transform=train_transform)
|
111 |
+
test_dataset = OCRDataset(dataframe=test_df, char_indexer=char_indexer, image_dir=TEST_IMAGES_DIR, transform=test_transform)
|
112 |
+
|
113 |
+
# Create DataLoader instances
|
114 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) # num_workers=0 for Windows
|
115 |
+
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
|
116 |
+
|
117 |
+
# Train the model, passing the progress callback
|
118 |
+
trained_ocr_model, training_history = train_ocr_model(
|
119 |
+
ocr_model, # Pass the initialized model instance
|
120 |
+
train_loader,
|
121 |
+
test_loader,
|
122 |
+
char_indexer, # Pass char_indexer for CER calculation
|
123 |
+
epochs=NUM_EPOCHS,
|
124 |
+
device=device,
|
125 |
+
progress_callback=training_progress_bar_instance.progress # Pass the instance's progress method
|
126 |
+
)
|
127 |
+
|
128 |
+
# Ensure the directory for saving the model exists
|
129 |
+
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
|
130 |
+
save_ocr_model(trained_ocr_model, MODEL_SAVE_PATH)
|
131 |
+
status_text.success(f"Model training complete and saved to `{MODEL_SAVE_PATH}`!")
|
132 |
+
|
133 |
+
# Display training history chart
|
134 |
+
st.sidebar.subheader("Training History Plots")
|
135 |
+
|
136 |
+
history_df = pd.DataFrame({
|
137 |
+
'Epoch': range(1, len(training_history['train_loss']) + 1),
|
138 |
+
'Train Loss': training_history['train_loss'],
|
139 |
+
'Test Loss': training_history['test_loss'],
|
140 |
+
'Test CER (%)': [cer * 100 for cer in training_history['test_cer']], # Convert CER to percentage for display
|
141 |
+
'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']] # Convert to percentage
|
142 |
+
})
|
143 |
+
|
144 |
+
# Plot 1: Training and Test Loss
|
145 |
+
st.sidebar.markdown("**Loss over Epochs**")
|
146 |
+
st.sidebar.line_chart(
|
147 |
+
history_df.set_index('Epoch')[['Train Loss', 'Test Loss']]
|
148 |
+
)
|
149 |
+
st.sidebar.caption("Lower loss indicates better model performance.")
|
150 |
+
|
151 |
+
# Plot 2: Character Error Rate (CER)
|
152 |
+
st.sidebar.markdown("**Character Error Rate (CER) over Epochs**")
|
153 |
+
st.sidebar.line_chart(
|
154 |
+
history_df.set_index('Epoch')[['Test CER (%)']]
|
155 |
+
)
|
156 |
+
st.sidebar.caption("Lower CER indicates fewer character errors (0% is perfect).")
|
157 |
+
|
158 |
+
# Plot 3: Exact Match Accuracy
|
159 |
+
st.sidebar.markdown("**Exact Match Accuracy over Epochs**")
|
160 |
+
st.sidebar.line_chart(
|
161 |
+
history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']]
|
162 |
+
)
|
163 |
+
st.sidebar.caption("Higher exact match accuracy indicates more perfectly recognized names.")
|
164 |
+
|
165 |
+
# Update the global model instance to the newly trained one for immediate inference
|
166 |
+
ocr_model = trained_ocr_model
|
167 |
+
ocr_model.eval()
|
168 |
+
|
169 |
+
except Exception as e:
|
170 |
+
status_text.error(f"An error occurred during training: {e}")
|
171 |
+
st.sidebar.text(traceback.format_exc()) # Show full traceback for debugging
|
172 |
+
|
173 |
+
# --- Main Content: Name Prediction ---
|
174 |
+
st.header("Predict Your Handwritten Name")
|
175 |
+
st.markdown("Upload a clear image of a single handwritten name or word.")
|
176 |
+
|
177 |
+
uploaded_file = st.file_uploader("🖼️ Choose an image...", type=["png", "jpg", "jpeg"])
|
178 |
+
|
179 |
+
if uploaded_file is not None:
|
180 |
+
try:
|
181 |
+
# Open the uploaded image
|
182 |
+
image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale
|
183 |
+
st.image(image_pil, caption="Uploaded Image", use_column_width=True)
|
184 |
+
st.write("---")
|
185 |
+
st.write("Processing and Recognizing...")
|
186 |
+
|
187 |
+
# Preprocess the image for the model using utils_ocr function
|
188 |
+
processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
|
189 |
+
|
190 |
+
# Make prediction
|
191 |
+
ocr_model.eval() # Ensure model is in evaluation mode
|
192 |
+
with torch.no_grad(): # Disable gradient calculation for inference
|
193 |
+
output = ocr_model(processed_image_tensor) # (sequence_length, batch_size, num_classes)
|
194 |
+
|
195 |
+
# ctc_greedy_decode expects (sequence_length, batch_size, num_classes)
|
196 |
+
# It returns a list of strings, so get the first element for single image inference.
|
197 |
+
predicted_texts = ctc_greedy_decode(output, char_indexer)
|
198 |
+
predicted_text = predicted_texts[0] # Get the first (and only) prediction
|
199 |
+
|
200 |
+
st.success(f"Recognized Text: **{predicted_text}**")
|
201 |
+
|
202 |
+
except Exception as e:
|
203 |
+
st.error(f"Error processing image or recognizing text: {e}")
|
204 |
+
st.info("💡 **Tips for best results:**\n"
|
205 |
+
"- Ensure the handwritten text is clear and on a clean background.\n"
|
206 |
+
"- Only include one name/word per image.\n"
|
207 |
+
"- The model is trained on specific characters. Unusual symbols might not be recognized.")
|
208 |
+
st.text(traceback.format_exc())
|
209 |
+
|
210 |
+
st.markdown("""
|
211 |
+
---
|
212 |
+
*Built using Streamlit, PyTorch, OpenCV, and EditDistance ©2025 by MFT*
|
213 |
+
=======
|
214 |
+
# app.py
|
215 |
+
|
216 |
+
import streamlit as st
|
217 |
+
import pandas as pd
|
218 |
+
import numpy as np
|
219 |
+
from PIL import Image
|
220 |
+
import torch
|
221 |
+
import torch.nn.functional as F # Added F for log_softmax in inference
|
222 |
+
import torchvision.transforms as transforms
|
223 |
+
import os
|
224 |
+
import traceback # For detailed error logging
|
225 |
+
|
226 |
+
# Import custom modules
|
227 |
+
from config import CHARS, BLANK_TOKEN, IMG_HEIGHT, TRAIN_CSV_PATH, TEST_CSV_PATH, \
|
228 |
+
TRAIN_IMAGES_DIR, TEST_IMAGES_DIR, MODEL_SAVE_PATH, NUM_CLASSES, NUM_EPOCHS, BATCH_SIZE
|
229 |
+
from data_handler_ocr import CharIndexer, OCRDataset
|
230 |
+
from model_ocr import CRNN, train_ocr_model, save_ocr_model, load_ocr_model, ctc_greedy_decode
|
231 |
+
from utils_ocr import preprocess_user_image_for_ocr
|
232 |
+
|
233 |
+
# --- Streamlit App Setup ---
|
234 |
+
st.set_page_config(page_title="Handwritten Name Recognizer", layout="centered")
|
235 |
+
|
236 |
+
st.title("📝 Handwritten Name Recognition (OCR)")
|
237 |
+
st.markdown("""
|
238 |
+
This application uses a Convolutional Recurrent Neural Network (CRNN) to perform
|
239 |
+
Optical Character Recognition (OCR) on handwritten names. You can upload an image
|
240 |
+
of a handwritten name for prediction or train a new model using the provided dataset.
|
241 |
+
|
242 |
+
**Note:** Training a robust OCR model can be time-consuming.
|
243 |
+
""")
|
244 |
+
|
245 |
+
# --- Initialize CharIndexer ---
|
246 |
+
# The CHARS variable should contain all possible characters your model can recognize.
|
247 |
+
# Make sure it's comprehensive based on your dataset.
|
248 |
+
char_indexer = CharIndexer(CHARS, BLANK_TOKEN)
|
249 |
+
# For robustness, it's best to always use char_indexer.num_classes
|
250 |
+
# If NUM_CLASSES from config is used to initialize CRNN, ensure it matches char_indexer.num_classes
|
251 |
+
|
252 |
+
# --- Model Loading / Initialization ---
|
253 |
+
@st.cache_resource # Cache the model to prevent reloading on every rerun
|
254 |
+
def get_and_load_ocr_model_cached(num_classes, model_path):
|
255 |
+
"""
|
256 |
+
Initializes the OCR model and attempts to load a pre-trained model.
|
257 |
+
If no pre-trained model exists, a new model instance is returned.
|
258 |
+
"""
|
259 |
+
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
|
260 |
+
|
261 |
+
if os.path.exists(model_path):
|
262 |
+
st.sidebar.info("Loading pre-trained OCR model...")
|
263 |
+
try:
|
264 |
+
# Load model to CPU first, then move to device
|
265 |
+
model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
266 |
+
st.sidebar.success("OCR model loaded successfully!")
|
267 |
+
except Exception as e:
|
268 |
+
st.sidebar.error(f"Error loading model: {e}. A new model will be initialized.")
|
269 |
+
# If loading fails, re-initialize an untrained model
|
270 |
+
model_instance = CRNN(num_classes=num_classes, cnn_output_channels=512, rnn_hidden_size=256, rnn_num_layers=2)
|
271 |
+
else:
|
272 |
+
st.sidebar.warning("No pre-trained OCR model found. Please train a model using the sidebar option.")
|
273 |
+
|
274 |
+
return model_instance
|
275 |
+
|
276 |
+
# Get the model instance
|
277 |
+
ocr_model = get_and_load_ocr_model_cached(char_indexer.num_classes, MODEL_SAVE_PATH)
|
278 |
+
# Determine the device (GPU if available, else CPU)
|
279 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
280 |
+
ocr_model.to(device)
|
281 |
+
ocr_model.eval() # Set model to evaluation mode for inference by default
|
282 |
+
|
283 |
+
# --- Sidebar for Model Training ---
|
284 |
+
st.sidebar.header("Model Training (Optional)")
|
285 |
+
st.sidebar.markdown("If you want to train a new model or no model is found:")
|
286 |
+
|
287 |
+
# Initialize Streamlit widgets outside the button block
|
288 |
+
training_progress_bar = st.sidebar.empty() # Placeholder for progress bar
|
289 |
+
status_text = st.sidebar.empty() # Placeholder for status messages
|
290 |
+
|
291 |
+
if st.sidebar.button("📊 Train New OCR Model"):
|
292 |
+
# Clear previous messages/widgets if button is clicked again
|
293 |
+
training_progress_bar.empty()
|
294 |
+
status_text.empty()
|
295 |
+
|
296 |
+
# Check for existence of CSVs and image directories
|
297 |
+
if not os.path.exists(TRAIN_CSV_PATH) or not os.path.exists(TEST_CSV_PATH) or \
|
298 |
+
not os.path.isdir(TRAIN_IMAGES_DIR) or not os.path.isdir(TEST_IMAGES_DIR):
|
299 |
+
status_text.error(f"""Dataset files or image directories not found.
|
300 |
+
Please ensure '{TRAIN_CSV_PATH}', '{TEST_CSV_PATH}', and directories '{TRAIN_IMAGES_DIR}'
|
301 |
+
and '{TEST_IMAGES_DIR}' exist. Refer to your project structure.""")
|
302 |
+
else:
|
303 |
+
status_text.write(f"Training a new CRNN model for {NUM_EPOCHS} epochs. This will take significant time...")
|
304 |
+
|
305 |
+
training_progress_bar_instance = training_progress_bar.progress(0.0, text="Training in progress. Please wait.")
|
306 |
+
|
307 |
+
try:
|
308 |
+
train_df = pd.read_csv(TRAIN_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
|
309 |
+
test_df = pd.read_csv(TEST_CSV_PATH, delimiter=';', names=['FILENAME', 'IDENTITY'], header=None)
|
310 |
+
|
311 |
+
# Define standard image transforms for consistency
|
312 |
+
train_transform = transforms.Compose([
|
313 |
+
transforms.Resize((IMG_HEIGHT, 100)), # Resize to fixed height, width will be 100 (adjust as needed for variable width)
|
314 |
+
transforms.ToTensor(), # Converts PIL Image to PyTorch Tensor (H, W) -> (C, H, W), normalizes to [0,1]
|
315 |
+
])
|
316 |
+
test_transform = transforms.Compose([
|
317 |
+
transforms.Resize((IMG_HEIGHT, 100)), # Same transformation as train
|
318 |
+
transforms.ToTensor(),
|
319 |
+
])
|
320 |
+
|
321 |
+
# Create dataset instances
|
322 |
+
train_dataset = OCRDataset(dataframe=train_df, char_indexer=char_indexer, image_dir=TRAIN_IMAGES_DIR, transform=train_transform)
|
323 |
+
test_dataset = OCRDataset(dataframe=test_df, char_indexer=char_indexer, image_dir=TEST_IMAGES_DIR, transform=test_transform)
|
324 |
+
|
325 |
+
# Create DataLoader instances
|
326 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) # num_workers=0 for Windows
|
327 |
+
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
|
328 |
+
|
329 |
+
# Train the model, passing the progress callback
|
330 |
+
trained_ocr_model, training_history = train_ocr_model(
|
331 |
+
ocr_model, # Pass the initialized model instance
|
332 |
+
train_loader,
|
333 |
+
test_loader,
|
334 |
+
char_indexer, # Pass char_indexer for CER calculation
|
335 |
+
epochs=NUM_EPOCHS,
|
336 |
+
device=device,
|
337 |
+
progress_callback=training_progress_bar_instance.progress # Pass the instance's progress method
|
338 |
+
)
|
339 |
+
|
340 |
+
# Ensure the directory for saving the model exists
|
341 |
+
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
|
342 |
+
save_ocr_model(trained_ocr_model, MODEL_SAVE_PATH)
|
343 |
+
status_text.success(f"Model training complete and saved to `{MODEL_SAVE_PATH}`!")
|
344 |
+
|
345 |
+
# Display training history chart
|
346 |
+
st.sidebar.subheader("Training History Plots")
|
347 |
+
|
348 |
+
history_df = pd.DataFrame({
|
349 |
+
'Epoch': range(1, len(training_history['train_loss']) + 1),
|
350 |
+
'Train Loss': training_history['train_loss'],
|
351 |
+
'Test Loss': training_history['test_loss'],
|
352 |
+
'Test CER (%)': [cer * 100 for cer in training_history['test_cer']], # Convert CER to percentage for display
|
353 |
+
'Test Exact Match Accuracy (%)': [acc * 100 for acc in training_history['test_exact_match_accuracy']] # Convert to percentage
|
354 |
+
})
|
355 |
+
|
356 |
+
# Plot 1: Training and Test Loss
|
357 |
+
st.sidebar.markdown("**Loss over Epochs**")
|
358 |
+
st.sidebar.line_chart(
|
359 |
+
history_df.set_index('Epoch')[['Train Loss', 'Test Loss']]
|
360 |
+
)
|
361 |
+
st.sidebar.caption("Lower loss indicates better model performance.")
|
362 |
+
|
363 |
+
# Plot 2: Character Error Rate (CER)
|
364 |
+
st.sidebar.markdown("**Character Error Rate (CER) over Epochs**")
|
365 |
+
st.sidebar.line_chart(
|
366 |
+
history_df.set_index('Epoch')[['Test CER (%)']]
|
367 |
+
)
|
368 |
+
st.sidebar.caption("Lower CER indicates fewer character errors (0% is perfect).")
|
369 |
+
|
370 |
+
# Plot 3: Exact Match Accuracy
|
371 |
+
st.sidebar.markdown("**Exact Match Accuracy over Epochs**")
|
372 |
+
st.sidebar.line_chart(
|
373 |
+
history_df.set_index('Epoch')[['Test Exact Match Accuracy (%)']]
|
374 |
+
)
|
375 |
+
st.sidebar.caption("Higher exact match accuracy indicates more perfectly recognized names.")
|
376 |
+
|
377 |
+
# Update the global model instance to the newly trained one for immediate inference
|
378 |
+
ocr_model = trained_ocr_model
|
379 |
+
ocr_model.eval()
|
380 |
+
|
381 |
+
except Exception as e:
|
382 |
+
status_text.error(f"An error occurred during training: {e}")
|
383 |
+
st.sidebar.text(traceback.format_exc()) # Show full traceback for debugging
|
384 |
+
|
385 |
+
# --- Main Content: Name Prediction ---
|
386 |
+
st.header("Predict Your Handwritten Name")
|
387 |
+
st.markdown("Upload a clear image of a single handwritten name or word.")
|
388 |
+
|
389 |
+
uploaded_file = st.file_uploader("🖼️ Choose an image...", type=["png", "jpg", "jpeg"])
|
390 |
+
|
391 |
+
if uploaded_file is not None:
|
392 |
+
try:
|
393 |
+
# Open the uploaded image
|
394 |
+
image_pil = Image.open(uploaded_file).convert('L') # Ensure grayscale
|
395 |
+
st.image(image_pil, caption="Uploaded Image", use_column_width=True)
|
396 |
+
st.write("---")
|
397 |
+
st.write("Processing and Recognizing...")
|
398 |
+
|
399 |
+
# Preprocess the image for the model using utils_ocr function
|
400 |
+
processed_image_tensor = preprocess_user_image_for_ocr(image_pil, IMG_HEIGHT).to(device)
|
401 |
+
|
402 |
+
# Make prediction
|
403 |
+
ocr_model.eval() # Ensure model is in evaluation mode
|
404 |
+
with torch.no_grad(): # Disable gradient calculation for inference
|
405 |
+
output = ocr_model(processed_image_tensor) # (sequence_length, batch_size, num_classes)
|
406 |
+
|
407 |
+
# ctc_greedy_decode expects (sequence_length, batch_size, num_classes)
|
408 |
+
# It returns a list of strings, so get the first element for single image inference.
|
409 |
+
predicted_texts = ctc_greedy_decode(output, char_indexer)
|
410 |
+
predicted_text = predicted_texts[0] # Get the first (and only) prediction
|
411 |
+
|
412 |
+
st.success(f"Recognized Text: **{predicted_text}**")
|
413 |
+
|
414 |
+
except Exception as e:
|
415 |
+
st.error(f"Error processing image or recognizing text: {e}")
|
416 |
+
st.info("💡 **Tips for best results:**\n"
|
417 |
+
"- Ensure the handwritten text is clear and on a clean background.\n"
|
418 |
+
"- Only include one name/word per image.\n"
|
419 |
+
"- The model is trained on specific characters. Unusual symbols might not be recognized.")
|
420 |
+
st.text(traceback.format_exc())
|
421 |
+
|
422 |
+
st.markdown("""
|
423 |
+
---
|
424 |
+
*Built using Streamlit, PyTorch, OpenCV, and EditDistance ©2025 by MFT*
|
425 |
+
>>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
|
426 |
+
""")
|
config.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<<<<<<< HEAD
|
2 |
+
# config.py
|
3 |
+
|
4 |
+
import os
|
5 |
+
|
6 |
+
# --- Paths ---
|
7 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
8 |
+
DATA_DIR = os.path.join(BASE_DIR, 'data')
|
9 |
+
MODELS_DIR = os.path.join(BASE_DIR, 'models')
|
10 |
+
|
11 |
+
TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'images', 'train')
|
12 |
+
TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'images', 'test')
|
13 |
+
|
14 |
+
TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
|
15 |
+
TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
|
16 |
+
|
17 |
+
MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'handwritten_name_ocr_model.pth')
|
18 |
+
|
19 |
+
# --- Character Set and OCR Configuration ---
|
20 |
+
# This character set MUST cover all characters present in your dataset.
|
21 |
+
# Add any special characters if needed.
|
22 |
+
# The order here is crucial as it defines the indices for your characters.
|
23 |
+
CHARS = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
|
24 |
+
|
25 |
+
# Define the character for the blank token. It MUST NOT be in CHARS.
|
26 |
+
BLANK_TOKEN_SYMBOL = 'Þ'
|
27 |
+
|
28 |
+
# Construct the full vocabulary string. It's conventional to put the blank token last.
|
29 |
+
# This VOCABULARY string is what you pass to CharIndexer.
|
30 |
+
VOCABULARY = CHARS + BLANK_TOKEN_SYMBOL
|
31 |
+
|
32 |
+
# NUM_CLASSES is the total number of unique symbols in the vocabulary, including the blank.
|
33 |
+
NUM_CLASSES = len(VOCABULARY)
|
34 |
+
|
35 |
+
# BLANK_TOKEN is the actual index of the blank symbol within the VOCABULARY.
|
36 |
+
# Since we appended it last, its index will be len(CHARS).
|
37 |
+
BLANK_TOKEN = VOCABULARY.find(BLANK_TOKEN_SYMBOL)
|
38 |
+
|
39 |
+
# --- Sanity Checks (Highly Recommended) ---
|
40 |
+
if BLANK_TOKEN == -1:
|
41 |
+
raise ValueError(f"Error: BLANK_TOKEN_SYMBOL '{BLANK_TOKEN_SYMBOL}' not found in VOCABULARY. Check config.py definitions.")
|
42 |
+
if BLANK_TOKEN >= NUM_CLASSES:
|
43 |
+
raise ValueError(f"Error: BLANK_TOKEN index ({BLANK_TOKEN}) must be less than NUM_CLASSES ({NUM_CLASSES}).")
|
44 |
+
|
45 |
+
print(f"Config Loaded: NUM_CLASSES={NUM_CLASSES}, BLANK_TOKEN_INDEX={BLANK_TOKEN}")
|
46 |
+
print(f"Vocabulary Length: {len(VOCABULARY)}")
|
47 |
+
print(f"Blank Symbol: '{BLANK_TOKEN_SYMBOL}' at index {BLANK_TOKEN}")
|
48 |
+
|
49 |
+
|
50 |
+
# --- Image Preprocessing Parameters ---
|
51 |
+
IMG_HEIGHT = 32
|
52 |
+
|
53 |
+
# --- Training Parameters ---
|
54 |
+
BATCH_SIZE = 64
|
55 |
+
LEARNING_RATE = 0.001
|
56 |
+
=======
|
57 |
+
# config.py
|
58 |
+
|
59 |
+
import os
|
60 |
+
|
61 |
+
# --- Paths ---
|
62 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
63 |
+
DATA_DIR = os.path.join(BASE_DIR, 'data')
|
64 |
+
MODELS_DIR = os.path.join(BASE_DIR, 'models')
|
65 |
+
|
66 |
+
TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'images', 'train')
|
67 |
+
TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'images', 'test')
|
68 |
+
|
69 |
+
TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
|
70 |
+
TEST_CSV_PATH = os.path.join(DATA_DIR, 'test.csv')
|
71 |
+
|
72 |
+
MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'handwritten_name_ocr_model.pth')
|
73 |
+
|
74 |
+
# --- Character Set and OCR Configuration ---
|
75 |
+
# This character set MUST cover all characters present in your dataset.
|
76 |
+
# Add any special characters if needed.
|
77 |
+
# The order here is crucial as it defines the indices for your characters.
|
78 |
+
CHARS = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
|
79 |
+
|
80 |
+
# Define the character for the blank token. It MUST NOT be in CHARS.
|
81 |
+
BLANK_TOKEN_SYMBOL = 'Þ'
|
82 |
+
|
83 |
+
# Construct the full vocabulary string. It's conventional to put the blank token last.
|
84 |
+
# This VOCABULARY string is what you pass to CharIndexer.
|
85 |
+
VOCABULARY = CHARS + BLANK_TOKEN_SYMBOL
|
86 |
+
|
87 |
+
# NUM_CLASSES is the total number of unique symbols in the vocabulary, including the blank.
|
88 |
+
NUM_CLASSES = len(VOCABULARY)
|
89 |
+
|
90 |
+
# BLANK_TOKEN is the actual index of the blank symbol within the VOCABULARY.
|
91 |
+
# Since we appended it last, its index will be len(CHARS).
|
92 |
+
BLANK_TOKEN = VOCABULARY.find(BLANK_TOKEN_SYMBOL)
|
93 |
+
|
94 |
+
# --- Sanity Checks (Highly Recommended) ---
|
95 |
+
if BLANK_TOKEN == -1:
|
96 |
+
raise ValueError(f"Error: BLANK_TOKEN_SYMBOL '{BLANK_TOKEN_SYMBOL}' not found in VOCABULARY. Check config.py definitions.")
|
97 |
+
if BLANK_TOKEN >= NUM_CLASSES:
|
98 |
+
raise ValueError(f"Error: BLANK_TOKEN index ({BLANK_TOKEN}) must be less than NUM_CLASSES ({NUM_CLASSES}).")
|
99 |
+
|
100 |
+
print(f"Config Loaded: NUM_CLASSES={NUM_CLASSES}, BLANK_TOKEN_INDEX={BLANK_TOKEN}")
|
101 |
+
print(f"Vocabulary Length: {len(VOCABULARY)}")
|
102 |
+
print(f"Blank Symbol: '{BLANK_TOKEN_SYMBOL}' at index {BLANK_TOKEN}")
|
103 |
+
|
104 |
+
|
105 |
+
# --- Image Preprocessing Parameters ---
|
106 |
+
IMG_HEIGHT = 32
|
107 |
+
|
108 |
+
# --- Training Parameters ---
|
109 |
+
BATCH_SIZE = 64
|
110 |
+
LEARNING_RATE = 0.001
|
111 |
+
>>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
|
112 |
+
NUM_EPOCHS = 3
|
data_handler_ocr.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<<<<<<< HEAD
|
2 |
+
#data_handler_ocr.py
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset, DataLoader
|
7 |
+
from torchvision import transforms
|
8 |
+
import os
|
9 |
+
from PIL import Image
|
10 |
+
import numpy as np
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
# Import utility functions and config
|
14 |
+
from config import CHARS, BLANK_TOKEN, IMG_HEIGHT, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR
|
15 |
+
from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
|
16 |
+
|
17 |
+
class CharIndexer:
|
18 |
+
"""Manages character-to-index and index-to-character mappings."""
|
19 |
+
def __init__(self, chars: str, blank_token: str):
|
20 |
+
self.char_to_idx = {char: i for i, char in enumerate(chars)}
|
21 |
+
self.idx_to_char = {i: char for i, char in enumerate(chars)}
|
22 |
+
self.blank_token_idx = len(chars) # Index for the blank token
|
23 |
+
self.idx_to_char[self.blank_token_idx] = blank_token # Add blank token to idx_to_char
|
24 |
+
self.num_classes = len(chars) + 1 # Total classes including blank
|
25 |
+
|
26 |
+
def encode(self, text: str) -> list[int]:
|
27 |
+
"""Converts a text string to a list of integer indices."""
|
28 |
+
return [self.char_to_idx[char] for char in text]
|
29 |
+
|
30 |
+
def decode(self, indices: list[int]) -> str:
|
31 |
+
"""Converts a list of integer indices back to a text string."""
|
32 |
+
# CTC decoding often produces repeated characters and blank tokens.
|
33 |
+
# This simple decoder removes blanks and duplicates.
|
34 |
+
decoded_text = []
|
35 |
+
for i, idx in enumerate(indices):
|
36 |
+
if idx == self.blank_token_idx:
|
37 |
+
continue
|
38 |
+
# Remove consecutive duplicates
|
39 |
+
if i > 0 and indices[i-1] == idx:
|
40 |
+
continue
|
41 |
+
decoded_text.append(self.idx_to_char[idx])
|
42 |
+
return "".join(decoded_text)
|
43 |
+
|
44 |
+
class OCRDataset(Dataset):
|
45 |
+
"""
|
46 |
+
Custom PyTorch Dataset for the Handwritten Name Recognition task.
|
47 |
+
Loads images and their corresponding text labels.
|
48 |
+
"""
|
49 |
+
def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
|
50 |
+
"""
|
51 |
+
Initializes the OCR Dataset.
|
52 |
+
Args:
|
53 |
+
dataframe (pd.DataFrame): A DataFrame containing 'image_path' and 'label' columns.
|
54 |
+
char_indexer (CharIndexer): An instance of CharIndexer for character encoding.
|
55 |
+
transform (callable, optional): Optional transform to be applied on an image.
|
56 |
+
"""
|
57 |
+
self.data = dataframe
|
58 |
+
self.char_indexer = char_indexer
|
59 |
+
self.image_dir = image_dir
|
60 |
+
self.transform = transform
|
61 |
+
|
62 |
+
|
63 |
+
def __len__(self) -> int:
|
64 |
+
return len(self.data)
|
65 |
+
|
66 |
+
def __getitem__(self, idx):
|
67 |
+
raw_filename_entry = self.data.iloc[idx]['FILENAME']
|
68 |
+
ground_truth_text = self.data.iloc[idx]['IDENTITY']
|
69 |
+
|
70 |
+
filename = raw_filename_entry.split(',')[0].strip() # .strip() removes any whitespace
|
71 |
+
# Construct the full image path
|
72 |
+
img_path = os.path.join(self.image_dir, filename)
|
73 |
+
# Ensure ground_truth_text is a string
|
74 |
+
ground_truth_text = str(ground_truth_text)
|
75 |
+
|
76 |
+
# Load and transform image
|
77 |
+
try:
|
78 |
+
image = Image.open(img_path).convert('L') # Convert to grayscale
|
79 |
+
except FileNotFoundError:
|
80 |
+
print(f"Error: Image file not found at {img_path}. Skipping this item.")
|
81 |
+
raise # Re-raise to let the main traceback be seen.
|
82 |
+
|
83 |
+
if self.transform:
|
84 |
+
image = self.transform(image)
|
85 |
+
|
86 |
+
image_width = image.size(2) # Assuming image is a tensor (C, H, W) after transform
|
87 |
+
|
88 |
+
text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
|
89 |
+
text_length = len(text_encoded)
|
90 |
+
|
91 |
+
return image, text_encoded, image_width, text_length
|
92 |
+
|
93 |
+
def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
94 |
+
"""
|
95 |
+
Custom collate function for the DataLoader to handle variable-width images
|
96 |
+
and variable-length text sequences for CTC loss.
|
97 |
+
"""
|
98 |
+
images, texts, image_widths, text_lengths = zip(*batch)
|
99 |
+
|
100 |
+
# Pad images to the maximum width in the current batch
|
101 |
+
max_batch_width = max(image_widths)
|
102 |
+
padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
|
103 |
+
images_batch = torch.stack(padded_images, 0) # Stack to (N, C, H, max_W)
|
104 |
+
|
105 |
+
# Concatenate all text sequences and get their lengths
|
106 |
+
texts_batch = torch.cat(texts, 0)
|
107 |
+
text_lengths_tensor = torch.tensor(text_lengths, dtype=torch.long)
|
108 |
+
image_widths_tensor = torch.tensor(image_widths, dtype=torch.long) # Actual widths
|
109 |
+
|
110 |
+
return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
|
111 |
+
|
112 |
+
|
113 |
+
def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
|
114 |
+
"""
|
115 |
+
Loads training and testing dataframes.
|
116 |
+
Assumes CSVs have 'filename' and 'name' columns.
|
117 |
+
"""
|
118 |
+
train_df = pd.read_csv(train_csv_path)
|
119 |
+
test_df = pd.read_csv(test_csv_path)
|
120 |
+
return train_df, test_df
|
121 |
+
|
122 |
+
def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
|
123 |
+
char_indexer: CharIndexer, batch_size: int) -> tuple[DataLoader, DataLoader]:
|
124 |
+
"""
|
125 |
+
Creates PyTorch DataLoader objects for OCR training and testing datasets,
|
126 |
+
using specific image directories for train/test.
|
127 |
+
"""
|
128 |
+
train_dataset = OCRDataset(train_df, TRAIN_IMAGES_DIR, char_indexer)
|
129 |
+
test_dataset = OCRDataset(test_df, TEST_IMAGES_DIR, char_indexer)
|
130 |
+
|
131 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
|
132 |
+
num_workers=0, collate_fn=ocr_collate_fn)
|
133 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
|
134 |
+
num_workers=0, collate_fn=ocr_collate_fn)
|
135 |
+
=======
|
136 |
+
#data_handler_ocr.py
|
137 |
+
|
138 |
+
import pandas as pd
|
139 |
+
import torch
|
140 |
+
from torch.utils.data import Dataset, DataLoader
|
141 |
+
from torchvision import transforms
|
142 |
+
import os
|
143 |
+
from PIL import Image
|
144 |
+
import numpy as np
|
145 |
+
import torch.nn.functional as F
|
146 |
+
|
147 |
+
# Import utility functions and config
|
148 |
+
from config import CHARS, BLANK_TOKEN, IMG_HEIGHT, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR
|
149 |
+
from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
|
150 |
+
|
151 |
+
class CharIndexer:
|
152 |
+
"""Manages character-to-index and index-to-character mappings."""
|
153 |
+
def __init__(self, chars: str, blank_token: str):
|
154 |
+
self.char_to_idx = {char: i for i, char in enumerate(chars)}
|
155 |
+
self.idx_to_char = {i: char for i, char in enumerate(chars)}
|
156 |
+
self.blank_token_idx = len(chars) # Index for the blank token
|
157 |
+
self.idx_to_char[self.blank_token_idx] = blank_token # Add blank token to idx_to_char
|
158 |
+
self.num_classes = len(chars) + 1 # Total classes including blank
|
159 |
+
|
160 |
+
def encode(self, text: str) -> list[int]:
|
161 |
+
"""Converts a text string to a list of integer indices."""
|
162 |
+
return [self.char_to_idx[char] for char in text]
|
163 |
+
|
164 |
+
def decode(self, indices: list[int]) -> str:
|
165 |
+
"""Converts a list of integer indices back to a text string."""
|
166 |
+
# CTC decoding often produces repeated characters and blank tokens.
|
167 |
+
# This simple decoder removes blanks and duplicates.
|
168 |
+
decoded_text = []
|
169 |
+
for i, idx in enumerate(indices):
|
170 |
+
if idx == self.blank_token_idx:
|
171 |
+
continue
|
172 |
+
# Remove consecutive duplicates
|
173 |
+
if i > 0 and indices[i-1] == idx:
|
174 |
+
continue
|
175 |
+
decoded_text.append(self.idx_to_char[idx])
|
176 |
+
return "".join(decoded_text)
|
177 |
+
|
178 |
+
class OCRDataset(Dataset):
|
179 |
+
"""
|
180 |
+
Custom PyTorch Dataset for the Handwritten Name Recognition task.
|
181 |
+
Loads images and their corresponding text labels.
|
182 |
+
"""
|
183 |
+
def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
|
184 |
+
"""
|
185 |
+
Initializes the OCR Dataset.
|
186 |
+
Args:
|
187 |
+
dataframe (pd.DataFrame): A DataFrame containing 'image_path' and 'label' columns.
|
188 |
+
char_indexer (CharIndexer): An instance of CharIndexer for character encoding.
|
189 |
+
transform (callable, optional): Optional transform to be applied on an image.
|
190 |
+
"""
|
191 |
+
self.data = dataframe
|
192 |
+
self.char_indexer = char_indexer
|
193 |
+
self.image_dir = image_dir
|
194 |
+
self.transform = transform
|
195 |
+
|
196 |
+
|
197 |
+
def __len__(self) -> int:
|
198 |
+
return len(self.data)
|
199 |
+
|
200 |
+
def __getitem__(self, idx):
|
201 |
+
raw_filename_entry = self.data.iloc[idx]['FILENAME']
|
202 |
+
ground_truth_text = self.data.iloc[idx]['IDENTITY']
|
203 |
+
|
204 |
+
filename = raw_filename_entry.split(',')[0].strip() # .strip() removes any whitespace
|
205 |
+
# Construct the full image path
|
206 |
+
img_path = os.path.join(self.image_dir, filename)
|
207 |
+
# Ensure ground_truth_text is a string
|
208 |
+
ground_truth_text = str(ground_truth_text)
|
209 |
+
|
210 |
+
# Load and transform image
|
211 |
+
try:
|
212 |
+
image = Image.open(img_path).convert('L') # Convert to grayscale
|
213 |
+
except FileNotFoundError:
|
214 |
+
print(f"Error: Image file not found at {img_path}. Skipping this item.")
|
215 |
+
raise # Re-raise to let the main traceback be seen.
|
216 |
+
|
217 |
+
if self.transform:
|
218 |
+
image = self.transform(image)
|
219 |
+
|
220 |
+
image_width = image.size(2) # Assuming image is a tensor (C, H, W) after transform
|
221 |
+
|
222 |
+
text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
|
223 |
+
text_length = len(text_encoded)
|
224 |
+
|
225 |
+
return image, text_encoded, image_width, text_length
|
226 |
+
|
227 |
+
def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
228 |
+
"""
|
229 |
+
Custom collate function for the DataLoader to handle variable-width images
|
230 |
+
and variable-length text sequences for CTC loss.
|
231 |
+
"""
|
232 |
+
images, texts, image_widths, text_lengths = zip(*batch)
|
233 |
+
|
234 |
+
# Pad images to the maximum width in the current batch
|
235 |
+
max_batch_width = max(image_widths)
|
236 |
+
padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
|
237 |
+
images_batch = torch.stack(padded_images, 0) # Stack to (N, C, H, max_W)
|
238 |
+
|
239 |
+
# Concatenate all text sequences and get their lengths
|
240 |
+
texts_batch = torch.cat(texts, 0)
|
241 |
+
text_lengths_tensor = torch.tensor(text_lengths, dtype=torch.long)
|
242 |
+
image_widths_tensor = torch.tensor(image_widths, dtype=torch.long) # Actual widths
|
243 |
+
|
244 |
+
return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
|
245 |
+
|
246 |
+
|
247 |
+
def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
|
248 |
+
"""
|
249 |
+
Loads training and testing dataframes.
|
250 |
+
Assumes CSVs have 'filename' and 'name' columns.
|
251 |
+
"""
|
252 |
+
train_df = pd.read_csv(train_csv_path)
|
253 |
+
test_df = pd.read_csv(test_csv_path)
|
254 |
+
return train_df, test_df
|
255 |
+
|
256 |
+
def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
|
257 |
+
char_indexer: CharIndexer, batch_size: int) -> tuple[DataLoader, DataLoader]:
|
258 |
+
"""
|
259 |
+
Creates PyTorch DataLoader objects for OCR training and testing datasets,
|
260 |
+
using specific image directories for train/test.
|
261 |
+
"""
|
262 |
+
train_dataset = OCRDataset(train_df, TRAIN_IMAGES_DIR, char_indexer)
|
263 |
+
test_dataset = OCRDataset(test_df, TEST_IMAGES_DIR, char_indexer)
|
264 |
+
|
265 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
|
266 |
+
num_workers=0, collate_fn=ocr_collate_fn)
|
267 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
|
268 |
+
num_workers=0, collate_fn=ocr_collate_fn)
|
269 |
+
>>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
|
270 |
+
return train_loader, test_loader
|
model_ocr.py
ADDED
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<<<<<<< HEAD
|
2 |
+
# model_ocr.py
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.optim as optim
|
8 |
+
from torch.utils.data import DataLoader # Keep DataLoader for type hinting
|
9 |
+
from tqdm import tqdm
|
10 |
+
from sklearn.metrics import accuracy_score
|
11 |
+
import editdistance
|
12 |
+
|
13 |
+
# Import config and char_indexer
|
14 |
+
# Ensure these imports align with your current config.py
|
15 |
+
from config import IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN
|
16 |
+
from data_handler_ocr import CharIndexer
|
17 |
+
# You might also need to import binarize_image, resize_image_for_ocr, normalize_image_for_model
|
18 |
+
# if they are used directly in model_ocr.py for internal preprocessing (e.g., in evaluate_model if not using DataLoader)
|
19 |
+
# For now, assuming they are handled by DataLoader transforms.
|
20 |
+
from utils_ocr import binarize_image, resize_image_for_ocr, normalize_image_for_model # Add this for completeness if needed elsewhere
|
21 |
+
|
22 |
+
|
23 |
+
class CNN_Backbone(nn.Module):
|
24 |
+
"""
|
25 |
+
CNN feature extractor for OCR. Designed to produce features suitable for RNN.
|
26 |
+
Output feature map should have height 1 after the final pooling/reduction.
|
27 |
+
"""
|
28 |
+
def __init__(self, input_channels=1, output_channels=512):
|
29 |
+
super(CNN_Backbone, self).__init__()
|
30 |
+
self.cnn = nn.Sequential(
|
31 |
+
# First block
|
32 |
+
nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
|
33 |
+
nn.ReLU(True),
|
34 |
+
nn.MaxPool2d(kernel_size=2, stride=2), # H: 32 -> 16, W: W_in -> W_in/2
|
35 |
+
|
36 |
+
# Second block
|
37 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
38 |
+
nn.ReLU(True),
|
39 |
+
nn.MaxPool2d(kernel_size=2, stride=2), # H: 16 -> 8, W: W_in/2 -> W_in/4
|
40 |
+
|
41 |
+
# Third block (with two conv layers)
|
42 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
43 |
+
nn.ReLU(True),
|
44 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
45 |
+
nn.ReLU(True),
|
46 |
+
# This MaxPool2d effectively brings height from 8 to 4, with a small width adjustment due to padding
|
47 |
+
# The original comment (W/4 + 1) is due to padding=1 and stride=1 on width, which is fine.
|
48 |
+
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # H: 8 -> 4, W: (W/4) -> (W/4 + 1) (approx)
|
49 |
+
|
50 |
+
# Fourth block
|
51 |
+
nn.Conv2d(256, output_channels, kernel_size=3, stride=1, padding=1),
|
52 |
+
nn.ReLU(True),
|
53 |
+
# This AdaptiveAvgPool2d makes sure the height dimension becomes 1
|
54 |
+
# while preserving the width. This is crucial for RNN input.
|
55 |
+
nn.AdaptiveAvgPool2d((1, None)) # Output height 1, preserve width
|
56 |
+
)
|
57 |
+
|
58 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
59 |
+
# x: (N, C, H, W) e.g., (B, 1, 32, W_img)
|
60 |
+
|
61 |
+
# Pass through the CNN layers
|
62 |
+
conv_features = self.cnn(x) # Output: (N, cnn_out_channels, 1, W_prime)
|
63 |
+
|
64 |
+
# Squeeze the height dimension (which is 1)
|
65 |
+
# This transforms (N, C_out, 1, W_prime) to (N, C_out, W_prime)
|
66 |
+
conv_features = conv_features.squeeze(2)
|
67 |
+
|
68 |
+
# Permute for RNN input: (sequence_length, batch_size, input_size)
|
69 |
+
# This transforms (N, C_out, W_prime) to (W_prime, N, C_out)
|
70 |
+
conv_features = conv_features.permute(2, 0, 1)
|
71 |
+
|
72 |
+
# Return the CNN features, ready for the RNN layer in CRNN
|
73 |
+
return conv_features
|
74 |
+
|
75 |
+
class BidirectionalLSTM(nn.Module):
|
76 |
+
"""Bidirectional LSTM layer for sequence modeling."""
|
77 |
+
def __init__(self, input_size: int, hidden_size: int, num_layers: int, dropout: float = 0.5):
|
78 |
+
super(BidirectionalLSTM, self).__init__()
|
79 |
+
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
|
80 |
+
bidirectional=True, dropout=dropout, batch_first=False)
|
81 |
+
# batch_first=False expects input as (sequence_length, batch_size, input_size)
|
82 |
+
|
83 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
84 |
+
output, _ = self.lstm(x) # [0] returns the output, [1] returns (h_n, c_n)
|
85 |
+
return output
|
86 |
+
|
87 |
+
class CRNN(nn.Module):
|
88 |
+
"""
|
89 |
+
Convolutional Recurrent Neural Network for OCR.
|
90 |
+
Combines CNN for feature extraction, LSTMs for sequence modeling,
|
91 |
+
and a final linear layer for character prediction.
|
92 |
+
"""
|
93 |
+
def __init__(self, num_classes: int, cnn_output_channels: int = 512,
|
94 |
+
rnn_hidden_size: int = 256, rnn_num_layers: int = 2):
|
95 |
+
super(CRNN, self).__init__()
|
96 |
+
self.cnn = CNN_Backbone(output_channels=cnn_output_channels)
|
97 |
+
# Input to LSTM is the number of channels from the CNN output
|
98 |
+
self.rnn = BidirectionalLSTM(cnn_output_channels, rnn_hidden_size, rnn_num_layers)
|
99 |
+
# Output of bidirectional LSTM is hidden_size * 2
|
100 |
+
self.fc = nn.Linear(rnn_hidden_size * 2, num_classes) # Final linear layer for classes
|
101 |
+
|
102 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
103 |
+
# x: (N, C, H, W) e.g., (B, 1, 32, W_img)
|
104 |
+
|
105 |
+
# 1. Pass through the CNN to extract features
|
106 |
+
conv_features = self.cnn(x) # Output: (W_prime, N, C_out) after permute in CNN_Backbone
|
107 |
+
|
108 |
+
# 2. Pass CNN features through the RNN (LSTM)
|
109 |
+
rnn_features = self.rnn(conv_features) # Output: (W_prime, N, rnn_hidden_size * 2)
|
110 |
+
|
111 |
+
# 3. Pass RNN features through the final fully connected layer
|
112 |
+
# Apply the linear layer to each time step independently
|
113 |
+
# output will be (W_prime, N, num_classes)
|
114 |
+
output = self.fc(rnn_features)
|
115 |
+
|
116 |
+
return output
|
117 |
+
|
118 |
+
|
119 |
+
# --- Decoding Function ---
|
120 |
+
def ctc_greedy_decode(output: torch.Tensor, char_indexer: CharIndexer) -> list[str]:
|
121 |
+
"""
|
122 |
+
Performs greedy decoding on the CTC output.
|
123 |
+
output: (sequence_length, batch_size, num_classes) - raw logits
|
124 |
+
"""
|
125 |
+
# Apply log_softmax to get probabilities for argmax
|
126 |
+
log_probs = F.log_softmax(output, dim=2)
|
127 |
+
|
128 |
+
# Permute to (batch_size, sequence_length, num_classes) for argmax along class dim
|
129 |
+
# This gives us the index of the most probable character at each time step for each sample in the batch.
|
130 |
+
predicted_indices = torch.argmax(log_probs.permute(1, 0, 2), dim=2).cpu().numpy()
|
131 |
+
|
132 |
+
decoded_texts = []
|
133 |
+
for seq in predicted_indices:
|
134 |
+
# Use char_indexer's decode method, which handles blank removal and duplicate collapse
|
135 |
+
decoded_texts.append(char_indexer.decode(seq.tolist())) # Convert numpy array to list
|
136 |
+
return decoded_texts
|
137 |
+
|
138 |
+
# --- Evaluation Function ---
|
139 |
+
def evaluate_model(model: nn.Module, dataloader: DataLoader, char_indexer: CharIndexer, device: str):
|
140 |
+
model.eval() # Set model to evaluation mode
|
141 |
+
# CTCLoss needs the blank token index, which is available from char_indexer
|
142 |
+
criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
|
143 |
+
total_loss = 0
|
144 |
+
all_predictions = []
|
145 |
+
all_ground_truths = []
|
146 |
+
|
147 |
+
with torch.no_grad(): # Disable gradient calculation for evaluation
|
148 |
+
for inputs, targets_padded, _, target_lengths in tqdm(dataloader, desc="Evaluating"):
|
149 |
+
inputs = inputs.to(device)
|
150 |
+
targets_padded = targets_padded.to(device)
|
151 |
+
target_lengths = target_lengths.to(device)
|
152 |
+
|
153 |
+
output = model(inputs) # (seq_len, batch_size, num_classes)
|
154 |
+
|
155 |
+
# Calculate input_lengths for CTCLoss. This is the sequence length produced by the CNN/RNN.
|
156 |
+
# It's the `output.shape[0]` (sequence_length) for each item in the batch.
|
157 |
+
outputs_seq_len_for_ctc = torch.full(
|
158 |
+
size=(output.shape[1],), # batch_size
|
159 |
+
fill_value=output.shape[0], # actual sequence length (T) from model output
|
160 |
+
dtype=torch.long,
|
161 |
+
device=device
|
162 |
+
)
|
163 |
+
|
164 |
+
# CTC Loss calculation requires log_softmax on the output logits
|
165 |
+
log_probs_for_loss = F.log_softmax(output, dim=2) # (T, N, C)
|
166 |
+
|
167 |
+
loss = criterion(log_probs_for_loss, targets_padded, outputs_seq_len_for_ctc, target_lengths)
|
168 |
+
total_loss += loss.item() * inputs.size(0) # Multiply by batch size for correct average
|
169 |
+
|
170 |
+
# Decode predictions for metrics
|
171 |
+
decoded_preds = ctc_greedy_decode(output, char_indexer)
|
172 |
+
|
173 |
+
# Reconstruct ground truths from encoded tensors
|
174 |
+
ground_truths = []
|
175 |
+
# Loop through each sample in the batch
|
176 |
+
for i in range(targets_padded.size(0)):
|
177 |
+
# Extract the actual target sequence for the i-th sample using its length
|
178 |
+
# Convert to list before passing to char_indexer.decode
|
179 |
+
ground_truths.append(char_indexer.decode(targets_padded[i, :target_lengths[i]].tolist()))
|
180 |
+
|
181 |
+
all_predictions.extend(decoded_preds)
|
182 |
+
all_ground_truths.extend(ground_truths)
|
183 |
+
|
184 |
+
avg_loss = total_loss / len(dataloader.dataset)
|
185 |
+
|
186 |
+
# Calculate Character Error Rate (CER)
|
187 |
+
cer_sum = 0
|
188 |
+
total_chars = 0
|
189 |
+
for pred, gt in zip(all_predictions, all_ground_truths):
|
190 |
+
cer_sum += editdistance.eval(pred, gt)
|
191 |
+
total_chars += len(gt)
|
192 |
+
char_error_rate = cer_sum / total_chars if total_chars > 0 else 0.0
|
193 |
+
|
194 |
+
# Calculate Exact Match Accuracy (Word-level Accuracy)
|
195 |
+
exact_match_accuracy = accuracy_score(all_ground_truths, all_predictions)
|
196 |
+
|
197 |
+
return avg_loss, char_error_rate, exact_match_accuracy
|
198 |
+
|
199 |
+
# --- Training Function ---
|
200 |
+
def train_ocr_model(model: nn.Module, train_loader: DataLoader,
|
201 |
+
test_loader: DataLoader, char_indexer: CharIndexer,
|
202 |
+
epochs: int, device: str, progress_callback=None) -> tuple[nn.Module, dict]:
|
203 |
+
"""
|
204 |
+
Trains the OCR model using CTC loss.
|
205 |
+
"""
|
206 |
+
# CTCLoss needs the blank token index
|
207 |
+
criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
|
208 |
+
optimizer = optim.Adam(model.parameters(), lr=0.001) # Using a fixed LR for now
|
209 |
+
# Using ReduceLROnPlateau to adjust LR based on test loss (monitor 'min' loss)
|
210 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5)
|
211 |
+
|
212 |
+
model.to(device) # Ensure model is on the correct device
|
213 |
+
model.train() # Set model to training mode
|
214 |
+
|
215 |
+
training_history = {
|
216 |
+
'train_loss': [],
|
217 |
+
'test_loss': [],
|
218 |
+
'test_cer': [],
|
219 |
+
'test_exact_match_accuracy': []
|
220 |
+
}
|
221 |
+
|
222 |
+
for epoch in range(epochs):
|
223 |
+
running_loss = 0.0
|
224 |
+
pbar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)")
|
225 |
+
for images, texts_encoded, _, text_lengths in pbar_train:
|
226 |
+
images = images.to(device)
|
227 |
+
# Ensure target tensors are on the correct device for CTCLoss calculation
|
228 |
+
texts_encoded = texts_encoded.to(device)
|
229 |
+
text_lengths = text_lengths.to(device)
|
230 |
+
|
231 |
+
optimizer.zero_grad() # Clear gradients from previous step
|
232 |
+
outputs = model(images) # (sequence_length_from_cnn, batch_size, num_classes)
|
233 |
+
|
234 |
+
# `outputs.shape[0]` is the actual sequence length (T) produced by the model.
|
235 |
+
# CTC loss expects `input_lengths` to be a tensor of shape (batch_size,) with these values.
|
236 |
+
outputs_seq_len_for_ctc = torch.full(
|
237 |
+
size=(outputs.shape[1],), # batch_size
|
238 |
+
fill_value=outputs.shape[0], # actual sequence length (T) from model output
|
239 |
+
dtype=torch.long,
|
240 |
+
device=device
|
241 |
+
)
|
242 |
+
|
243 |
+
# CTC Loss calculation requires log_softmax on the output logits
|
244 |
+
log_probs_for_loss = F.log_softmax(outputs, dim=2) # (T, N, C)
|
245 |
+
|
246 |
+
# Use outputs_seq_len_for_ctc for the input_lengths argument
|
247 |
+
loss = criterion(log_probs_for_loss, texts_encoded, outputs_seq_len_for_ctc, text_lengths)
|
248 |
+
loss.backward() # Backpropagate
|
249 |
+
optimizer.step() # Update model weights
|
250 |
+
|
251 |
+
running_loss += loss.item() * images.size(0) # Multiply by batch size for correct average
|
252 |
+
pbar_train.set_postfix(loss=loss.item())
|
253 |
+
|
254 |
+
epoch_train_loss = running_loss / len(train_loader.dataset)
|
255 |
+
training_history['train_loss'].append(epoch_train_loss)
|
256 |
+
|
257 |
+
# Evaluate on test set using the dedicated function
|
258 |
+
# Ensure model is in eval mode before calling evaluate_model
|
259 |
+
model.eval()
|
260 |
+
test_loss, test_cer, test_exact_match_accuracy = evaluate_model(model, test_loader, char_indexer, device)
|
261 |
+
training_history['test_loss'].append(test_loss)
|
262 |
+
training_history['test_cer'].append(test_cer)
|
263 |
+
training_history['test_exact_match_accuracy'].append(test_exact_match_accuracy)
|
264 |
+
|
265 |
+
# Adjust learning rate based on test loss (this is where scheduler.step() is called)
|
266 |
+
scheduler.step(test_loss)
|
267 |
+
|
268 |
+
print(f"Epoch {epoch+1}/{epochs}: Train Loss={epoch_train_loss:.4f}, "
|
269 |
+
f"Test Loss={test_loss:.4f}, Test CER={test_cer:.4f}, Test Exact Match Acc={test_exact_match_accuracy:.4f}")
|
270 |
+
|
271 |
+
if progress_callback:
|
272 |
+
# Update progress bar with current epoch and key metrics
|
273 |
+
progress_val = (epoch + 1) / epochs
|
274 |
+
progress_callback(progress_val, text=f"Epoch {epoch+1}/{epochs} done. Test CER: {test_cer:.4f}, Test Exact Match Acc: {test_exact_match_accuracy:.4f}")
|
275 |
+
|
276 |
+
model.train() # Set model back to training mode after evaluation
|
277 |
+
|
278 |
+
return model, training_history
|
279 |
+
|
280 |
+
def save_ocr_model(model: nn.Module, path: str):
|
281 |
+
"""Saves the state dictionary of the trained OCR model."""
|
282 |
+
torch.save(model.state_dict(), path)
|
283 |
+
print(f"OCR model saved to {path}")
|
284 |
+
|
285 |
+
def load_ocr_model(model: nn.Module, path: str):
|
286 |
+
"""
|
287 |
+
Loads a trained OCR model's state dictionary.
|
288 |
+
Includes map_location to handle loading models trained on GPU to CPU, and vice versa.
|
289 |
+
"""
|
290 |
+
model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) # Always load to CPU first
|
291 |
+
model.eval() # Set to evaluation mode
|
292 |
+
=======
|
293 |
+
# model_ocr.py
|
294 |
+
|
295 |
+
import torch
|
296 |
+
import torch.nn as nn
|
297 |
+
import torch.nn.functional as F
|
298 |
+
import torch.optim as optim
|
299 |
+
from torch.utils.data import DataLoader # Keep DataLoader for type hinting
|
300 |
+
from tqdm import tqdm
|
301 |
+
from sklearn.metrics import accuracy_score
|
302 |
+
import editdistance
|
303 |
+
|
304 |
+
# Import config and char_indexer
|
305 |
+
# Ensure these imports align with your current config.py
|
306 |
+
from config import IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN
|
307 |
+
from data_handler_ocr import CharIndexer
|
308 |
+
# You might also need to import binarize_image, resize_image_for_ocr, normalize_image_for_model
|
309 |
+
# if they are used directly in model_ocr.py for internal preprocessing (e.g., in evaluate_model if not using DataLoader)
|
310 |
+
# For now, assuming they are handled by DataLoader transforms.
|
311 |
+
from utils_ocr import binarize_image, resize_image_for_ocr, normalize_image_for_model # Add this for completeness if needed elsewhere
|
312 |
+
|
313 |
+
|
314 |
+
class CNN_Backbone(nn.Module):
|
315 |
+
"""
|
316 |
+
CNN feature extractor for OCR. Designed to produce features suitable for RNN.
|
317 |
+
Output feature map should have height 1 after the final pooling/reduction.
|
318 |
+
"""
|
319 |
+
def __init__(self, input_channels=1, output_channels=512):
|
320 |
+
super(CNN_Backbone, self).__init__()
|
321 |
+
self.cnn = nn.Sequential(
|
322 |
+
# First block
|
323 |
+
nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
|
324 |
+
nn.ReLU(True),
|
325 |
+
nn.MaxPool2d(kernel_size=2, stride=2), # H: 32 -> 16, W: W_in -> W_in/2
|
326 |
+
|
327 |
+
# Second block
|
328 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
329 |
+
nn.ReLU(True),
|
330 |
+
nn.MaxPool2d(kernel_size=2, stride=2), # H: 16 -> 8, W: W_in/2 -> W_in/4
|
331 |
+
|
332 |
+
# Third block (with two conv layers)
|
333 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
334 |
+
nn.ReLU(True),
|
335 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
336 |
+
nn.ReLU(True),
|
337 |
+
# This MaxPool2d effectively brings height from 8 to 4, with a small width adjustment due to padding
|
338 |
+
# The original comment (W/4 + 1) is due to padding=1 and stride=1 on width, which is fine.
|
339 |
+
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # H: 8 -> 4, W: (W/4) -> (W/4 + 1) (approx)
|
340 |
+
|
341 |
+
# Fourth block
|
342 |
+
nn.Conv2d(256, output_channels, kernel_size=3, stride=1, padding=1),
|
343 |
+
nn.ReLU(True),
|
344 |
+
# This AdaptiveAvgPool2d makes sure the height dimension becomes 1
|
345 |
+
# while preserving the width. This is crucial for RNN input.
|
346 |
+
nn.AdaptiveAvgPool2d((1, None)) # Output height 1, preserve width
|
347 |
+
)
|
348 |
+
|
349 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
350 |
+
# x: (N, C, H, W) e.g., (B, 1, 32, W_img)
|
351 |
+
|
352 |
+
# Pass through the CNN layers
|
353 |
+
conv_features = self.cnn(x) # Output: (N, cnn_out_channels, 1, W_prime)
|
354 |
+
|
355 |
+
# Squeeze the height dimension (which is 1)
|
356 |
+
# This transforms (N, C_out, 1, W_prime) to (N, C_out, W_prime)
|
357 |
+
conv_features = conv_features.squeeze(2)
|
358 |
+
|
359 |
+
# Permute for RNN input: (sequence_length, batch_size, input_size)
|
360 |
+
# This transforms (N, C_out, W_prime) to (W_prime, N, C_out)
|
361 |
+
conv_features = conv_features.permute(2, 0, 1)
|
362 |
+
|
363 |
+
# Return the CNN features, ready for the RNN layer in CRNN
|
364 |
+
return conv_features
|
365 |
+
|
366 |
+
class BidirectionalLSTM(nn.Module):
|
367 |
+
"""Bidirectional LSTM layer for sequence modeling."""
|
368 |
+
def __init__(self, input_size: int, hidden_size: int, num_layers: int, dropout: float = 0.5):
|
369 |
+
super(BidirectionalLSTM, self).__init__()
|
370 |
+
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
|
371 |
+
bidirectional=True, dropout=dropout, batch_first=False)
|
372 |
+
# batch_first=False expects input as (sequence_length, batch_size, input_size)
|
373 |
+
|
374 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
375 |
+
output, _ = self.lstm(x) # [0] returns the output, [1] returns (h_n, c_n)
|
376 |
+
return output
|
377 |
+
|
378 |
+
class CRNN(nn.Module):
|
379 |
+
"""
|
380 |
+
Convolutional Recurrent Neural Network for OCR.
|
381 |
+
Combines CNN for feature extraction, LSTMs for sequence modeling,
|
382 |
+
and a final linear layer for character prediction.
|
383 |
+
"""
|
384 |
+
def __init__(self, num_classes: int, cnn_output_channels: int = 512,
|
385 |
+
rnn_hidden_size: int = 256, rnn_num_layers: int = 2):
|
386 |
+
super(CRNN, self).__init__()
|
387 |
+
self.cnn = CNN_Backbone(output_channels=cnn_output_channels)
|
388 |
+
# Input to LSTM is the number of channels from the CNN output
|
389 |
+
self.rnn = BidirectionalLSTM(cnn_output_channels, rnn_hidden_size, rnn_num_layers)
|
390 |
+
# Output of bidirectional LSTM is hidden_size * 2
|
391 |
+
self.fc = nn.Linear(rnn_hidden_size * 2, num_classes) # Final linear layer for classes
|
392 |
+
|
393 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
394 |
+
# x: (N, C, H, W) e.g., (B, 1, 32, W_img)
|
395 |
+
|
396 |
+
# 1. Pass through the CNN to extract features
|
397 |
+
conv_features = self.cnn(x) # Output: (W_prime, N, C_out) after permute in CNN_Backbone
|
398 |
+
|
399 |
+
# 2. Pass CNN features through the RNN (LSTM)
|
400 |
+
rnn_features = self.rnn(conv_features) # Output: (W_prime, N, rnn_hidden_size * 2)
|
401 |
+
|
402 |
+
# 3. Pass RNN features through the final fully connected layer
|
403 |
+
# Apply the linear layer to each time step independently
|
404 |
+
# output will be (W_prime, N, num_classes)
|
405 |
+
output = self.fc(rnn_features)
|
406 |
+
|
407 |
+
return output
|
408 |
+
|
409 |
+
|
410 |
+
# --- Decoding Function ---
|
411 |
+
def ctc_greedy_decode(output: torch.Tensor, char_indexer: CharIndexer) -> list[str]:
|
412 |
+
"""
|
413 |
+
Performs greedy decoding on the CTC output.
|
414 |
+
output: (sequence_length, batch_size, num_classes) - raw logits
|
415 |
+
"""
|
416 |
+
# Apply log_softmax to get probabilities for argmax
|
417 |
+
log_probs = F.log_softmax(output, dim=2)
|
418 |
+
|
419 |
+
# Permute to (batch_size, sequence_length, num_classes) for argmax along class dim
|
420 |
+
# This gives us the index of the most probable character at each time step for each sample in the batch.
|
421 |
+
predicted_indices = torch.argmax(log_probs.permute(1, 0, 2), dim=2).cpu().numpy()
|
422 |
+
|
423 |
+
decoded_texts = []
|
424 |
+
for seq in predicted_indices:
|
425 |
+
# Use char_indexer's decode method, which handles blank removal and duplicate collapse
|
426 |
+
decoded_texts.append(char_indexer.decode(seq.tolist())) # Convert numpy array to list
|
427 |
+
return decoded_texts
|
428 |
+
|
429 |
+
# --- Evaluation Function ---
|
430 |
+
def evaluate_model(model: nn.Module, dataloader: DataLoader, char_indexer: CharIndexer, device: str):
|
431 |
+
model.eval() # Set model to evaluation mode
|
432 |
+
# CTCLoss needs the blank token index, which is available from char_indexer
|
433 |
+
criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
|
434 |
+
total_loss = 0
|
435 |
+
all_predictions = []
|
436 |
+
all_ground_truths = []
|
437 |
+
|
438 |
+
with torch.no_grad(): # Disable gradient calculation for evaluation
|
439 |
+
for inputs, targets_padded, _, target_lengths in tqdm(dataloader, desc="Evaluating"):
|
440 |
+
inputs = inputs.to(device)
|
441 |
+
targets_padded = targets_padded.to(device)
|
442 |
+
target_lengths = target_lengths.to(device)
|
443 |
+
|
444 |
+
output = model(inputs) # (seq_len, batch_size, num_classes)
|
445 |
+
|
446 |
+
# Calculate input_lengths for CTCLoss. This is the sequence length produced by the CNN/RNN.
|
447 |
+
# It's the `output.shape[0]` (sequence_length) for each item in the batch.
|
448 |
+
outputs_seq_len_for_ctc = torch.full(
|
449 |
+
size=(output.shape[1],), # batch_size
|
450 |
+
fill_value=output.shape[0], # actual sequence length (T) from model output
|
451 |
+
dtype=torch.long,
|
452 |
+
device=device
|
453 |
+
)
|
454 |
+
|
455 |
+
# CTC Loss calculation requires log_softmax on the output logits
|
456 |
+
log_probs_for_loss = F.log_softmax(output, dim=2) # (T, N, C)
|
457 |
+
|
458 |
+
loss = criterion(log_probs_for_loss, targets_padded, outputs_seq_len_for_ctc, target_lengths)
|
459 |
+
total_loss += loss.item() * inputs.size(0) # Multiply by batch size for correct average
|
460 |
+
|
461 |
+
# Decode predictions for metrics
|
462 |
+
decoded_preds = ctc_greedy_decode(output, char_indexer)
|
463 |
+
|
464 |
+
# Reconstruct ground truths from encoded tensors
|
465 |
+
ground_truths = []
|
466 |
+
# Loop through each sample in the batch
|
467 |
+
for i in range(targets_padded.size(0)):
|
468 |
+
# Extract the actual target sequence for the i-th sample using its length
|
469 |
+
# Convert to list before passing to char_indexer.decode
|
470 |
+
ground_truths.append(char_indexer.decode(targets_padded[i, :target_lengths[i]].tolist()))
|
471 |
+
|
472 |
+
all_predictions.extend(decoded_preds)
|
473 |
+
all_ground_truths.extend(ground_truths)
|
474 |
+
|
475 |
+
avg_loss = total_loss / len(dataloader.dataset)
|
476 |
+
|
477 |
+
# Calculate Character Error Rate (CER)
|
478 |
+
cer_sum = 0
|
479 |
+
total_chars = 0
|
480 |
+
for pred, gt in zip(all_predictions, all_ground_truths):
|
481 |
+
cer_sum += editdistance.eval(pred, gt)
|
482 |
+
total_chars += len(gt)
|
483 |
+
char_error_rate = cer_sum / total_chars if total_chars > 0 else 0.0
|
484 |
+
|
485 |
+
# Calculate Exact Match Accuracy (Word-level Accuracy)
|
486 |
+
exact_match_accuracy = accuracy_score(all_ground_truths, all_predictions)
|
487 |
+
|
488 |
+
return avg_loss, char_error_rate, exact_match_accuracy
|
489 |
+
|
490 |
+
# --- Training Function ---
|
491 |
+
def train_ocr_model(model: nn.Module, train_loader: DataLoader,
|
492 |
+
test_loader: DataLoader, char_indexer: CharIndexer,
|
493 |
+
epochs: int, device: str, progress_callback=None) -> tuple[nn.Module, dict]:
|
494 |
+
"""
|
495 |
+
Trains the OCR model using CTC loss.
|
496 |
+
"""
|
497 |
+
# CTCLoss needs the blank token index
|
498 |
+
criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
|
499 |
+
optimizer = optim.Adam(model.parameters(), lr=0.001) # Using a fixed LR for now
|
500 |
+
# Using ReduceLROnPlateau to adjust LR based on test loss (monitor 'min' loss)
|
501 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5)
|
502 |
+
|
503 |
+
model.to(device) # Ensure model is on the correct device
|
504 |
+
model.train() # Set model to training mode
|
505 |
+
|
506 |
+
training_history = {
|
507 |
+
'train_loss': [],
|
508 |
+
'test_loss': [],
|
509 |
+
'test_cer': [],
|
510 |
+
'test_exact_match_accuracy': []
|
511 |
+
}
|
512 |
+
|
513 |
+
for epoch in range(epochs):
|
514 |
+
running_loss = 0.0
|
515 |
+
pbar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)")
|
516 |
+
for images, texts_encoded, _, text_lengths in pbar_train:
|
517 |
+
images = images.to(device)
|
518 |
+
# Ensure target tensors are on the correct device for CTCLoss calculation
|
519 |
+
texts_encoded = texts_encoded.to(device)
|
520 |
+
text_lengths = text_lengths.to(device)
|
521 |
+
|
522 |
+
optimizer.zero_grad() # Clear gradients from previous step
|
523 |
+
outputs = model(images) # (sequence_length_from_cnn, batch_size, num_classes)
|
524 |
+
|
525 |
+
# `outputs.shape[0]` is the actual sequence length (T) produced by the model.
|
526 |
+
# CTC loss expects `input_lengths` to be a tensor of shape (batch_size,) with these values.
|
527 |
+
outputs_seq_len_for_ctc = torch.full(
|
528 |
+
size=(outputs.shape[1],), # batch_size
|
529 |
+
fill_value=outputs.shape[0], # actual sequence length (T) from model output
|
530 |
+
dtype=torch.long,
|
531 |
+
device=device
|
532 |
+
)
|
533 |
+
|
534 |
+
# CTC Loss calculation requires log_softmax on the output logits
|
535 |
+
log_probs_for_loss = F.log_softmax(outputs, dim=2) # (T, N, C)
|
536 |
+
|
537 |
+
# Use outputs_seq_len_for_ctc for the input_lengths argument
|
538 |
+
loss = criterion(log_probs_for_loss, texts_encoded, outputs_seq_len_for_ctc, text_lengths)
|
539 |
+
loss.backward() # Backpropagate
|
540 |
+
optimizer.step() # Update model weights
|
541 |
+
|
542 |
+
running_loss += loss.item() * images.size(0) # Multiply by batch size for correct average
|
543 |
+
pbar_train.set_postfix(loss=loss.item())
|
544 |
+
|
545 |
+
epoch_train_loss = running_loss / len(train_loader.dataset)
|
546 |
+
training_history['train_loss'].append(epoch_train_loss)
|
547 |
+
|
548 |
+
# Evaluate on test set using the dedicated function
|
549 |
+
# Ensure model is in eval mode before calling evaluate_model
|
550 |
+
model.eval()
|
551 |
+
test_loss, test_cer, test_exact_match_accuracy = evaluate_model(model, test_loader, char_indexer, device)
|
552 |
+
training_history['test_loss'].append(test_loss)
|
553 |
+
training_history['test_cer'].append(test_cer)
|
554 |
+
training_history['test_exact_match_accuracy'].append(test_exact_match_accuracy)
|
555 |
+
|
556 |
+
# Adjust learning rate based on test loss (this is where scheduler.step() is called)
|
557 |
+
scheduler.step(test_loss)
|
558 |
+
|
559 |
+
print(f"Epoch {epoch+1}/{epochs}: Train Loss={epoch_train_loss:.4f}, "
|
560 |
+
f"Test Loss={test_loss:.4f}, Test CER={test_cer:.4f}, Test Exact Match Acc={test_exact_match_accuracy:.4f}")
|
561 |
+
|
562 |
+
if progress_callback:
|
563 |
+
# Update progress bar with current epoch and key metrics
|
564 |
+
progress_val = (epoch + 1) / epochs
|
565 |
+
progress_callback(progress_val, text=f"Epoch {epoch+1}/{epochs} done. Test CER: {test_cer:.4f}, Test Exact Match Acc: {test_exact_match_accuracy:.4f}")
|
566 |
+
|
567 |
+
model.train() # Set model back to training mode after evaluation
|
568 |
+
|
569 |
+
return model, training_history
|
570 |
+
|
571 |
+
def save_ocr_model(model: nn.Module, path: str):
|
572 |
+
"""Saves the state dictionary of the trained OCR model."""
|
573 |
+
torch.save(model.state_dict(), path)
|
574 |
+
print(f"OCR model saved to {path}")
|
575 |
+
|
576 |
+
def load_ocr_model(model: nn.Module, path: str):
|
577 |
+
"""
|
578 |
+
Loads a trained OCR model's state dictionary.
|
579 |
+
Includes map_location to handle loading models trained on GPU to CPU, and vice versa.
|
580 |
+
"""
|
581 |
+
model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) # Always load to CPU first
|
582 |
+
model.eval() # Set to evaluation mode
|
583 |
+
>>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
|
584 |
+
print(f"OCR model loaded from {path}")
|
requirements.txt
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<<<<<<< HEAD
|
2 |
+
#requirements.txt
|
3 |
+
# This file lists all the Python libraries required to run the Handwritten Name OCR application.
|
4 |
+
# Install using: pip install -r requirements.txt
|
5 |
+
|
6 |
+
streamlit>=1.33.0
|
7 |
+
pandas>=2.2.2
|
8 |
+
numpy>=1.26.4
|
9 |
+
Pillow>=10.3.0
|
10 |
+
opencv-python-headless>=4.9.0.80
|
11 |
+
torch>=2.2.2
|
12 |
+
torchvision>=0.17.2 # PyTorch companion library for vision tasks (datasets, transforms)
|
13 |
+
matplotlib>=3.8.4 # For plotting training history
|
14 |
+
tqdm>=4.66.2 # For displaying progress bars during training
|
15 |
+
editdistance>=0.8.1 # For calculating character error rate (CER)
|
16 |
+
=======
|
17 |
+
#requirements.txt
|
18 |
+
# This file lists all the Python libraries required to run the Handwritten Name OCR application.
|
19 |
+
# Install using: pip install -r requirements.txt
|
20 |
+
|
21 |
+
streamlit>=1.33.0
|
22 |
+
pandas>=2.2.2
|
23 |
+
numpy>=1.26.4
|
24 |
+
Pillow>=10.3.0
|
25 |
+
opencv-python-headless>=4.9.0.80
|
26 |
+
torch>=2.2.2
|
27 |
+
torchvision>=0.17.2 # PyTorch companion library for vision tasks (datasets, transforms)
|
28 |
+
matplotlib>=3.8.4 # For plotting training history
|
29 |
+
tqdm>=4.66.2 # For displaying progress bars during training
|
30 |
+
editdistance>=0.8.1 # For calculating character error rate (CER)
|
31 |
+
>>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
|
32 |
+
scikit-learn>=1.4.2
|
utils_ocr.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<<<<<<< HEAD
|
2 |
+
#utils_ocr.py
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
from matplotlib.pylab import f
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import torch
|
9 |
+
from torchvision import transforms
|
10 |
+
|
11 |
+
# --- Image Preprocessing for OCR ---
|
12 |
+
|
13 |
+
def load_image_as_grayscale(image_path: str) -> Image.Image:
|
14 |
+
"""Loads an image from path and converts it to grayscale PIL Image."""
|
15 |
+
# Use PIL for robust image loading and conversion to grayscale 'L' mode
|
16 |
+
img = Image.open(image_path).convert('L')
|
17 |
+
return img
|
18 |
+
|
19 |
+
def binarize_image(image_pil: Image.Image) -> Image.Image:
|
20 |
+
"""Binarizes a grayscale PIL Image (black and white)."""
|
21 |
+
# Convert PIL to OpenCV format (numpy array)
|
22 |
+
img_np = np.array(image_pil)
|
23 |
+
# Apply Otsu's thresholding for adaptive binarization
|
24 |
+
_, img_bin = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
25 |
+
# Invert colors: Handwritten text usually dark on light. OCR models often
|
26 |
+
# prefer light text on dark background. Check your training data's style.
|
27 |
+
# This example assumes dark text on light background and inverts to white text on black.
|
28 |
+
img_bin = 255 - img_bin
|
29 |
+
return Image.fromarray(img_bin)
|
30 |
+
|
31 |
+
def resize_image_for_ocr(image_pil: Image.Image, target_height: int) -> Image.Image:
|
32 |
+
"""
|
33 |
+
Resizes a PIL Image to a target height while maintaining aspect ratio.
|
34 |
+
Pads width if necessary to avoid distortion.
|
35 |
+
"""
|
36 |
+
original_width, original_height = image_pil.size
|
37 |
+
# Calculate new width based on target height and original aspect ratio
|
38 |
+
new_width = int(original_width * (target_height / original_height))
|
39 |
+
resized_img = image_pil.resize((new_width, target_height), Image.LANCZOS)
|
40 |
+
return resized_img
|
41 |
+
|
42 |
+
def normalize_image_for_model(image_pil: Image.Image) -> torch.Tensor:
|
43 |
+
"""
|
44 |
+
Converts a PIL Image to a PyTorch Tensor and normalizes pixel values.
|
45 |
+
"""
|
46 |
+
# Convert to tensor (scales to 0-1 automatically)
|
47 |
+
tensor_transform = transforms.ToTensor()
|
48 |
+
img_tensor = tensor_transform(image_pil)
|
49 |
+
# For grayscale images, mean and std are single values.
|
50 |
+
# Adjust normalization values if your training data uses different ones.
|
51 |
+
img_tensor = transforms.Normalize((0.5,), (0.5,))(img_tensor) # Normalize to [-1, 1]
|
52 |
+
return img_tensor
|
53 |
+
|
54 |
+
def preprocess_user_image_for_ocr(uploaded_image_pil: Image.Image, target_height: int) -> torch.Tensor:
|
55 |
+
"""
|
56 |
+
Combines all preprocessing steps for a single user-uploaded image
|
57 |
+
to prepare it for the OCR model.
|
58 |
+
"""
|
59 |
+
# Ensure it's grayscale
|
60 |
+
img_gray = uploaded_image_pil.convert('L')
|
61 |
+
|
62 |
+
# Binarize
|
63 |
+
img_bin = binarize_image(img_gray)
|
64 |
+
|
65 |
+
# Resize (maintain aspect ratio)
|
66 |
+
img_resized = resize_image_for_ocr(img_bin, target_height)
|
67 |
+
|
68 |
+
# Normalize and convert to tensor
|
69 |
+
img_tensor = normalize_image_for_model(img_resized)
|
70 |
+
|
71 |
+
# Add batch dimension: (C, H, W) -> (1, C, H, W)
|
72 |
+
img_tensor = img_tensor.unsqueeze(0)
|
73 |
+
|
74 |
+
return img_tensor
|
75 |
+
|
76 |
+
def pad_image_tensor(image_tensor: torch.Tensor, max_width: int) -> torch.Tensor:
|
77 |
+
"""
|
78 |
+
Pads a single image tensor to a max_width with zeros.
|
79 |
+
Input tensor shape: (C, H, W)
|
80 |
+
Output tensor shape: (C, H, max_width)
|
81 |
+
"""
|
82 |
+
C, H, W = image_tensor.shape
|
83 |
+
if W > max_width:
|
84 |
+
# If image is wider than max_width, you might want to crop or resize it.
|
85 |
+
# For this example, we'll just return a warning or clip.
|
86 |
+
# A more robust solution might split text lines or use a different resizing strategy.
|
87 |
+
print(f"Warning: Image width {W} exceeds max_width {max_width}. Cropping.")
|
88 |
+
return image_tensor[:, :, :max_width] # Simple cropping
|
89 |
+
padding = max_width - W
|
90 |
+
# Pad on the right (P_left, P_right, P_top, P_bottom)
|
91 |
+
padded_tensor = f.pad(image_tensor, (0, padding), 'constant', 0)
|
92 |
+
=======
|
93 |
+
#utils_ocr.py
|
94 |
+
|
95 |
+
import cv2
|
96 |
+
from matplotlib.pylab import f
|
97 |
+
import numpy as np
|
98 |
+
from PIL import Image
|
99 |
+
import torch
|
100 |
+
from torchvision import transforms
|
101 |
+
|
102 |
+
# --- Image Preprocessing for OCR ---
|
103 |
+
|
104 |
+
def load_image_as_grayscale(image_path: str) -> Image.Image:
|
105 |
+
"""Loads an image from path and converts it to grayscale PIL Image."""
|
106 |
+
# Use PIL for robust image loading and conversion to grayscale 'L' mode
|
107 |
+
img = Image.open(image_path).convert('L')
|
108 |
+
return img
|
109 |
+
|
110 |
+
def binarize_image(image_pil: Image.Image) -> Image.Image:
|
111 |
+
"""Binarizes a grayscale PIL Image (black and white)."""
|
112 |
+
# Convert PIL to OpenCV format (numpy array)
|
113 |
+
img_np = np.array(image_pil)
|
114 |
+
# Apply Otsu's thresholding for adaptive binarization
|
115 |
+
_, img_bin = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
116 |
+
# Invert colors: Handwritten text usually dark on light. OCR models often
|
117 |
+
# prefer light text on dark background. Check your training data's style.
|
118 |
+
# This example assumes dark text on light background and inverts to white text on black.
|
119 |
+
img_bin = 255 - img_bin
|
120 |
+
return Image.fromarray(img_bin)
|
121 |
+
|
122 |
+
def resize_image_for_ocr(image_pil: Image.Image, target_height: int) -> Image.Image:
|
123 |
+
"""
|
124 |
+
Resizes a PIL Image to a target height while maintaining aspect ratio.
|
125 |
+
Pads width if necessary to avoid distortion.
|
126 |
+
"""
|
127 |
+
original_width, original_height = image_pil.size
|
128 |
+
# Calculate new width based on target height and original aspect ratio
|
129 |
+
new_width = int(original_width * (target_height / original_height))
|
130 |
+
resized_img = image_pil.resize((new_width, target_height), Image.LANCZOS)
|
131 |
+
return resized_img
|
132 |
+
|
133 |
+
def normalize_image_for_model(image_pil: Image.Image) -> torch.Tensor:
|
134 |
+
"""
|
135 |
+
Converts a PIL Image to a PyTorch Tensor and normalizes pixel values.
|
136 |
+
"""
|
137 |
+
# Convert to tensor (scales to 0-1 automatically)
|
138 |
+
tensor_transform = transforms.ToTensor()
|
139 |
+
img_tensor = tensor_transform(image_pil)
|
140 |
+
# For grayscale images, mean and std are single values.
|
141 |
+
# Adjust normalization values if your training data uses different ones.
|
142 |
+
img_tensor = transforms.Normalize((0.5,), (0.5,))(img_tensor) # Normalize to [-1, 1]
|
143 |
+
return img_tensor
|
144 |
+
|
145 |
+
def preprocess_user_image_for_ocr(uploaded_image_pil: Image.Image, target_height: int) -> torch.Tensor:
|
146 |
+
"""
|
147 |
+
Combines all preprocessing steps for a single user-uploaded image
|
148 |
+
to prepare it for the OCR model.
|
149 |
+
"""
|
150 |
+
# Ensure it's grayscale
|
151 |
+
img_gray = uploaded_image_pil.convert('L')
|
152 |
+
|
153 |
+
# Binarize
|
154 |
+
img_bin = binarize_image(img_gray)
|
155 |
+
|
156 |
+
# Resize (maintain aspect ratio)
|
157 |
+
img_resized = resize_image_for_ocr(img_bin, target_height)
|
158 |
+
|
159 |
+
# Normalize and convert to tensor
|
160 |
+
img_tensor = normalize_image_for_model(img_resized)
|
161 |
+
|
162 |
+
# Add batch dimension: (C, H, W) -> (1, C, H, W)
|
163 |
+
img_tensor = img_tensor.unsqueeze(0)
|
164 |
+
|
165 |
+
return img_tensor
|
166 |
+
|
167 |
+
def pad_image_tensor(image_tensor: torch.Tensor, max_width: int) -> torch.Tensor:
|
168 |
+
"""
|
169 |
+
Pads a single image tensor to a max_width with zeros.
|
170 |
+
Input tensor shape: (C, H, W)
|
171 |
+
Output tensor shape: (C, H, max_width)
|
172 |
+
"""
|
173 |
+
C, H, W = image_tensor.shape
|
174 |
+
if W > max_width:
|
175 |
+
# If image is wider than max_width, you might want to crop or resize it.
|
176 |
+
# For this example, we'll just return a warning or clip.
|
177 |
+
# A more robust solution might split text lines or use a different resizing strategy.
|
178 |
+
print(f"Warning: Image width {W} exceeds max_width {max_width}. Cropping.")
|
179 |
+
return image_tensor[:, :, :max_width] # Simple cropping
|
180 |
+
padding = max_width - W
|
181 |
+
# Pad on the right (P_left, P_right, P_top, P_bottom)
|
182 |
+
padded_tensor = f.pad(image_tensor, (0, padding), 'constant', 0)
|
183 |
+
>>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
|
184 |
+
return padded_tensor
|