marianeft commited on
Commit
35d1d63
·
verified ·
1 Parent(s): 1ae7e6f

Update data_handler_ocr.py

Browse files
Files changed (1) hide show
  1. data_handler_ocr.py +58 -177
data_handler_ocr.py CHANGED
@@ -1,4 +1,3 @@
1
- <<<<<<< HEAD
2
  #data_handler_ocr.py
3
 
4
  import pandas as pd
@@ -11,34 +10,52 @@ import numpy as np
11
  import torch.nn.functional as F
12
 
13
  # Import utility functions and config
14
- from config import CHARS, BLANK_TOKEN, IMG_HEIGHT, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR
15
  from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
16
 
17
  class CharIndexer:
18
  """Manages character-to-index and index-to-character mappings."""
19
- def __init__(self, chars: str, blank_token: str):
20
- self.char_to_idx = {char: i for i, char in enumerate(chars)}
21
- self.idx_to_char = {i: char for i, char in enumerate(chars)}
22
- self.blank_token_idx = len(chars) # Index for the blank token
23
- self.idx_to_char[self.blank_token_idx] = blank_token # Add blank token to idx_to_char
24
- self.num_classes = len(chars) + 1 # Total classes including blank
 
 
 
 
 
 
 
 
 
 
25
 
26
  def encode(self, text: str) -> list[int]:
27
  """Converts a text string to a list of integer indices."""
28
- return [self.char_to_idx[char] for char in text]
 
 
 
 
 
 
 
29
 
30
  def decode(self, indices: list[int]) -> str:
31
  """Converts a list of integer indices back to a text string."""
32
- # CTC decoding often produces repeated characters and blank tokens.
33
- # This simple decoder removes blanks and duplicates.
34
  decoded_text = []
35
  for i, idx in enumerate(indices):
36
  if idx == self.blank_token_idx:
37
  continue
38
- # Remove consecutive duplicates
39
  if i > 0 and indices[i-1] == idx:
40
  continue
41
- decoded_text.append(self.idx_to_char[idx])
 
 
 
 
42
  return "".join(decoded_text)
43
 
44
  class OCRDataset(Dataset):
@@ -47,43 +64,44 @@ class OCRDataset(Dataset):
47
  Loads images and their corresponding text labels.
48
  """
49
  def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
50
- """
51
- Initializes the OCR Dataset.
52
- Args:
53
- dataframe (pd.DataFrame): A DataFrame containing 'image_path' and 'label' columns.
54
- char_indexer (CharIndexer): An instance of CharIndexer for character encoding.
55
- transform (callable, optional): Optional transform to be applied on an image.
56
- """
57
  self.data = dataframe
58
  self.char_indexer = char_indexer
59
  self.image_dir = image_dir
60
- self.transform = transform
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  def __len__(self) -> int:
64
  return len(self.data)
65
 
66
  def __getitem__(self, idx):
67
- raw_filename_entry = self.data.iloc[idx]['FILENAME']
68
- ground_truth_text = self.data.iloc[idx]['IDENTITY']
 
 
69
 
70
- filename = raw_filename_entry.split(',')[0].strip() # .strip() removes any whitespace
71
- # Construct the full image path
72
- img_path = os.path.join(self.image_dir, filename)
73
- # Ensure ground_truth_text is a string
74
  ground_truth_text = str(ground_truth_text)
75
 
76
- # Load and transform image
77
  try:
78
- image = Image.open(img_path).convert('L') # Convert to grayscale
79
  except FileNotFoundError:
80
- print(f"Error: Image file not found at {img_path}. Skipping this item.")
81
- raise # Re-raise to let the main traceback be seen.
82
 
83
  if self.transform:
84
  image = self.transform(image)
85
 
86
- image_width = image.size(2) # Assuming image is a tensor (C, H, W) after transform
87
 
88
  text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
89
  text_length = len(text_encoded)
@@ -97,15 +115,13 @@ def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tenso
97
  """
98
  images, texts, image_widths, text_lengths = zip(*batch)
99
 
100
- # Pad images to the maximum width in the current batch
101
  max_batch_width = max(image_widths)
102
  padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
103
- images_batch = torch.stack(padded_images, 0) # Stack to (N, C, H, max_W)
104
 
105
- # Concatenate all text sequences and get their lengths
106
  texts_batch = torch.cat(texts, 0)
107
  text_lengths_tensor = torch.tensor(text_lengths, dtype=torch.long)
