Spaces:
Build error
Build error
#utils_ocr.py | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torchvision.transforms as transforms | |
import os | |
# Import config for IMG_HEIGHT and MAX_IMG_WIDTH | |
from config import IMG_HEIGHT, MAX_IMG_WIDTH | |
# --- Image Preprocessing Functions --- | |
def load_image_as_grayscale(image_path: str) -> Image.Image: | |
"""Loads an image from path and converts it to grayscale PIL Image.""" | |
if not os.path.exists(image_path): | |
raise FileNotFoundError(f"Image not found at: {image_path}") | |
return Image.open(image_path).convert('L') # 'L' for grayscale | |
def binarize_image(img: Image.Image) -> Image.Image: | |
""" | |
Binarizes a grayscale PIL Image using Otsu's method. | |
Returns a PIL Image. | |
""" | |
# Convert PIL Image to OpenCV format (numpy array) | |
img_np = np.array(img) | |
# Apply Otsu's binarization | |
_, binary_img = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) | |
# Convert back to PIL Image | |
return Image.fromarray(binary_img) | |
def resize_image_for_ocr(img: Image.Image, img_height: int) -> Image.Image: | |
""" | |
Resizes a PIL Image to a fixed height while maintaining aspect ratio. | |
Also ensures the width does not exceed MAX_IMG_WIDTH. | |
""" | |
width, height = img.size | |
# Calculate new width based on target height, maintaining aspect ratio | |
new_width = int(width * (img_height / height)) | |
if new_width > MAX_IMG_WIDTH: | |
new_width = MAX_IMG_WIDTH | |
resized_img = img.resize((new_width, img_height), Image.Resampling.LANCZOS) | |
if resized_img.width > MAX_IMG_WIDTH: | |
# Crop the image from the left to MAX_IMG_WIDTH | |
resized_img = resized_img.crop((0, 0, MAX_IMG_WIDTH, img_height)) | |
return resized_img | |
return img.resize((new_width, img_height), Image.Resampling.LANCZOS) # Use LANCZOS for high-quality downsampling | |
def normalize_image_for_model(img_tensor: torch.Tensor) -> torch.Tensor: | |
""" | |
Normalizes a torch.Tensor image (grayscale) for input into the model. | |
Puts pixel values in range [-1, 1]. | |
Assumes image is already a torch.Tensor with values in [0, 1] (e.g., after ToTensor). | |
""" | |
# Formula: (pixel_value - mean) / std_dev | |
# For [0, 1] to [-1, 1], mean = 0.5, std_dev = 0.5 | |
img_tensor = (img_tensor - 0.5) / 0.5 | |
return img_tensor | |
def preprocess_user_image_for_ocr(image_pil: Image.Image, target_height: int) -> torch.Tensor: | |
""" | |
Applies all necessary preprocessing steps to a user-uploaded PIL Image | |
to prepare it for the OCR model. | |
""" | |
# Define a transformation pipeline similar to the dataset, but including ToTensor | |
transform_pipeline = transforms.Compose([ | |
transforms.Lambda(lambda img: binarize_image(img)), # PIL Image -> PIL Image | |
# Use the updated resize function that also handles MAX_IMG_WIDTH | |
transforms.Lambda(lambda img: resize_image_for_ocr(img, target_height)), # PIL Image -> PIL Image | |
transforms.ToTensor(), # PIL Image -> Tensor [0, 1] | |
transforms.Lambda(normalize_image_for_model) # Tensor [0, 1] -> Tensor [-1, 1] | |
]) | |
processed_image = transform_pipeline(image_pil) | |
# Add a batch dimension (C, H, W) -> (1, C, H, W) for single image inference | |
return processed_image.unsqueeze(0) |