marianeft commited on
Commit
0c48050
·
verified ·
1 Parent(s): 35d1d63

Update model_ocr.py

Browse files
Files changed (1) hide show
  1. model_ocr.py +5 -303
model_ocr.py CHANGED
@@ -1,4 +1,3 @@
1
- <<<<<<< HEAD
2
  # model_ocr.py
3
 
4
  import torch
@@ -11,13 +10,9 @@ from sklearn.metrics import accuracy_score
11
  import editdistance
12
 
13
  # Import config and char_indexer
14
- # Ensure these imports align with your current config.py
15
  from config import IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN
16
  from data_handler_ocr import CharIndexer
17
- # You might also need to import binarize_image, resize_image_for_ocr, normalize_image_for_model
18
- # if they are used directly in model_ocr.py for internal preprocessing (e.g., in evaluate_model if not using DataLoader)
19
- # For now, assuming they are handled by DataLoader transforms.
20
- from utils_ocr import binarize_image, resize_image_for_ocr, normalize_image_for_model # Add this for completeness if needed elsewhere
21
 
22
 
23
  class CNN_Backbone(nn.Module):
@@ -44,7 +39,6 @@ class CNN_Backbone(nn.Module):
44
  nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
45
  nn.ReLU(True),
46
  # This MaxPool2d effectively brings height from 8 to 4, with a small width adjustment due to padding
47
- # The original comment (W/4 + 1) is due to padding=1 and stride=1 on width, which is fine.
48
  nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # H: 8 -> 4, W: (W/4) -> (W/4 + 1) (approx)
49
 
50
  # Fourth block
@@ -97,7 +91,7 @@ class CRNN(nn.Module):
97
  # Input to LSTM is the number of channels from the CNN output
98
  self.rnn = BidirectionalLSTM(cnn_output_channels, rnn_hidden_size, rnn_num_layers)
99
  # Output of bidirectional LSTM is hidden_size * 2
100
- self.fc = nn.Linear(rnn_hidden_size * 2, num_classes) # Final linear layer for classes
101
 
102
  def forward(self, x: torch.Tensor) -> torch.Tensor:
103
  # x: (N, C, H, W) e.g., (B, 1, 32, W_img)
@@ -207,7 +201,7 @@ def train_ocr_model(model: nn.Module, train_loader: DataLoader,
207
  criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
208
  optimizer = optim.Adam(model.parameters(), lr=0.001) # Using a fixed LR for now
209
  # Using ReduceLROnPlateau to adjust LR based on test loss (monitor 'min' loss)
210
- scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5)
211
 
212
  model.to(device) # Ensure model is on the correct device
213
  model.train() # Set model to training mode
@@ -287,298 +281,6 @@ def load_ocr_model(model: nn.Module, path: str):
287
  Loads a trained OCR model's state dictionary.
288
  Includes map_location to handle loading models trained on GPU to CPU, and vice versa.
