historical-ocr / image_segmentation.py
milwright's picture
Fix OCR processing variable scope issue by using explicit module reference for apply_preprocessing_to_file
3dd2ff2
"""
Image segmentation utility for OCR preprocessing.
Separates text regions from image regions to improve OCR accuracy on mixed-content documents.
Uses content-aware adaptive segmentation for improved results across document types.
"""
import cv2
import numpy as np
from PIL import Image
import io
import base64
import logging
from pathlib import Path
from typing import Tuple, List, Dict, Union, Optional
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def segment_image_for_ocr(image_path: Union[str, Path], vision_enabled: bool = True, preserve_content: bool = True) -> Dict[str, Union[Image.Image, str]]:
"""
Prepare image for OCR processing using content-aware segmentation.
Uses adaptive region detection based on text density analysis.
Args:
image_path: Path to the image file
vision_enabled: Whether the vision model is enabled
preserve_content: Whether to preserve original content without enhancement
Returns:
Dict containing segmentation results
"""
# Convert to Path object if string
image_file = Path(image_path) if isinstance(image_path, str) else image_path
# Log start of processing
logger.info(f"Preparing image for Mistral OCR: {image_file.name}")
try:
# Open original image with PIL
with Image.open(image_file) as pil_img:
# Check for low entropy images when vision is disabled
if not vision_enabled:
from utils.image_utils import calculate_image_entropy
ent = calculate_image_entropy(pil_img)
if ent < 3.5: # Likely line-art or blank page
logger.info(f"Low entropy image detected ({ent:.2f}), classifying as illustration")
return {
'text_regions': None,
'image_regions': pil_img,
'text_mask_base64': None,
'combined_result': None,
'text_regions_coordinates': []
}
# Convert to RGB if needed
if pil_img.mode != 'RGB':
pil_img = pil_img.convert('RGB')
# Get image dimensions
img_np = np.array(pil_img)
img_width, img_height = pil_img.size
# Analyze text density to determine if advanced segmentation is needed
# This replaces document-specific logic with content-aware analysis
from utils.image_utils import estimate_text_density
text_density = estimate_text_density(img_np)
# Use adaptive approach for documents with unusual text distribution
if text_density['pattern'] == 'varied' or text_density['uppercase_sections'] > 0:
logger.info(f"Using adaptive segmentation for document with varied text density pattern={text_density['pattern']}, uppercase_sections={text_density['uppercase_sections']}")
# Detect content regions based on text density
from utils.text_utils import detect_content_regions
regions = detect_content_regions(img_np)
# Create visualization with green borders around the text regions
vis_img = img_np.copy()
# Draw regions on visualization
for x, y, w, h in regions:
cv2.rectangle(vis_img, (x, y), (x+w, y+h), (0, 255, 0), 3)
# Add text to indicate we're using adaptive processing
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(vis_img, "Adaptive region processing", (30, 60), font, 1, (0, 255, 0), 2)
# Create visualization images
text_regions_vis = Image.fromarray(vis_img)
image_regions_vis = text_regions_vis.copy()
# Create a mask highlighting the text regions
text_mask = np.zeros((img_height, img_width), dtype=np.uint8)
for x, y, w, h in regions:
text_mask[y:y+h, x:x+w] = 255
_, buffer = cv2.imencode('.png', text_mask)
text_mask_base64 = base64.b64encode(buffer).decode('utf-8')
# Extract region images
region_images = []
for i, (x, y, w, h) in enumerate(regions):
region = img_np[y:y+h, x:x+w].copy()
region_pil = Image.fromarray(region)
region_info = {
'image': region,
'pil_image': region_pil,
'coordinates': (x, y, w, h),
'padded_coordinates': (x, y, w, h),
'order': i
}
region_images.append(region_info)
# Return the adaptive segmentation results
return {
'text_regions': text_regions_vis,
'image_regions': image_regions_vis,
'text_mask_base64': f"data:image/png;base64,{text_mask_base64}",
'combined_result': pil_img,
'text_regions_coordinates': regions,
'region_images': region_images,
'segmentation_type': 'adaptive'
}
else:
# SIMPLIFIED APPROACH for most documents
# Let Mistral OCR handle the entire document understanding process
logger.info(f"Using standard approach for document with uniform text density")
# For visualization, mark the entire image as a text region
full_image_region = [(0, 0, img_width, img_height)]
# Create visualization with a simple border
vis_img = img_np.copy()
cv2.rectangle(vis_img, (5, 5), (img_width-5, img_height-5), (0, 255, 0), 5)
# Add text to indicate this is using Mistral's native processing
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(vis_img, "Processed by Mistral OCR", (30, 60), font, 1, (0, 255, 0), 2)
# Create visualizations and masks
text_regions_vis = Image.fromarray(vis_img)
image_regions_vis = text_regions_vis.copy()
# Create a mask of the entire image (just for visualization)
text_mask = np.ones((img_height, img_width), dtype=np.uint8) * 255
_, buffer = cv2.imencode('.png', text_mask)
text_mask_base64 = base64.b64encode(buffer).decode('utf-8')
# Return the original image as the combined result
return {
'text_regions': text_regions_vis,
'image_regions': image_regions_vis,
'text_mask_base64': f"data:image/png;base64,{text_mask_base64}",
'combined_result': pil_img,
'text_regions_coordinates': full_image_region,
'region_images': [{
'image': img_np,
'pil_image': pil_img,
'coordinates': (0, 0, img_width, img_height),
'padded_coordinates': (0, 0, img_width, img_height),
'order': 0
}],
'segmentation_type': 'simplified'
}
except Exception as e:
logger.error(f"Error segmenting image {image_file.name}: {str(e)}")
# Return None values if processing fails
return {
'text_regions': None,
'image_regions': None,
'text_mask_base64': None,
'combined_result': None,
'text_regions_coordinates': []
}
def process_segmented_image(image_path: Union[str, Path], output_dir: Optional[Path] = None, preserve_content: bool = True) -> Dict:
"""
Process an image using segmentation for improved OCR, saving visualization outputs.
Args:
image_path: Path to the image file
output_dir: Optional directory to save visualization outputs
Returns:
Dictionary with processing results and paths to output files
"""
# Convert to Path object if string
image_file = Path(image_path) if isinstance(image_path, str) else image_path
# Create output directory if not provided
if output_dir is None:
output_dir = Path("output") / "segmentation"
output_dir.mkdir(parents=True, exist_ok=True)
# Process the image with segmentation
segmentation_results = segment_image_for_ocr(image_file)
# Prepare results dictionary
results = {
'original_image': str(image_file),
'output_files': {}
}
# Save visualization outputs if segmentation was successful
if segmentation_results['text_regions'] is not None:
# Save text regions visualization
text_regions_path = output_dir / f"{image_file.stem}_text_regions.jpg"
segmentation_results['text_regions'].save(text_regions_path)
results['output_files']['text_regions'] = str(text_regions_path)
# Save image regions visualization
image_regions_path = output_dir / f"{image_file.stem}_image_regions.jpg"
segmentation_results['image_regions'].save(image_regions_path)
results['output_files']['image_regions'] = str(image_regions_path)
# Save combined result
combined_path = output_dir / f"{image_file.stem}_combined.jpg"
segmentation_results['combined_result'].save(combined_path)
results['output_files']['combined_result'] = str(combined_path)
# Save text mask visualization
text_mask_path = output_dir / f"{image_file.stem}_text_mask.png"
# Save text mask from base64
if segmentation_results['text_mask_base64']:
base64_data = segmentation_results['text_mask_base64'].split(',')[1]
with open(text_mask_path, 'wb') as f:
f.write(base64.b64decode(base64_data))
results['output_files']['text_mask'] = str(text_mask_path)
# Add detected text regions count
results['text_regions_count'] = len(segmentation_results['text_regions_coordinates'])
results['text_regions_coordinates'] = segmentation_results['text_regions_coordinates']
return results
if __name__ == "__main__":
# Simple test - process a sample image if run directly
import sys
if len(sys.argv) > 1:
image_path = sys.argv[1]
else:
image_path = "input/handwritten-journal.jpg" # Example image path"
logger.info(f"Testing image segmentation on {image_path}")
results = process_segmented_image(image_path)
# Print results summary
logger.info(f"Segmentation complete. Found {results.get('text_regions_count', 0)} text regions.")
logger.info(f"Output files saved to: {[path for path in results.get('output_files', {}).values()]}")