108
- image_widths_tensor = torch.tensor(image_widths, dtype=torch.long) # Actual widths
109
 
110
  return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
111
 
@@ -113,10 +129,10 @@ def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tenso
113
  def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
114
  """
115
  Loads training and testing dataframes.
116
- Assumes CSVs have 'filename' and 'name' columns.
117
  """
118
- train_df = pd.read_csv(train_csv_path)
119
- test_df = pd.read_csv(test_csv_path)
120
  return train_df, test_df
121
 
122
  def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
@@ -125,146 +141,11 @@ def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
125
  Creates PyTorch DataLoader objects for OCR training and testing datasets,
126
  using specific image directories for train/test.
127
  """
128
- train_dataset = OCRDataset(train_df, TRAIN_IMAGES_DIR, char_indexer)
129
- test_dataset = OCRDataset(test_df, TEST_IMAGES_DIR, char_indexer)
130
 
131
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
132
  num_workers=0, collate_fn=ocr_collate_fn)
133
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
134
  num_workers=0, collate_fn=ocr_collate_fn)
135
- =======
136
- #data_handler_ocr.py
137
-
138
- import pandas as pd
139
- import torch
140
- from torch.utils.data import Dataset, DataLoader
141
- from torchvision import transforms
142
- import os
143
- from PIL import Image
144
- import numpy as np
145
- import torch.nn.functional as F
146
-
147
- # Import utility functions and config
148
- from config import CHARS, BLANK_TOKEN, IMG_HEIGHT, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR
149
- from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
150
-
151
- class CharIndexer:
152
- """Manages character-to-index and index-to-character mappings."""
153
- def __init__(self, chars: str, blank_token: str):
154
- self.char_to_idx = {char: i for i, char in enumerate(chars)}
155
- self.idx_to_char = {i: char for i, char in enumerate(chars)}
156
- self.blank_token_idx = len(chars) # Index for the blank token
157
- self.idx_to_char[self.blank_token_idx] = blank_token # Add blank token to idx_to_char
158
- self.num_classes = len(chars) + 1 # Total classes including blank
159
-
160
- def encode(self, text: str) -> list[int]:
161
- """Converts a text string to a list of integer indices."""
162
- return [self.char_to_idx[char] for char in text]
163
-
164
- def decode(self, indices: list[int]) -> str:
165
- """Converts a list of integer indices back to a text string."""
166
- # CTC decoding often produces repeated characters and blank tokens.
167
- # This simple decoder removes blanks and duplicates.
168
- decoded_text = []
169
- for i, idx in enumerate(indices):
170
- if idx == self.blank_token_idx:
171
- continue
172
- # Remove consecutive duplicates
173
- if i > 0 and indices[i-1] == idx:
174
- continue
175
- decoded_text.append(self.idx_to_char[idx])
176
- return "".join(decoded_text)
177
-
178
- class OCRDataset(Dataset):
179
- """
180
- Custom PyTorch Dataset for the Handwritten Name Recognition task.
181
- Loads images and their corresponding text labels.
182
- """
183
- def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
184
- """
185
- Initializes the OCR Dataset.
186
- Args:
187
- dataframe (pd.DataFrame): A DataFrame containing 'image_path' and 'label' columns.
188
- char_indexer (CharIndexer): An instance of CharIndexer for character encoding.
189
- transform (callable, optional): Optional transform to be applied on an image.
190
- """
191
- self.data = dataframe
192
- self.char_indexer = char_indexer
193
- self.image_dir = image_dir
194
- self.transform = transform
195
-
196
-
197
- def __len__(self) -> int:
198
- return len(self.data)
199
-
200
- def __getitem__(self, idx):
201
- raw_filename_entry = self.data.iloc[idx]['FILENAME']
202
- ground_truth_text = self.data.iloc[idx]['IDENTITY']
203
-
204
- filename = raw_filename_entry.split(',')[0].strip() # .strip() removes any whitespace
205
- # Construct the full image path
206
- img_path = os.path.join(self.image_dir, filename)
207
- # Ensure ground_truth_text is a string
208
- ground_truth_text = str(ground_truth_text)
209
-
210
- # Load and transform image
211
- try:
212
- image = Image.open(img_path).convert('L') # Convert to grayscale
213
- except FileNotFoundError:
214
- print(f"Error: Image file not found at {img_path}. Skipping this item.")
215
- raise # Re-raise to let the main traceback be seen.
216
-
217
- if self.transform:
218
- image = self.transform(image)
219
-
220
- image_width = image.size(2) # Assuming image is a tensor (C, H, W) after transform
221
-
222
- text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
223
- text_length = len(text_encoded)
224
-
225
- return image, text_encoded, image_width, text_length
226
-
227
- def ocr_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
228
- """
229
- Custom collate function for the DataLoader to handle variable-width images
230
- and variable-length text sequences for CTC loss.
231
- """
232
- images, texts, image_widths, text_lengths = zip(*batch)
233
-
234
- # Pad images to the maximum width in the current batch
235
- max_batch_width = max(image_widths)
236
- padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
237
- images_batch = torch.stack(padded_images, 0) # Stack to (N, C, H, max_W)
238
-
239
- # Concatenate all text sequences and get their lengths
240
- texts_batch = torch.cat(texts, 0)
241
- text_lengths_tensor = torch.tensor(text_lengths, dtype=torch.long)
242
- image_widths_tensor = torch.tensor(image_widths, dtype=torch.long) # Actual widths
243
-
244
- return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
245
-
246
-
247
- def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
248
- """
249
- Loads training and testing dataframes.
250
- Assumes CSVs have 'filename' and 'name' columns.
251
- """
252
- train_df = pd.read_csv(train_csv_path)
253
- test_df = pd.read_csv(test_csv_path)
254
- return train_df, test_df
255
-
256
- def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
257
- char_indexer: CharIndexer, batch_size: int) -> tuple[DataLoader, DataLoader]:
258
- """
259
- Creates PyTorch DataLoader objects for OCR training and testing datasets,
260
- using specific image directories for train/test.
261
- """
262
- train_dataset = OCRDataset(train_df, TRAIN_IMAGES_DIR, char_indexer)
263
- test_dataset = OCRDataset(test_df, TEST_IMAGES_DIR, char_indexer)
264
-
265
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
266
- num_workers=0, collate_fn=ocr_collate_fn)
267
- test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
268
- num_workers=0, collate_fn=ocr_collate_fn)
269
- >>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
270
- return train_loader, test_loader
 
 
1
  #data_handler_ocr.py
2
 
3
  import pandas as pd
 
10
  import torch.nn.functional as F
11
 
12
  # Import utility functions and config
13
+ from config import VOCABULARY, BLANK_TOKEN, BLANK_TOKEN_SYMBOL, IMG_HEIGHT, TRAIN_IMAGES_DIR, TEST_IMAGES_DIR
14
  from utils_ocr import load_image_as_grayscale, binarize_image, resize_image_for_ocr, normalize_image_for_model
15
 
16
  class CharIndexer:
17
  """Manages character-to-index and index-to-character mappings."""
18
+ def __init__(self, vocabulary_string: str, blank_token_symbol: str):
19
+ self.chars = sorted(list(set(vocabulary_string)))
20
+ self.char_to_idx = {char: i for i, char in enumerate(self.chars)}
21
+ self.idx_to_char = {i: char for i, char in enumerate(self.chars)}
22
+
23
+ if blank_token_symbol not in self.char_to_idx:
24
+ raise ValueError(f"Blank token symbol '{blank_token_symbol}' not found in provided vocabulary string: '{vocabulary_string}'")
25
+
26
+ self.blank_token_idx = self.char_to_idx[blank_token_symbol]
27
+ self.num_classes = len(self.chars)
28
+
29
+ if self.blank_token_idx >= self.num_classes:
30
+ raise ValueError(f"Blank token index ({self.blank_token_idx}) is out of range for num_classes ({self.num_classes}). This indicates a configuration mismatch.")
31
+
32
+ print(f"CharIndexer initialized: num_classes={self.num_classes}, blank_token_idx={self.blank_token_idx}")
33
+ print(f"Mapped blank symbol: '{self.idx_to_char[self.blank_token_idx]}'")
34
 
35
  def encode(self, text: str) -> list[int]:
36
  """Converts a text string to a list of integer indices."""
37
+ encoded_list = []
38
+ for char in text:
39
+ if char in self.char_to_idx:
40
+ encoded_list.append(self.char_to_idx[char])
41
+ else:
42
+ print(f"Warning: Character '{char}' not found in CharIndexer vocabulary. Mapping to blank token.")
43
+ encoded_list.append(self.blank_token_idx)
44
+ return encoded_list
45
 
46
  def decode(self, indices: list[int]) -> str:
47
  """Converts a list of integer indices back to a text string."""
 
 
48
  decoded_text = []
49
  for i, idx in enumerate(indices):
50
  if idx == self.blank_token_idx:
51
  continue
 
52
  if i > 0 and indices[i-1] == idx:
53
  continue
54
+ if idx in self.idx_to_char:
55
+ decoded_text.append(self.idx_to_char[idx])
56
+ else:
57
+ print(f"Warning: Index {idx} not found in CharIndexer's idx_to_char mapping during decoding.")
58
+
59
  return "".join(decoded_text)
60
 
61
  class OCRDataset(Dataset):
 
64
  Loads images and their corresponding text labels.
65
  """