289
  """
290
- model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) # Always load to CPU first
291
  model.eval() # Set to evaluation mode
292
- =======
293
- # model_ocr.py
294
-
295
- import torch
296
- import torch.nn as nn
297
- import torch.nn.functional as F
298
- import torch.optim as optim
299
- from torch.utils.data import DataLoader # Keep DataLoader for type hinting
300
- from tqdm import tqdm
301
- from sklearn.metrics import accuracy_score
302
- import editdistance
303
-
304
- # Import config and char_indexer
305
- # Ensure these imports align with your current config.py
306
- from config import IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN
307
- from data_handler_ocr import CharIndexer
308
- # You might also need to import binarize_image, resize_image_for_ocr, normalize_image_for_model
309
- # if they are used directly in model_ocr.py for internal preprocessing (e.g., in evaluate_model if not using DataLoader)
310
- # For now, assuming they are handled by DataLoader transforms.
311
- from utils_ocr import binarize_image, resize_image_for_ocr, normalize_image_for_model # Add this for completeness if needed elsewhere
312
-
313
-
314
- class CNN_Backbone(nn.Module):
315
- """
316
- CNN feature extractor for OCR. Designed to produce features suitable for RNN.
317
- Output feature map should have height 1 after the final pooling/reduction.
318
- """
319
- def __init__(self, input_channels=1, output_channels=512):
320
- super(CNN_Backbone, self).__init__()
321
- self.cnn = nn.Sequential(
322
- # First block
323
- nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
324
- nn.ReLU(True),
325
- nn.MaxPool2d(kernel_size=2, stride=2), # H: 32 -> 16, W: W_in -> W_in/2
326
-
327
- # Second block
328
- nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
329
- nn.ReLU(True),
330
- nn.MaxPool2d(kernel_size=2, stride=2), # H: 16 -> 8, W: W_in/2 -> W_in/4
331
-
332
- # Third block (with two conv layers)
333
- nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
334
- nn.ReLU(True),
335
- nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
336
- nn.ReLU(True),
337
- # This MaxPool2d effectively brings height from 8 to 4, with a small width adjustment due to padding
338
- # The original comment (W/4 + 1) is due to padding=1 and stride=1 on width, which is fine.
339
- nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # H: 8 -> 4, W: (W/4) -> (W/4 + 1) (approx)
340
-
341
- # Fourth block
342
- nn.Conv2d(256, output_channels, kernel_size=3, stride=1, padding=1),
343
- nn.ReLU(True),
344
- # This AdaptiveAvgPool2d makes sure the height dimension becomes 1
345
- # while preserving the width. This is crucial for RNN input.
346
- nn.AdaptiveAvgPool2d((1, None)) # Output height 1, preserve width
347
- )
348
-
349
- def forward(self, x: torch.Tensor) -> torch.Tensor:
350
- # x: (N, C, H, W) e.g., (B, 1, 32, W_img)
351
-
352
- # Pass through the CNN layers
353
- conv_features = self.cnn(x) # Output: (N, cnn_out_channels, 1, W_prime)
354
-
355
- # Squeeze the height dimension (which is 1)
356
- # This transforms (N, C_out, 1, W_prime) to (N, C_out, W_prime)
357
- conv_features = conv_features.squeeze(2)
358
-
359
- # Permute for RNN input: (sequence_length, batch_size, input_size)
360
- # This transforms (N, C_out, W_prime) to (W_prime, N, C_out)
361
- conv_features = conv_features.permute(2, 0, 1)
362
-
363
- # Return the CNN features, ready for the RNN layer in CRNN
364
- return conv_features
365
-
366
- class BidirectionalLSTM(nn.Module):
367
- """Bidirectional LSTM layer for sequence modeling."""
368
- def __init__(self, input_size: int, hidden_size: int, num_layers: int, dropout: float = 0.5):
369
- super(BidirectionalLSTM, self).__init__()
370
- self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
371
- bidirectional=True, dropout=dropout, batch_first=False)
372
- # batch_first=False expects input as (sequence_length, batch_size, input_size)
373
-
374
- def forward(self, x: torch.Tensor) -> torch.Tensor:
375
- output, _ = self.lstm(x) # [0] returns the output, [1] returns (h_n, c_n)
376
- return output
377
-
378
- class CRNN(nn.Module):
379
- """
380
- Convolutional Recurrent Neural Network for OCR.
381
- Combines CNN for feature extraction, LSTMs for sequence modeling,
382
- and a final linear layer for character prediction.
383
- """
384
- def __init__(self, num_classes: int, cnn_output_channels: int = 512,
385
- rnn_hidden_size: int = 256, rnn_num_layers: int = 2):
386
- super(CRNN, self).__init__()
387
- self.cnn = CNN_Backbone(output_channels=cnn_output_channels)
388
- # Input to LSTM is the number of channels from the CNN output
389
- self.rnn = BidirectionalLSTM(cnn_output_channels, rnn_hidden_size, rnn_num_layers)
390
- # Output of bidirectional LSTM is hidden_size * 2
391
- self.fc = nn.Linear(rnn_hidden_size * 2, num_classes) # Final linear layer for classes
392
-
393
- def forward(self, x: torch.Tensor) -> torch.Tensor:
394
- # x: (N, C, H, W) e.g., (B, 1, 32, W_img)
395
-
396
- # 1. Pass through the CNN to extract features
397
- conv_features = self.cnn(x) # Output: (W_prime, N, C_out) after permute in CNN_Backbone
398
-
399
- # 2. Pass CNN features through the RNN (LSTM)
400
- rnn_features = self.rnn(conv_features) # Output: (W_prime, N, rnn_hidden_size * 2)
401
-
402
- # 3. Pass RNN features through the final fully connected layer
403
- # Apply the linear layer to each time step independently
404
- # output will be (W_prime, N, num_classes)
405
- output = self.fc(rnn_features)
406
-
407
- return output
408
-
409
-
410
- # --- Decoding Function ---
411
- def ctc_greedy_decode(output: torch.Tensor, char_indexer: CharIndexer) -> list[str]:
412
- """
413
- Performs greedy decoding on the CTC output.
414
- output: (sequence_length, batch_size, num_classes) - raw logits
415
- """
416
- # Apply log_softmax to get probabilities for argmax
417
- log_probs = F.log_softmax(output, dim=2)
418
-
419
- # Permute to (batch_size, sequence_length, num_classes) for argmax along class dim
420
- # This gives us the index of the most probable character at each time step for each sample in the batch.
421
- predicted_indices = torch.argmax(log_probs.permute(1, 0, 2), dim=2).cpu().numpy()
422
-
423
- decoded_texts = []
424
- for seq in predicted_indices:
425
- # Use char_indexer's decode method, which handles blank removal and duplicate collapse
426
- decoded_texts.append(char_indexer.decode(seq.tolist())) # Convert numpy array to list
427
- return decoded_texts
428
-
429
- # --- Evaluation Function ---
430
- def evaluate_model(model: nn.Module, dataloader: DataLoader, char_indexer: CharIndexer, device: str):
431
- model.eval() # Set model to evaluation mode
432
- # CTCLoss needs the blank token index, which is available from char_indexer
433
- criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
434
- total_loss = 0
435
- all_predictions = []
436
- all_ground_truths = []
437
-
438
- with torch.no_grad(): # Disable gradient calculation for evaluation
439
- for inputs, targets_padded, _, target_lengths in tqdm(dataloader, desc="Evaluating"):
440
- inputs = inputs.to(device)
441
- targets_padded = targets_padded.to(device)
442
- target_lengths = target_lengths.to(device)
443
-
444
- output = model(inputs) # (seq_len, batch_size, num_classes)
445
-
446
- # Calculate input_lengths for CTCLoss. This is the sequence length produced by the CNN/RNN.
447
- # It's the `output.shape[0]` (sequence_length) for each item in the batch.
448
- outputs_seq_len_for_ctc = torch.full(
449
- size=(output.shape[1],), # batch_size
450
- fill_value=output.shape[0], # actual sequence length (T) from model output
451
- dtype=torch.long,
452
- device=device
453
- )
454
-
455
- # CTC Loss calculation requires log_softmax on the output logits
456
- log_probs_for_loss = F.log_softmax(output, dim=2) # (T, N, C)
457
-
458
- loss = criterion(log_probs_for_loss, targets_padded, outputs_seq_len_for_ctc, target_lengths)
459
- total_loss += loss.item() * inputs.size(0) # Multiply by batch size for correct average
460
-
461
- # Decode predictions for metrics
462
- decoded_preds = ctc_greedy_decode(output, char_indexer)
463
-
464
- # Reconstruct ground truths from encoded tensors
465
- ground_truths = []
466
- # Loop through each sample in the batch
467
- for i in range(targets_padded.size(0)):
468
- # Extract the actual target sequence for the i-th sample using its length
469
- # Convert to list before passing to char_indexer.decode
470
- ground_truths.append(char_indexer.decode(targets_padded[i, :target_lengths[i]].tolist()))
471
-
472
- all_predictions.extend(decoded_preds)
473
- all_ground_truths.extend(ground_truths)
474
-
475
- avg_loss = total_loss / len(dataloader.dataset)
476
-
477
- # Calculate Character Error Rate (CER)
478
- cer_sum = 0
479
- total_chars = 0
480
- for pred, gt in zip(all_predictions, all_ground_truths):
481
- cer_sum += editdistance.eval(pred, gt)
482
- total_chars += len(gt)
483
- char_error_rate = cer_sum / total_chars if total_chars > 0 else 0.0
484
-
485
- # Calculate Exact Match Accuracy (Word-level Accuracy)
486
- exact_match_accuracy = accuracy_score(all_ground_truths, all_predictions)
487
-
488
- return avg_loss, char_error_rate, exact_match_accuracy
489
-
490
- # --- Training Function ---
491
- def train_ocr_model(model: nn.Module, train_loader: DataLoader,
492
- test_loader: DataLoader, char_indexer: CharIndexer,
493
- epochs: int, device: str, progress_callback=None) -> tuple[nn.Module, dict]:
494
- """
495
- Trains the OCR model using CTC loss.
496
- """
497
- # CTCLoss needs the blank token index
498
- criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
499
- optimizer = optim.Adam(model.parameters(), lr=0.001) # Using a fixed LR for now
500
- # Using ReduceLROnPlateau to adjust LR based on test loss (monitor 'min' loss)
501
- scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5)
502
-
503
- model.to(device) # Ensure model is on the correct device
504
- model.train() # Set model to training mode
505
-
506
- training_history = {
507
- 'train_loss': [],
508
- 'test_loss': [],
509
- 'test_cer': [],
510
- 'test_exact_match_accuracy': []
511
- }
512
-
513
- for epoch in range(epochs):
514
- running_loss = 0.0
515
- pbar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)")
516
- for images, texts_encoded, _, text_lengths in pbar_train:
517
- images = images.to(device)
518
- # Ensure target tensors are on the correct device for CTCLoss calculation
519
- texts_encoded = texts_encoded.to(device)
520
- text_lengths = text_lengths.to(device)
521
-
522
- optimizer.zero_grad() # Clear gradients from previous step
523
- outputs = model(images) # (sequence_length_from_cnn, batch_size, num_classes)
524
-
525
- # `outputs.shape[0]` is the actual sequence length (T) produced by the model.
526
- # CTC loss expects `input_lengths` to be a tensor of shape (batch_size,) with these values.
527
- outputs_seq_len_for_ctc = torch.full(
528
- size=(outputs.shape[1],), # batch_size
529
- fill_value=outputs.shape[0], # actual sequence length (T) from model output
530
- dtype=torch.long,
531
- device=device
532
- )
533
-
534
- # CTC Loss calculation requires log_softmax on the output logits
535
- log_probs_for_loss = F.log_softmax(outputs, dim=2) # (T, N, C)
536
-
537
- # Use outputs_seq_len_for_ctc for the input_lengths argument
538
- loss = criterion(log_probs_for_loss, texts_encoded, outputs_seq_len_for_ctc, text_lengths)
539
- loss.backward() # Backpropagate
540
- optimizer.step() # Update model weights
541
-
542
- running_loss += loss.item() * images.size(0) # Multiply by batch size for correct average
543
- pbar_train.set_postfix(loss=loss.item())
544
-
545
- epoch_train_loss = running_loss / len(train_loader.dataset)
546
- training_history['train_loss'].append(epoch_train_loss)
547
-
548
- # Evaluate on test set using the dedicated function
549
- # Ensure model is in eval mode before calling evaluate_model
550
- model.eval()
551
- test_loss, test_cer, test_exact_match_accuracy = evaluate_model(model, test_loader, char_indexer, device)
552
- training_history['test_loss'].append(test_loss)
553
- training_history['test_cer'].append(test_cer)
554
- training_history['test_exact_match_accuracy'].append(test_exact_match_accuracy)
555
-
556
- # Adjust learning rate based on test loss (this is where scheduler.step() is called)
557
- scheduler.step(test_loss)
558
-
559
- print(f"Epoch {epoch+1}/{epochs}: Train Loss={epoch_train_loss:.4f}, "
560
- f"Test Loss={test_loss:.4f}, Test CER={test_cer:.4f}, Test Exact Match Acc={test_exact_match_accuracy:.4f}")
561
-
562
- if progress_callback:
563
- # Update progress bar with current epoch and key metrics
564
- progress_val = (epoch + 1) / epochs
565
- progress_callback(progress_val, text=f"Epoch {epoch+1}/{epochs} done. Test CER: {test_cer:.4f}, Test Exact Match Acc: {test_exact_match_accuracy:.4f}")
566
-
567
- model.train() # Set model back to training mode after evaluation
568
-
569
- return model, training_history
570
-
571
- def save_ocr_model(model: nn.Module, path: str):
572
- """Saves the state dictionary of the trained OCR model."""
573
- torch.save(model.state_dict(), path)
574
- print(f"OCR model saved to {path}")
575
-
576
- def load_ocr_model(model: nn.Module, path: str):
577
- """
578
- Loads a trained OCR model's state dictionary.
579
- Includes map_location to handle loading models trained on GPU to CPU, and vice versa.
580
- """
581
- model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) # Always load to CPU first
582
- model.eval() # Set to evaluation mode
583
- >>>>>>> ee59e5b21399d8b323cff452a961ea2fd6c65308
584
- print(f"OCR model loaded from {path}")
 
 
1
  # model_ocr.py
2
 
3
  import torch
 
10
  import editdistance
11
 
12
  # Import config and char_indexer
 
13
  from config import IMG_HEIGHT, NUM_CLASSES, BLANK_TOKEN
14
  from data_handler_ocr import CharIndexer
15
+ from utils_ocr import binarize_image, resize_image_for_ocr, normalize_image_for_model
 
 
 
16
 
17
 
18
  class CNN_Backbone(nn.Module):
 
39
  nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
40
  nn.ReLU(True),
41
  # This MaxPool2d effectively brings height from 8 to 4, with a small width adjustment due to padding
 
42
  nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # H: 8 -> 4, W: (W/4) -> (W/4 + 1) (approx)
43
 
44
  # Fourth block
 
91
  # Input to LSTM is the number of channels from the CNN output
92
  self.rnn = BidirectionalLSTM(cnn_output_channels, rnn_hidden_size, rnn_num_layers)
93
  # Output of bidirectional LSTM is hidden_size * 2
94
+ self.fc = nn.Linear(rnn_hidden_size * 2, num_classes)
95
 
96
  def forward(self, x: torch.Tensor) -> torch.Tensor:
97
  # x: (N, C, H, W) e.g., (B, 1, 32, W_img)
 
201
  criterion = nn.CTCLoss(blank=char_indexer.blank_token_idx, zero_infinity=True)
202
  optimizer = optim.Adam(model.parameters(), lr=0.001) # Using a fixed LR for now
203
  # Using ReduceLROnPlateau to adjust LR based on test loss (monitor 'min' loss)
204
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5) # Removed verbose=True
205
 
206
  model.to(device) # Ensure model is on the correct device
207
  model.train() # Set model to training mode
 
281
  Loads a trained OCR model's state dictionary.
282
  Includes map_location to handle loading models trained on GPU to CPU, and vice versa.
283
  """
284
+ model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
285
  model.eval() # Set to evaluation mode
286
+ print(f"OCR model loaded from {path}")