historical-ocr / image_segmentation.py
milwright's picture
Save current segmentation approach before refactoring
73375a3
raw
history blame
13 kB
"""
Image segmentation utility for OCR preprocessing.
Separates text regions from image regions to improve OCR accuracy on mixed-content documents.
Based on Mistral AI cookbook examples.
"""
import cv2
import numpy as np
from PIL import Image
import io
import base64
import logging
from pathlib import Path
from typing import Tuple, List, Dict, Union, Optional
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def determine_segmentation_approach(image_path: Union[str, Path]) -> str:
"""
Determine which segmentation approach to use based on the document type.
Args:
image_path: Path to the image file
Returns:
str: Segmentation approach to use ('simplified' or 'original')
"""
# Convert to string for easier pattern matching
filename = str(image_path).lower()
# Document-specific rules based on testing results
if "baldwin" in filename and "north" in filename:
# Baldwin documents showed better results with original approach
return "original"
# Default to our simplified approach for most documents
return "simplified"
def segment_image_for_ocr(image_path: Union[str, Path], vision_enabled: bool = True, preserve_content: bool = True) -> Dict[str, Union[Image.Image, str]]:
"""
Prepare image for OCR processing using the most appropriate segmentation approach.
For most documents, this uses a minimal approach that trusts Mistral OCR
to handle document understanding and layout analysis. For specific document types
that benefit from custom segmentation, a document-specific approach is used.
Args:
image_path: Path to the image file
vision_enabled: Whether the vision model is enabled
preserve_content: Whether to preserve original content without enhancement
Returns:
Dict containing segmentation results
"""
# Convert to Path object if string
image_file = Path(image_path) if isinstance(image_path, str) else image_path
# Determine the segmentation approach to use
approach = determine_segmentation_approach(image_file)
# Log start of processing
logger.info(f"Preparing image for Mistral OCR: {image_file.name} (using {approach} approach)")
try:
# Open original image with PIL
with Image.open(image_file) as pil_img:
# Check for low entropy images when vision is disabled
if not vision_enabled:
from utils.image_utils import calculate_image_entropy
ent = calculate_image_entropy(pil_img)
if ent < 3.5: # Likely line-art or blank page
logger.info(f"Low entropy image detected ({ent:.2f}), classifying as illustration")
return {
'text_regions': None,
'image_regions': pil_img,
'text_mask_base64': None,
'combined_result': None,
'text_regions_coordinates': []
}
# Convert to RGB if needed
if pil_img.mode != 'RGB':
pil_img = pil_img.convert('RGB')
# Get image dimensions
img_np = np.array(pil_img)
img_width, img_height = pil_img.size
# Apply the appropriate segmentation approach based on the document type
if approach == "simplified":
# SIMPLIFIED APPROACH for most documents:
# Let Mistral OCR handle the entire document understanding process
# For visualization, mark the entire image as a text region
full_image_region = [(0, 0, img_width, img_height)]
# Create visualization with a simple border
vis_img = img_np.copy()
cv2.rectangle(vis_img, (5, 5), (img_width-5, img_height-5), (0, 255, 0), 5)
# Add text to indicate this is using Mistral's native processing
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(vis_img, "Processed by Mistral OCR", (30, 60), font, 1, (0, 255, 0), 2)
# Create visualizations and masks
text_regions_vis = Image.fromarray(vis_img)
image_regions_vis = text_regions_vis.copy()
# Create a mask of the entire image (just for visualization)
text_mask = np.ones((img_height, img_width), dtype=np.uint8) * 255
_, buffer = cv2.imencode('.png', text_mask)
text_mask_base64 = base64.b64encode(buffer).decode('utf-8')
# Return the original image as the combined result
return {
'text_regions': text_regions_vis,
'image_regions': image_regions_vis,
'text_mask_base64': f"data:image/png;base64,{text_mask_base64}",
'combined_result': pil_img,
'text_regions_coordinates': full_image_region,
'region_images': [{
'image': img_np,
'pil_image': pil_img,
'coordinates': (0, 0, img_width, img_height),
'padded_coordinates': (0, 0, img_width, img_height),
'order': 0
}]
}
else:
# DOCUMENT-SPECIFIC APPROACH for baldwin-north and similar documents
# Use more structured segmentation with customized region detection
# This approach is preferred for documents that showed better results in testing
# Create a visualization with green borders around the text regions
vis_img = img_np.copy()
# For baldwin-north type documents, create a more granular segmentation
# Define regions with more detailed segmentation for better text capture
# Use 3 overlapping regions instead of 2 distinct ones
# Define header, middle, and body sections with overlap
header_height = int(img_height * 0.3) # Top 30% as header (increased from 25%)
middle_start = int(img_height * 0.2) # Start middle section with overlap
middle_height = int(img_height * 0.4) # Middle 40%
body_start = int(img_height * 0.5) # Start body with overlap
body_height = img_height - body_start # Remaining height
# Define regions with overlap to ensure no text is missed
regions = [
(0, 0, img_width, header_height), # Header region
(0, middle_start, img_width, middle_height), # Middle region with overlap
(0, body_start, img_width, body_height) # Body region with overlap
]
# Draw regions on visualization
for x, y, w, h in regions:
cv2.rectangle(vis_img, (x, y), (x+w, y+h), (0, 255, 0), 3)
# Add text to indicate we're using the document-specific approach
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(vis_img, "Document-specific processing", (30, 60), font, 1, (0, 255, 0), 2)
# Create visualization images
text_regions_vis = Image.fromarray(vis_img)
image_regions_vis = text_regions_vis.copy()
# Create a mask highlighting the text regions
text_mask = np.zeros((img_height, img_width), dtype=np.uint8)
for x, y, w, h in regions:
text_mask[y:y+h, x:x+w] = 255
_, buffer = cv2.imencode('.png', text_mask)
text_mask_base64 = base64.b64encode(buffer).decode('utf-8')
# Extract region images
region_images = []
for i, (x, y, w, h) in enumerate(regions):
region = img_np[y:y+h, x:x+w].copy()
region_pil = Image.fromarray(region)
region_info = {
'image': region,
'pil_image': region_pil,
'coordinates': (x, y, w, h),
'padded_coordinates': (x, y, w, h),
'order': i
}
region_images.append(region_info)
# Return the structured segmentation results
return {
'text_regions': text_regions_vis,
'image_regions': image_regions_vis,
'text_mask_base64': f"data:image/png;base64,{text_mask_base64}",
'combined_result': pil_img,
'text_regions_coordinates': regions,
'region_images': region_images
}
except Exception as e:
logger.error(f"Error segmenting image {image_file.name}: {str(e)}")
# Return None values if processing fails
return {
'text_regions': None,
'image_regions': None,
'text_mask_base64': None,
'combined_result': None,
'text_regions_coordinates': []
}
def process_segmented_image(image_path: Union[str, Path], output_dir: Optional[Path] = None, preserve_content: bool = True) -> Dict:
"""
Process an image using segmentation for improved OCR, saving visualization outputs.
Args:
image_path: Path to the image file
output_dir: Optional directory to save visualization outputs
Returns:
Dictionary with processing results and paths to output files
"""
# Convert to Path object if string
image_file = Path(image_path) if isinstance(image_path, str) else image_path
# Create output directory if not provided
if output_dir is None:
output_dir = Path("output") / "segmentation"
output_dir.mkdir(parents=True, exist_ok=True)
# Process the image with segmentation
segmentation_results = segment_image_for_ocr(image_file)
# Prepare results dictionary
results = {
'original_image': str(image_file),
'output_files': {}
}
# Save visualization outputs if segmentation was successful
if segmentation_results['text_regions'] is not None:
# Save text regions visualization
text_regions_path = output_dir / f"{image_file.stem}_text_regions.jpg"
segmentation_results['text_regions'].save(text_regions_path)
results['output_files']['text_regions'] = str(text_regions_path)
# Save image regions visualization
image_regions_path = output_dir / f"{image_file.stem}_image_regions.jpg"
segmentation_results['image_regions'].save(image_regions_path)
results['output_files']['image_regions'] = str(image_regions_path)
# Save combined result
combined_path = output_dir / f"{image_file.stem}_combined.jpg"
segmentation_results['combined_result'].save(combined_path)
results['output_files']['combined_result'] = str(combined_path)
# Save text mask visualization
text_mask_path = output_dir / f"{image_file.stem}_text_mask.png"
# Save text mask from base64
if segmentation_results['text_mask_base64']:
base64_data = segmentation_results['text_mask_base64'].split(',')[1]
with open(text_mask_path, 'wb') as f:
f.write(base64.b64decode(base64_data))
results['output_files']['text_mask'] = str(text_mask_path)
# Add detected text regions count
results['text_regions_count'] = len(segmentation_results['text_regions_coordinates'])
results['text_regions_coordinates'] = segmentation_results['text_regions_coordinates']
return results
if __name__ == "__main__":
# Simple test - process a sample image if run directly
import sys
if len(sys.argv) > 1:
image_path = sys.argv[1]
else:
image_path = "input/handwritten-journal.jpg" # Example image path"
logger.info(f"Testing image segmentation on {image_path}")
results = process_segmented_image(image_path)
# Print results summary
logger.info(f"Segmentation complete. Found {results.get('text_regions_count', 0)} text regions.")
logger.info(f"Output files saved to: {[path for path in results.get('output_files', {}).values()]}")