66
  def __init__(self, dataframe: pd.DataFrame, char_indexer: CharIndexer, image_dir: str, transform=None):
 
 
 
 
 
 
 
67
  self.data = dataframe
68
  self.char_indexer = char_indexer
69
  self.image_dir = image_dir
70
+
71
+ if transform is None:
72
+ self.transform = transforms.Compose([
73
+
74
+ transforms.Lambda(lambda img: binarize_image(img)),
75
+ transforms.Lambda(lambda img: resize_image_for_ocr(img, IMG_HEIGHT)),
76
+ transforms.ToTensor(),
77
+ transforms.Lambda(normalize_image_for_model)
78
+ ])
79
+ else:
80
+ self.transform = transform
81
 
82
 
83
  def __len__(self) -> int:
84
  return len(self.data)
85
 
86
  def __getitem__(self, idx):
87
+ raw_filename_entry = self.data.loc[idx, 'FILENAME']
88
+ ground_truth_text = self.data.loc[idx, 'IDENTITY']
89
+
90
+ filename_only = raw_filename_entry.split(',')[0].strip()
91
 
92
+ img_path = os.path.join(self.image_dir, filename_only)
 
 
 
93
  ground_truth_text = str(ground_truth_text)
94
 
 
95
  try:
96
+ image = load_image_as_grayscale(img_path)
97
  except FileNotFoundError:
