Spaces:
Build error
Build error
Update data_handler_ocr.py
Browse files- data_handler_ocr.py +58 -177
data_handler_ocr.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
<<<<<<< HEAD
|
2 |
#data_handler_ocr.py
|
3 |
|
4 |
import pandas as pd
|
@@ -11,34 +10,52 @@ import numpy as np
|
|
11 |
import torch.nn.functional as F
|
12 |
|
13 |
# Import utility functions and config
|
14 |
-
from config import
|
15 |
from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
|
16 |
|
17 |
class CharIndexer:
|
18 |
"""Manages character-to-index and index-to-character mappings."""
|
19 |
-
def __init__(self,
|
20 |
-
self.
|
21 |
-
self.
|
22 |
-
self.
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def encode(self, text: str) -> list[int]:
|
27 |
"""Converts a text string to a list of integer indices."""
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
def decode(self, indices: list[int]) -> str:
|
31 |
"""Converts a list of integer indices back to a text string."""
|
32 |
-
# CTC decoding often produces repeated characters and blank tokens.
|
33 |
-
# This simple decoder removes blanks and duplicates.
|
34 |
decoded_text = []
|
35 |
for i, idx in enumerate(indices):
|
36 |
if idx == self.blank_token_idx:
|
37 |
continue
|
38 |
-
# Remove consecutive duplicates
|
39 |
if i > 0 and indices[i-1] == idx:
|
40 |
continue
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
return "".join(decoded_text)
|
43 |
|
44 |
class OCRDataset(Dataset):
|
@@ -47,43 +64,44 @@ class OCRDataset(Dataset):
|
|
47 |
Loads images and their corresponding text labels.
|
48 |
"""
|
49 |
def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
|
50 |
-
"""
|
51 |
-
Initializes the OCR Dataset.
|
52 |
-
Args:
|
53 |
-
dataframe (pd.DataFrame): A DataFrame containing 'image_path' and 'label' columns.
|
54 |
-
char_indexer (CharIndexer): An instance of CharIndexer for character encoding.
|
55 |
-
transform (callable, optional): Optional transform to be applied on an image.
|
56 |
-
"""
|
57 |
self.data = dataframe
|
58 |
self.char_indexer = char_indexer
|
59 |
self.image_dir = image_dir
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
|
63 |
def __len__(self) -> int:
|
64 |
return len(self.data)
|
65 |
|
66 |
def __getitem__(self, idx):
|
67 |
-
raw_filename_entry = self.data.
|
68 |
-
ground_truth_text = self.data.
|
|
|
|
|
69 |
|
70 |
-
|
71 |
-
# Construct the full image path
|
72 |
-
img_path = os.path.join(self.image_dir, filename)
|
73 |
-
# Ensure ground_truth_text is a string
|
74 |
ground_truth_text = str(ground_truth_text)
|
75 |
|
76 |
-
# Load and transform image
|
77 |
try:
|
78 |
-
image =
|
79 |
except FileNotFoundError:
|
80 |
-
print(f"Error: Image file not found at {img_path}.
|
81 |
-
raise
|
82 |
|
83 |
if self.transform:
|
84 |
image = self.transform(image)
|
85 |
|
86 |
-
image_width = image.
|
87 |
|
88 |
text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
|
89 |
text_length = len(text_encoded)
|
@@ -97,15 +115,13 @@ def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tenso
|
|
97 |
"""
|
98 |
images, texts, image_widths, text_lengths = zip(*batch)
|
99 |
|
100 |
-
# Pad images to the maximum width in the current batch
|
101 |
max_batch_width = max(image_widths)
|
102 |
padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
|
103 |
-
images_batch = torch.stack(padded_images, 0)
|
104 |
|
105 |
-
# Concatenate all text sequences and get their lengths
|
106 |
texts_batch = torch.cat(texts, 0)
|
107 |
text_lengths_tensor = torch.tensor(text_lengths, dtype=torch.long)
|
108 |
-
image_widths_tensor = torch.tensor(image_widths, dtype=torch.long)
|
109 |
|
110 |
return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
|
111 |
|
@@ -113,10 +129,10 @@ def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tenso
|
|
113 |
def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
|
114 |
"""
|
115 |
Loads training and testing dataframes.
|
116 |
-
Assumes CSVs have '
|
117 |
"""
|
118 |
-
train_df = pd.read_csv(train_csv_path)
|
119 |
-
test_df = pd.read_csv(test_csv_path)
|
120 |
return train_df, test_df
|
121 |
|
122 |
def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
|
@@ -125,146 +141,11 @@ def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
|
|
125 |
Creates PyTorch DataLoader objects for OCR training and testing datasets,
|
126 |
using specific image directories for train/test.
|
127 |
"""
|
128 |
-
train_dataset = OCRDataset(train_df,
|
129 |
-
test_dataset = OCRDataset(test_df,
|
130 |
|
131 |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
|
132 |
num_workers=0, collate_fn=ocr_collate_fn)
|
133 |
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
|
134 |
num_workers=0, collate_fn=ocr_collate_fn)
|
135 |
-
|
136 |
-
#data_handler_ocr.py
|
137 |
-
|
138 |
-
import pandas as pd
|
139 |
-
import torch
|
140 |
-
from torch.utils.data import Dataset, DataLoader
|
141 |
-
from torchvision import transforms
|
142 |
-
import os
|
143 |
-
from PIL import Image
|
144 |
-
import numpy as np
|
145 |
-
import torch.nn.functional as F
|
146 |
-
|
147 |
-
# Import utility functions and config
|
148 |
-
from config import CHARS, BLANK_TOKEN, IMG_HEIGHT, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR
|
149 |
-
from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
|
150 |
-
|
151 |
-
class CharIndexer:
|
152 |
-
"""Manages character-to-index and index-to-character mappings."""
|
153 |
-
def __init__(self, chars: str, blank_token: str):
|
154 |
-
self.char_to_idx = {char: i for i, char in enumerate(chars)}
|
155 |
-
self.idx_to_char = {i: char for i, char in enumerate(chars)}
|
156 |
-
self.blank_token_idx = len(chars) # Index for the blank token
|
157 |
-
self.idx_to_char[self.blank_token_idx] = blank_token # Add blank token to idx_to_char
|
158 |
-
self.num_classes = len(chars) + 1 # Total classes including blank
|
159 |
-
|
160 |
-
def encode(self, text: str) -> list[int]:
|
161 |
-
"""Converts a text string to a list of integer indices."""
|
162 |
-
return [self.char_to_idx[char] for char in text]
|
163 |
-
|
164 |
-
def decode(self, indices: list[int]) -> str:
|
165 |
-
"""Converts a list of integer indices back to a text string."""
|
166 |
-
# CTC decoding often produces repeated characters and blank tokens.
|
167 |
-
# This simple decoder removes blanks and duplicates.
|
168 |
-
decoded_text = []
|
169 |
-
for i, idx in enumerate(indices):
|
170 |
-
if idx == self.blank_token_idx:
|
171 |
-
continue
|
172 |
-
# Remove consecutive duplicates
|
173 |
-
if i > 0 and indices[i-1] == idx:
|
174 |
-
continue
|
175 |
-
decoded_text.append(self.idx_to_char[idx])
|
176 |
-
return "".join(decoded_text)
|
177 |
-
|
178 |
-
class OCRDataset(Dataset):
|
179 |
-
"""
|
180 |
-
Custom PyTorch Dataset for the Handwritten Name Recognition task.
|
181 |
-
Loads images and their corresponding text labels.
|
182 |
-
"""
|
183 |
-
def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
|
184 |
-
"""
|
185 |
-
Initializes the OCR Dataset.
|
186 |
-
Args:
|
187 |
-
dataframe (pd.DataFrame): A DataFrame containing 'image_path' and 'label' columns.
|
188 |
-
char_indexer (CharIndexer): An instance of CharIndexer for character encoding.
|
189 |
-
transform (callable, optional): Optional transform to be applied on an image.
|
190 |
-
"""
|
191 |
-
self.data = dataframe
|
192 |
-
self.char_indexer = char_indexer
|
193 |
-
self.image_dir = image_dir
|
194 |
-
self.transform = transform
|
195 |
-
|
196 |
-
|
197 |
-
def __len__(self) -> int:
|
198 |
-
return len(self.data)
|
199 |
-
|
200 |
-
def __getitem__(self, idx):
|
201 |
-
raw_filename_entry = self.data.iloc[idx]['FILENAME']
|
202 |
-
ground_truth_text = self.data.iloc[idx]['IDENTITY']
|
203 |
-
|
204 |
-
filename = raw_filename_entry.split(',')[0].strip() # .strip() removes any whitespace
|
205 |
-
# Construct the full image path
|
206 |
-
img_path = os.path.join(self.image_dir, filename)
|
207 |
-
# Ensure ground_truth_text is a string
|
208 |
-
ground_truth_text = str(ground_truth_text)
|
209 |
-
|
210 |
-
# Load and transform image
|
211 |
-
try:
|
212 |
-
image = Image.open(img_path).convert('L') # Convert to grayscale
|
213 |
-
except FileNotFoundError:
|
214 |
-
print(f"Error: Image file not found at {img_path}. Skipping this item.")
|
215 |
-
raise # Re-raise to let the main traceback be seen.
|
216 |
-
|
217 |
-
if self.transform:
|
218 |
-
image = self.transform(image)
|
219 |
-
|
220 |
-
image_width = image.size(2) # Assuming image is a tensor (C, H, W) after transform
|
221 |
-
|
222 |
-
text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
|
223 |
-
text_length = len(text_encoded)
|
224 |
-
|
225 |
-
return image, text_encoded, image_width, text_length
|
226 |
-
|
227 |
-
def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
228 |
-
"""
|
229 |
-
Custom collate function for the DataLoader to handle variable-width images
|
230 |
-
and variable-length text sequences for CTC loss.
|
231 |
-
"""
|
232 |
-
images, texts, image_widths, text_lengths = zip(*batch)
|
233 |
-
|
234 |
-
# Pad images to the maximum width in the current batch
|
235 |
-
max_batch_width = max(image_widths)
|
236 |
-
padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
|
237 |
-
images_batch = torch.stack(padded_images, 0) # Stack to (N, C, H, max_W)
|
238 |
-
|
239 |
-
# Concatenate all text sequences and get their lengths
|
240 |
-
texts_batch = torch.cat(texts, 0)
|
241 |
-
text_lengths_tensor = torch.tensor(text_lengths, dtype=torch.long)
|
242 |
-
image_widths_tensor = torch.tensor(image_widths, dtype=torch.long) # Actual widths
|
243 |
-
|
244 |
-
return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
|
245 |
-
|
246 |
-
|
247 |
-
def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
|
248 |
-
"""
|
249 |
-
Loads training and testing dataframes.
|
250 |
-
Assumes CSVs have 'filename' and 'name' columns.
|
251 |
-
"""
|
252 |
-
train_df = pd.read_csv(train_csv_path)
|
253 |
-
test_df = pd.read_csv(test_csv_path)
|
254 |
-
return train_df, test_df
|
255 |
-
|
256 |
-
def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
|
257 |
-
char_indexer: CharIndexer, batch_size: int) -> tuple[DataLoader, DataLoader]:
|
258 |
-
"""
|
259 |
-
Creates PyTorch DataLoader objects for OCR training and testing datasets,
|
260 |
-
using specific image directories for train/test.
|
261 |
-
"""
|
262 |
-
train_dataset = OCRDataset(train_df, TRAIN_IMAGES_DIR, char_indexer)
|
263 |
-
test_dataset = OCRDataset(test_df, TEST_IMAGES_DIR, char_indexer)
|
264 |
-
|
265 |
-
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
|
266 |
-
num_workers=0, collate_fn=ocr_collate_fn)
|
267 |
-
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
|
268 |
-
num_workers=0, collate_fn=ocr_collate_fn)
|
269 |
-
>>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
|
270 |
-
return train_loader, test_loader
|
|
|
|
|
1 |
#data_handler_ocr.py
|
2 |
|
3 |
import pandas as pd
|
|
|
10 |
import torch.nn.functional as F
|
11 |
|
12 |
# Import utility functions and config
|
13 |
+
from config import VOCABULARY, BLANK_TOKEN, BLANK_TOKEN_SYMBOL, IMG_HEIGHT, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR
|
14 |
from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
|
15 |
|
16 |
class CharIndexer:
|
17 |
"""Manages character-to-index and index-to-character mappings."""
|
18 |
+
def __init__(self, vocabulary_string: str, blank_token_symbol: str):
|
19 |
+
self.chars = sorted(list(set(vocabulary_string)))
|
20 |
+
self.char_to_idx = {char: i for i, char in enumerate(self.chars)}
|
21 |
+
self.idx_to_char = {i: char for i, char in enumerate(self.chars)}
|
22 |
+
|
23 |
+
if blank_token_symbol not in self.char_to_idx:
|
24 |
+
raise ValueError(f"Blank token symbol '{blank_token_symbol}' not found in provided vocabulary string: '{vocabulary_string}'")
|
25 |
+
|
26 |
+
self.blank_token_idx = self.char_to_idx[blank_token_symbol]
|
27 |
+
self.num_classes = len(self.chars)
|
28 |
+
|
29 |
+
if self.blank_token_idx >= self.num_classes:
|
30 |
+
raise ValueError(f"Blank token index ({self.blank_token_idx}) is out of range for num_classes ({self.num_classes}). This indicates a configuration mismatch.")
|
31 |
+
|
32 |
+
print(f"CharIndexer initialized: num_classes={self.num_classes}, blank_token_idx={self.blank_token_idx}")
|
33 |
+
print(f"Mapped blank symbol: '{self.idx_to_char[self.blank_token_idx]}'")
|
34 |
|
35 |
def encode(self, text: str) -> list[int]:
|
36 |
"""Converts a text string to a list of integer indices."""
|
37 |
+
encoded_list = []
|
38 |
+
for char in text:
|
39 |
+
if char in self.char_to_idx:
|
40 |
+
encoded_list.append(self.char_to_idx[char])
|
41 |
+
else:
|
42 |
+
print(f"Warning: Character '{char}' not found in CharIndexer vocabulary. Mapping to blank token.")
|
43 |
+
encoded_list.append(self.blank_token_idx)
|
44 |
+
return encoded_list
|
45 |
|
46 |
def decode(self, indices: list[int]) -> str:
|
47 |
"""Converts a list of integer indices back to a text string."""
|
|
|
|
|
48 |
decoded_text = []
|
49 |
for i, idx in enumerate(indices):
|
50 |
if idx == self.blank_token_idx:
|
51 |
continue
|
|
|
52 |
if i > 0 and indices[i-1] == idx:
|
53 |
continue
|
54 |
+
if idx in self.idx_to_char:
|
55 |
+
decoded_text.append(self.idx_to_char[idx])
|
56 |
+
else:
|
57 |
+
print(f"Warning: Index {idx} not found in CharIndexer's idx_to_char mapping during decoding.")
|
58 |
+
|
59 |
return "".join(decoded_text)
|
60 |
|
61 |
class OCRDataset(Dataset):
|
|
|
64 |
Loads images and their corresponding text labels.
|
65 |
"""
|
66 |
def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
self.data = dataframe
|
68 |
self.char_indexer = char_indexer
|
69 |
self.image_dir = image_dir
|
70 |
+
|
71 |
+
if transform is None:
|
72 |
+
self.transform = transforms.Compose([
|
73 |
+
|
74 |
+
transforms.Lambda(lambda img: binarize_image(img)),
|
75 |
+
transforms.Lambda(lambda img: resize_image_for_ocr(img, IMG_HEIGHT)),
|
76 |
+
transforms.ToTensor(),
|
77 |
+
transforms.Lambda(normalize_image_for_model)
|
78 |
+
])
|
79 |
+
else:
|
80 |
+
self.transform = transform
|
81 |
|
82 |
|
83 |
def __len__(self) -> int:
|
84 |
return len(self.data)
|
85 |
|
86 |
def __getitem__(self, idx):
|
87 |
+
raw_filename_entry = self.data.loc[idx, 'FILENAME']
|
88 |
+
ground_truth_text = self.data.loc[idx, 'IDENTITY']
|
89 |
+
|
90 |
+
filename_only = raw_filename_entry.split(',')[0].strip()
|
91 |
|
92 |
+
img_path = os.path.join(self.image_dir, filename_only)
|
|
|
|
|
|
|
93 |
ground_truth_text = str(ground_truth_text)
|
94 |
|
|
|
95 |
try:
|
96 |
+
image = load_image_as_grayscale(img_path)
|
97 |
except FileNotFoundError:
|
98 |
+
print(f"Error: Image file not found at {img_path}. Please check your dataset and config.py paths.")
|
99 |
+
raise
|
100 |
|
101 |
if self.transform:
|
102 |
image = self.transform(image)
|
103 |
|
104 |
+
image_width = image.shape[2]
|
105 |
|
106 |
text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
|
107 |
text_length = len(text_encoded)
|
|
|
115 |
"""
|
116 |
images, texts, image_widths, text_lengths = zip(*batch)
|
117 |
|
|
|
118 |
max_batch_width = max(image_widths)
|
119 |
padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
|
120 |
+
images_batch = torch.stack(padded_images, 0)
|
121 |
|
|
|
122 |
texts_batch = torch.cat(texts, 0)
|
123 |
text_lengths_tensor = torch.tensor(text_lengths, dtype=torch.long)
|
124 |
+
image_widths_tensor = torch.tensor(image_widths, dtype=torch.long)
|
125 |
|
126 |
return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
|
127 |
|
|
|
129 |
def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
|
130 |
"""
|
131 |
Loads training and testing dataframes.
|
132 |
+
Assumes CSVs have 'FILENAME' and 'IDENTITY' columns and are comma-delimited with no header.
|
133 |
"""
|
134 |
+
train_df = pd.read_csv(train_csv_path, delimiter=',', names=['FILENAME', 'IDENTITY'], header=None, encoding='utf-8')
|
135 |
+
test_df = pd.read_csv(test_csv_path, delimiter=',', names=['FILENAME', 'IDENTITY'], header=None, encoding='utf-8')
|
136 |
return train_df, test_df
|
137 |
|
138 |
def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
|
|
|
141 |
Creates PyTorch DataLoader objects for OCR training and testing datasets,
|
142 |
using specific image directories for train/test.
|
143 |
"""
|
144 |
+
train_dataset = OCRDataset(dataframe=train_df, char_indexer=char_indexer, image_dir=TRAIN_IMAGES_DIR)
|
145 |
+
test_dataset = OCRDataset(dataframe=test_df, char_indexer=char_indexer, image_dir=TEST_IMAGES_DIR)
|
146 |
|
147 |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
|
148 |
num_workers=0, collate_fn=ocr_collate_fn)
|
149 |
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
|
150 |
num_workers=0, collate_fn=ocr_collate_fn)
|
151 |
+
return train_loader, test_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|