File size: 3,312 Bytes
8900f0a
 
 
 
 
 
15dba6b
 
8900f0a
15dba6b
 
 
 
8900f0a
 
 
15dba6b
 
 
8900f0a
15dba6b
8900f0a
15dba6b
 
8900f0a
15dba6b
 
 
 
 
 
 
 
8900f0a
15dba6b
8900f0a
15dba6b
 
8900f0a
15dba6b
 
 
 
 
 
 
 
 
 
 
 
 
 
8900f0a
15dba6b
8900f0a
15dba6b
 
 
8900f0a
15dba6b
 
 
8900f0a
 
15dba6b
8900f0a
15dba6b
 
8900f0a
15dba6b
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#utils_ocr.py

import cv2
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
import os

# Import config for IMG_HEIGHT and MAX_IMG_WIDTH
from config import IMG_HEIGHT, MAX_IMG_WIDTH

# --- Image Preprocessing Functions ---

def load_image_as_grayscale(image_path: str) -> Image.Image:
    """Loads an image from path and converts it to grayscale PIL Image."""
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found at: {image_path}")
    return Image.open(image_path).convert('L') # 'L' for grayscale

def binarize_image(img: Image.Image) -> Image.Image:
    """
    Binarizes a grayscale PIL Image using Otsu's method.
    Returns a PIL Image.
    """
    # Convert PIL Image to OpenCV format (numpy array)
    img_np = np.array(img)
    
    # Apply Otsu's binarization
    _, binary_img = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    
    # Convert back to PIL Image
    return Image.fromarray(binary_img)

def resize_image_for_ocr(img: Image.Image, img_height: int) -> Image.Image:
    """
    Resizes a PIL Image to a fixed height while maintaining aspect ratio.
    Also ensures the width does not exceed MAX_IMG_WIDTH.
    """
    width, height = img.size
    
    # Calculate new width based on target height, maintaining aspect ratio
    new_width = int(width * (img_height / height))
    
    if new_width > MAX_IMG_WIDTH:
        new_width = MAX_IMG_WIDTH
        resized_img = img.resize((new_width, img_height), Image.Resampling.LANCZOS)
        if resized_img.width > MAX_IMG_WIDTH:
            # Crop the image from the left to MAX_IMG_WIDTH
            resized_img = resized_img.crop((0, 0, MAX_IMG_WIDTH, img_height))
        return resized_img
    
    return img.resize((new_width, img_height), Image.Resampling.LANCZOS) # Use LANCZOS for high-quality downsampling

def normalize_image_for_model(img_tensor: torch.Tensor) -> torch.Tensor:
    """
    Normalizes a torch.Tensor image (grayscale) for input into the model.
    Puts pixel values in range [-1, 1].
    Assumes image is already a torch.Tensor with values in [0, 1] (e.g., after ToTensor).
    """
    # Formula: (pixel_value - mean) / std_dev
    # For [0, 1] to [-1, 1], mean = 0.5, std_dev = 0.5
    img_tensor = (img_tensor - 0.5) / 0.5 
    return img_tensor

def preprocess_user_image_for_ocr(image_pil: Image.Image, target_height: int) -> torch.Tensor:
    """
    Applies all necessary preprocessing steps to a user-uploaded PIL Image
    to prepare it for the OCR model.
    """
    # Define a transformation pipeline similar to the dataset, but including ToTensor
    transform_pipeline = transforms.Compose([
        transforms.Lambda(lambda img: binarize_image(img)), # PIL Image -> PIL Image
        # Use the updated resize function that also handles MAX_IMG_WIDTH
        transforms.Lambda(lambda img: resize_image_for_ocr(img, target_height)), # PIL Image -> PIL Image
        transforms.ToTensor(), # PIL Image -> Tensor [0, 1]
        transforms.Lambda(normalize_image_for_model) # Tensor [0, 1] -> Tensor [-1, 1]
    ])
    
    processed_image = transform_pipeline(image_pil)
    
    # Add a batch dimension (C, H, W) -> (1, C, H, W) for single image inference
    return processed_image.unsqueeze(0)