Spaces:
Running
Running
""" | |
Image segmentation utility for OCR preprocessing. | |
Separates text regions from image regions to improve OCR accuracy on mixed-content documents. | |
Uses content-aware adaptive segmentation for improved results across document types. | |
""" | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import io | |
import base64 | |
import logging | |
from pathlib import Path | |
from typing import Tuple, List, Dict, Union, Optional | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def segment_image_for_ocr(image_path: Union[str, Path], vision_enabled: bool = True, preserve_content: bool = True) -> Dict[str, Union[Image.Image, str]]: | |
""" | |
Prepare image for OCR processing using content-aware segmentation. | |
Uses adaptive region detection based on text density analysis. | |
Args: | |
image_path: Path to the image file | |
vision_enabled: Whether the vision model is enabled | |
preserve_content: Whether to preserve original content without enhancement | |
Returns: | |
Dict containing segmentation results | |
""" | |
# Convert to Path object if string | |
image_file = Path(image_path) if isinstance(image_path, str) else image_path | |
# Log start of processing | |
logger.info(f"Preparing image for Mistral OCR: {image_file.name}") | |
try: | |
# Open original image with PIL | |
with Image.open(image_file) as pil_img: | |
# Check for low entropy images when vision is disabled | |
if not vision_enabled: | |
from utils.image_utils import calculate_image_entropy | |
ent = calculate_image_entropy(pil_img) | |
if ent < 3.5: # Likely line-art or blank page | |
logger.info(f"Low entropy image detected ({ent:.2f}), classifying as illustration") | |
return { | |
'text_regions': None, | |
'image_regions': pil_img, | |
'text_mask_base64': None, | |
'combined_result': None, | |
'text_regions_coordinates': [] | |
} | |
# Convert to RGB if needed | |
if pil_img.mode != 'RGB': | |
pil_img = pil_img.convert('RGB') | |
# Get image dimensions | |
img_np = np.array(pil_img) | |
img_width, img_height = pil_img.size | |
# Analyze text density to determine if advanced segmentation is needed | |
# This replaces document-specific logic with content-aware analysis | |
from utils.image_utils import estimate_text_density | |
text_density = estimate_text_density(img_np) | |
# Use adaptive approach for documents with unusual text distribution | |
if text_density['pattern'] == 'varied' or text_density['uppercase_sections'] > 0: | |
logger.info(f"Using adaptive segmentation for document with varied text density pattern={text_density['pattern']}, uppercase_sections={text_density['uppercase_sections']}") | |
# Detect content regions based on text density | |
from utils.text_utils import detect_content_regions | |
regions = detect_content_regions(img_np) | |
# Create visualization with green borders around the text regions | |
vis_img = img_np.copy() | |
# Draw regions on visualization | |
for x, y, w, h in regions: | |
cv2.rectangle(vis_img, (x, y), (x+w, y+h), (0, 255, 0), 3) | |
# Add text to indicate we're using adaptive processing | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
cv2.putText(vis_img, "Adaptive region processing", (30, 60), font, 1, (0, 255, 0), 2) | |
# Create visualization images | |
text_regions_vis = Image.fromarray(vis_img) | |
image_regions_vis = text_regions_vis.copy() | |
# Create a mask highlighting the text regions | |
text_mask = np.zeros((img_height, img_width), dtype=np.uint8) | |
for x, y, w, h in regions: | |
text_mask[y:y+h, x:x+w] = 255 | |
_, buffer = cv2.imencode('.png', text_mask) | |
text_mask_base64 = base64.b64encode(buffer).decode('utf-8') | |
# Extract region images | |
region_images = [] | |
for i, (x, y, w, h) in enumerate(regions): | |
region = img_np[y:y+h, x:x+w].copy() | |
region_pil = Image.fromarray(region) | |
region_info = { | |
'image': region, | |
'pil_image': region_pil, | |
'coordinates': (x, y, w, h), | |
'padded_coordinates': (x, y, w, h), | |
'order': i | |
} | |
region_images.append(region_info) | |
# Return the adaptive segmentation results | |
return { | |
'text_regions': text_regions_vis, | |
'image_regions': image_regions_vis, | |
'text_mask_base64': f"data:image/png;base64,{text_mask_base64}", | |
'combined_result': pil_img, | |
'text_regions_coordinates': regions, | |
'region_images': region_images, | |
'segmentation_type': 'adaptive' | |
} | |
else: | |
# SIMPLIFIED APPROACH for most documents | |
# Let Mistral OCR handle the entire document understanding process | |
logger.info(f"Using standard approach for document with uniform text density") | |
# For visualization, mark the entire image as a text region | |
full_image_region = [(0, 0, img_width, img_height)] | |
# Create visualization with a simple border | |
vis_img = img_np.copy() | |
cv2.rectangle(vis_img, (5, 5), (img_width-5, img_height-5), (0, 255, 0), 5) | |
# Add text to indicate this is using Mistral's native processing | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
cv2.putText(vis_img, "Processed by Mistral OCR", (30, 60), font, 1, (0, 255, 0), 2) | |
# Create visualizations and masks | |
text_regions_vis = Image.fromarray(vis_img) | |
image_regions_vis = text_regions_vis.copy() | |
# Create a mask of the entire image (just for visualization) | |
text_mask = np.ones((img_height, img_width), dtype=np.uint8) * 255 | |
_, buffer = cv2.imencode('.png', text_mask) | |
text_mask_base64 = base64.b64encode(buffer).decode('utf-8') | |
# Return the original image as the combined result | |
return { | |
'text_regions': text_regions_vis, | |
'image_regions': image_regions_vis, | |
'text_mask_base64': f"data:image/png;base64,{text_mask_base64}", | |
'combined_result': pil_img, | |
'text_regions_coordinates': full_image_region, | |
'region_images': [{ | |
'image': img_np, | |
'pil_image': pil_img, | |
'coordinates': (0, 0, img_width, img_height), | |
'padded_coordinates': (0, 0, img_width, img_height), | |
'order': 0 | |
}], | |
'segmentation_type': 'simplified' | |
} | |
except Exception as e: | |
logger.error(f"Error segmenting image {image_file.name}: {str(e)}") | |
# Return None values if processing fails | |
return { | |
'text_regions': None, | |
'image_regions': None, | |
'text_mask_base64': None, | |
'combined_result': None, | |
'text_regions_coordinates': [] | |
} | |
def process_segmented_image(image_path: Union[str, Path], output_dir: Optional[Path] = None, preserve_content: bool = True) -> Dict: | |
""" | |
Process an image using segmentation for improved OCR, saving visualization outputs. | |
Args: | |
image_path: Path to the image file | |
output_dir: Optional directory to save visualization outputs | |
Returns: | |
Dictionary with processing results and paths to output files | |
""" | |
# Convert to Path object if string | |
image_file = Path(image_path) if isinstance(image_path, str) else image_path | |
# Create output directory if not provided | |
if output_dir is None: | |
output_dir = Path("output") / "segmentation" | |
output_dir.mkdir(parents=True, exist_ok=True) | |
# Process the image with segmentation | |
segmentation_results = segment_image_for_ocr(image_file) | |
# Prepare results dictionary | |
results = { | |
'original_image': str(image_file), | |
'output_files': {} | |
} | |
# Save visualization outputs if segmentation was successful | |
if segmentation_results['text_regions'] is not None: | |
# Save text regions visualization | |
text_regions_path = output_dir / f"{image_file.stem}_text_regions.jpg" | |
segmentation_results['text_regions'].save(text_regions_path) | |
results['output_files']['text_regions'] = str(text_regions_path) | |
# Save image regions visualization | |
image_regions_path = output_dir / f"{image_file.stem}_image_regions.jpg" | |
segmentation_results['image_regions'].save(image_regions_path) | |
results['output_files']['image_regions'] = str(image_regions_path) | |
# Save combined result | |
combined_path = output_dir / f"{image_file.stem}_combined.jpg" | |
segmentation_results['combined_result'].save(combined_path) | |
results['output_files']['combined_result'] = str(combined_path) | |
# Save text mask visualization | |
text_mask_path = output_dir / f"{image_file.stem}_text_mask.png" | |
# Save text mask from base64 | |
if segmentation_results['text_mask_base64']: | |
base64_data = segmentation_results['text_mask_base64'].split(',')[1] | |
with open(text_mask_path, 'wb') as f: | |
f.write(base64.b64decode(base64_data)) | |
results['output_files']['text_mask'] = str(text_mask_path) | |
# Add detected text regions count | |
results['text_regions_count'] = len(segmentation_results['text_regions_coordinates']) | |
results['text_regions_coordinates'] = segmentation_results['text_regions_coordinates'] | |
return results | |
if __name__ == "__main__": | |
# Simple test - process a sample image if run directly | |
import sys | |
if len(sys.argv) > 1: | |
image_path = sys.argv[1] | |
else: | |
image_path = "input/handwritten-journal.jpg" # Example image path" | |
logger.info(f"Testing image segmentation on {image_path}") | |
results = process_segmented_image(image_path) | |
# Print results summary | |
logger.info(f"Segmentation complete. Found {results.get('text_regions_count', 0)} text regions.") | |
logger.info(f"Output files saved to: {[path for path in results.get('output_files', {}).values()]}") | |