Spaces:
Build error
Build error
File size: 3,312 Bytes
8900f0a 15dba6b 8900f0a 15dba6b 8900f0a 15dba6b 8900f0a 15dba6b 8900f0a 15dba6b 8900f0a 15dba6b 8900f0a 15dba6b 8900f0a 15dba6b 8900f0a 15dba6b 8900f0a 15dba6b 8900f0a 15dba6b 8900f0a 15dba6b 8900f0a 15dba6b 8900f0a 15dba6b 8900f0a 15dba6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
#utils_ocr.py
import cv2
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
import os
# Import config for IMG_HEIGHT and MAX_IMG_WIDTH
from config import IMG_HEIGHT, MAX_IMG_WIDTH
# --- Image Preprocessing Functions ---
def load_image_as_grayscale(image_path: str) -> Image.Image:
"""Loads an image from path and converts it to grayscale PIL Image."""
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found at: {image_path}")
return Image.open(image_path).convert('L') # 'L' for grayscale
def binarize_image(img: Image.Image) -> Image.Image:
"""
Binarizes a grayscale PIL Image using Otsu's method.
Returns a PIL Image.
"""
# Convert PIL Image to OpenCV format (numpy array)
img_np = np.array(img)
# Apply Otsu's binarization
_, binary_img = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
# Convert back to PIL Image
return Image.fromarray(binary_img)
def resize_image_for_ocr(img: Image.Image, img_height: int) -> Image.Image:
"""
Resizes a PIL Image to a fixed height while maintaining aspect ratio.
Also ensures the width does not exceed MAX_IMG_WIDTH.
"""
width, height = img.size
# Calculate new width based on target height, maintaining aspect ratio
new_width = int(width * (img_height / height))
if new_width > MAX_IMG_WIDTH:
new_width = MAX_IMG_WIDTH
resized_img = img.resize((new_width, img_height), Image.Resampling.LANCZOS)
if resized_img.width > MAX_IMG_WIDTH:
# Crop the image from the left to MAX_IMG_WIDTH
resized_img = resized_img.crop((0, 0, MAX_IMG_WIDTH, img_height))
return resized_img
return img.resize((new_width, img_height), Image.Resampling.LANCZOS) # Use LANCZOS for high-quality downsampling
def normalize_image_for_model(img_tensor: torch.Tensor) -> torch.Tensor:
"""
Normalizes a torch.Tensor image (grayscale) for input into the model.
Puts pixel values in range [-1, 1].
Assumes image is already a torch.Tensor with values in [0, 1] (e.g., after ToTensor).
"""
# Formula: (pixel_value - mean) / std_dev
# For [0, 1] to [-1, 1], mean = 0.5, std_dev = 0.5
img_tensor = (img_tensor - 0.5) / 0.5
return img_tensor
def preprocess_user_image_for_ocr(image_pil: Image.Image, target_height: int) -> torch.Tensor:
"""
Applies all necessary preprocessing steps to a user-uploaded PIL Image
to prepare it for the OCR model.
"""
# Define a transformation pipeline similar to the dataset, but including ToTensor
transform_pipeline = transforms.Compose([
transforms.Lambda(lambda img: binarize_image(img)), # PIL Image -> PIL Image
# Use the updated resize function that also handles MAX_IMG_WIDTH
transforms.Lambda(lambda img: resize_image_for_ocr(img, target_height)), # PIL Image -> PIL Image
transforms.ToTensor(), # PIL Image -> Tensor [0, 1]
transforms.Lambda(normalize_image_for_model) # Tensor [0, 1] -> Tensor [-1, 1]
])
processed_image = transform_pipeline(image_pil)
# Add a batch dimension (C, H, W) -> (1, C, H, W) for single image inference
return processed_image.unsqueeze(0) |