98
+ print(f"Error: Image file not found at {img_path}. Please check your dataset and config.py paths.")
99
+ raise
100
 
101
  if self.transform:
102
  image = self.transform(image)
103
 
104
+ image_width = image.shape[2]
105
 
106
  text_encoded = torch.tensor(self.char_indexer.encode(ground_truth_text), dtype=torch.long)
107
  text_length = len(text_encoded)
 
115
  """
116
  images, texts, image_widths, text_lengths = zip(*batch)
117
 
 
118
  max_batch_width = max(image_widths)
119
  padded_images = [F.pad(img, (0, max_batch_width - img.shape[2]), 'constant', 0) for img in images]
120
+ images_batch = torch.stack(padded_images, 0)
121
 
 
122
  texts_batch = torch.cat(texts, 0)
123
  text_lengths_tensor = torch.tensor(text_lengths, dtype=torch.long)
124
+ image_widths_tensor = torch.tensor(image_widths, dtype=torch.long)
125
 
126
  return images_batch, texts_batch, image_widths_tensor, text_lengths_tensor
127
 
 
129
  def load_ocr_dataframes(train_csv_path: str, test_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
130
  """
131
  Loads training and testing dataframes.
132
+ Assumes CSVs have 'FILENAME' and 'IDENTITY' columns and are comma-delimited with no header.
133
  """
134
+ train_df = pd.read_csv(train_csv_path, delimiter=',', names=['FILENAME', 'IDENTITY'], header=None, encoding='utf-8')
135
+ test_df = pd.read_csv(test_csv_path, delimiter=',', names=['FILENAME', 'IDENTITY'], header=None, encoding='utf-8')
136
  return train_df, test_df
137
 
138
  def create_ocr_dataloaders(train_df: pd.DataFrame, test_df: pd.DataFrame,
 
141
  Creates PyTorch DataLoader objects for OCR training and testing datasets,
142
  using specific image directories for train/test.
143
  """
144
+ train_dataset = OCRDataset(dataframe=train_df, char_indexer=char_indexer, image_dir=TRAIN_IMAGES_DIR)
145
+ test_dataset = OCRDataset(dataframe=test_df, char_indexer=char_indexer, image_dir=TEST_IMAGES_DIR)
146
 
147
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
148
  num_workers=0, collate_fn=ocr_collate_fn)
149
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
150
  num_workers=0, collate_fn=ocr_collate_fn)
151
+ return train_loader, test_loader