marianeft's picture
Training Model Complete
15dba6b verified
#utils_ocr.py
import cv2
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
import os
# Import config for IMG_HEIGHT and MAX_IMG_WIDTH
from config import IMG_HEIGHT, MAX_IMG_WIDTH
# --- Image Preprocessing Functions ---
def load_image_as_grayscale(image_path: str) -> Image.Image:
"""Loads an image from path and converts it to grayscale PIL Image."""
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found at: {image_path}")
return Image.open(image_path).convert('L') # 'L' for grayscale
def binarize_image(img: Image.Image) -> Image.Image:
"""
Binarizes a grayscale PIL Image using Otsu's method.
Returns a PIL Image.
"""
# Convert PIL Image to OpenCV format (numpy array)
img_np = np.array(img)
# Apply Otsu's binarization
_, binary_img = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
# Convert back to PIL Image
return Image.fromarray(binary_img)
def resize_image_for_ocr(img: Image.Image, img_height: int) -> Image.Image:
"""
Resizes a PIL Image to a fixed height while maintaining aspect ratio.
Also ensures the width does not exceed MAX_IMG_WIDTH.
"""
width, height = img.size
# Calculate new width based on target height, maintaining aspect ratio
new_width = int(width * (img_height / height))
if new_width > MAX_IMG_WIDTH:
new_width = MAX_IMG_WIDTH
resized_img = img.resize((new_width, img_height), Image.Resampling.LANCZOS)
if resized_img.width > MAX_IMG_WIDTH:
# Crop the image from the left to MAX_IMG_WIDTH
resized_img = resized_img.crop((0, 0, MAX_IMG_WIDTH, img_height))
return resized_img
return img.resize((new_width, img_height), Image.Resampling.LANCZOS) # Use LANCZOS for high-quality downsampling
def normalize_image_for_model(img_tensor: torch.Tensor) -> torch.Tensor:
"""
Normalizes a torch.Tensor image (grayscale) for input into the model.
Puts pixel values in range [-1, 1].
Assumes image is already a torch.Tensor with values in [0, 1] (e.g., after ToTensor).
"""
# Formula: (pixel_value - mean) / std_dev
# For [0, 1] to [-1, 1], mean = 0.5, std_dev = 0.5
img_tensor = (img_tensor - 0.5) / 0.5
return img_tensor
def preprocess_user_image_for_ocr(image_pil: Image.Image, target_height: int) -> torch.Tensor:
"""
Applies all necessary preprocessing steps to a user-uploaded PIL Image
to prepare it for the OCR model.
"""
# Define a transformation pipeline similar to the dataset, but including ToTensor
transform_pipeline = transforms.Compose([
transforms.Lambda(lambda img: binarize_image(img)), # PIL Image -> PIL Image
# Use the updated resize function that also handles MAX_IMG_WIDTH
transforms.Lambda(lambda img: resize_image_for_ocr(img, target_height)), # PIL Image -> PIL Image
transforms.ToTensor(), # PIL Image -> Tensor [0, 1]
transforms.Lambda(normalize_image_for_model) # Tensor [0, 1] -> Tensor [-1, 1]
])
processed_image = transform_pipeline(image_pil)
# Add a batch dimension (C, H, W) -> (1, C, H, W) for single image inference
return processed_image.unsqueeze(0)