Spaces:
Running
Running
""" | |
Image segmentation utility for OCR preprocessing. | |
Separates text regions from image regions to improve OCR accuracy on mixed-content documents. | |
Based on Mistral AI cookbook examples. | |
""" | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import io | |
import base64 | |
import logging | |
from pathlib import Path | |
from typing import Tuple, List, Dict, Union, Optional | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def segment_image_for_ocr(image_path: Union[str, Path], vision_enabled: bool = True, preserve_content: bool = True) -> Dict[str, Union[Image.Image, str]]: | |
""" | |
Segment an image into text and image regions for improved OCR processing. | |
Args: | |
image_path: Path to the image file | |
vision_enabled: Whether the vision model is enabled | |
Returns: | |
Dict containing: | |
- 'text_regions': PIL Image with highlighted text regions | |
- 'image_regions': PIL Image with highlighted image regions | |
- 'text_mask_base64': Base64 string of text mask for visualization | |
- 'combined_result': PIL Image with combined processing approach | |
""" | |
# Convert to Path object if string | |
image_file = Path(image_path) if isinstance(image_path, str) else image_path | |
# Log start of processing | |
logger.info(f"Segmenting image for OCR: {image_file.name}") | |
try: | |
# Open original image with PIL for compatibility | |
with Image.open(image_file) as pil_img: | |
# --- 2 · Stop "text page detected as image" when vision model is off --- | |
if not vision_enabled: | |
# Import the entropy calculator from utils.image_utils | |
from utils.image_utils import calculate_image_entropy | |
# Calculate entropy to determine if this is line art or blank | |
ent = calculate_image_entropy(pil_img) | |
if ent < 3.5: # Heuristically low → line-art or blank page | |
logger.info(f"Low entropy image detected ({ent:.2f}), classifying as illustration") | |
# Return minimal result for illustration | |
return { | |
'text_regions': None, | |
'image_regions': pil_img, | |
'text_mask_base64': None, | |
'combined_result': None, | |
'text_regions_coordinates': [] | |
} | |
# Convert to RGB if not already | |
if pil_img.mode != 'RGB': | |
pil_img = pil_img.convert('RGB') | |
# Convert PIL image to OpenCV format | |
img = np.array(pil_img) | |
img_rgb = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
# Create grayscale version for text detection | |
gray = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2GRAY) | |
# Step 1: Apply adaptive thresholding to identify potential text areas | |
# This works well for printed text against contrasting backgrounds | |
binary = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, | |
cv2.THRESH_BINARY_INV, 11, 2) | |
# Step 2: Perform morphological operations to connect text components | |
# Use a combination of horizontal and vertical kernels for better text detection | |
# in historical documents with mixed content | |
horiz_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 1)) | |
vert_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 3)) | |
# Apply horizontal dilation to connect characters in a line | |
horiz_dilation = cv2.dilate(binary, horiz_kernel, iterations=1) | |
# Apply vertical dilation to connect lines in a paragraph | |
vert_dilation = cv2.dilate(binary, vert_kernel, iterations=1) | |
# Combine both dilations for better region detection | |
dilation = cv2.bitwise_or(horiz_dilation, vert_dilation) | |
# Step 3: Find contours which will correspond to text blocks | |
contours, _ = cv2.findContours(dilation, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
# Prepare masks to separate text and image regions | |
text_mask = np.zeros_like(gray) | |
# Step 4: Filter contours based on size to identify text regions | |
min_area = 50 # Lower minimum area to catch smaller text blocks in historical documents | |
max_area = img.shape[0] * img.shape[1] * 0.4 # Reduced max to avoid capturing too much | |
text_regions = [] | |
for contour in contours: | |
area = cv2.contourArea(contour) | |
# Filter by area to avoid noise | |
if min_area < area < max_area: | |
# Get the bounding rectangle | |
x, y, w, h = cv2.boundingRect(contour) | |
# Calculate aspect ratio - text regions typically have wider aspect ratio | |
aspect_ratio = w / h | |
# Calculate density of dark pixels in the region (text is typically dense) | |
roi = binary[y:y+h, x:x+w] | |
dark_pixel_density = np.sum(roi > 0) / (w * h) | |
# Special handling for historical documents | |
# Check for position - text is often at the bottom in historical prints | |
y_position_ratio = y / img.shape[0] # Normalized y position (0 at top, 1 at bottom) | |
# Bottom regions get preferential treatment as text | |
is_bottom_region = y_position_ratio > 0.7 | |
# Check if part of a text block cluster (horizontal proximity) | |
is_text_cluster = False | |
# Check already identified text regions for proximity | |
for tx, ty, tw, th in text_regions: | |
# Check if horizontally aligned and close | |
if abs((ty + th/2) - (y + h/2)) < max(th, h) and \ | |
abs((tx + tw) - x) < 20: # Near each other horizontally | |
is_text_cluster = True | |
break | |
# More inclusive classification for historical documents | |
# 1. Typical text characteristics OR | |
# 2. Bottom position (likely text in historical prints) OR | |
# 3. Part of a text cluster OR | |
# 4. Surrounded by other text | |
is_text_region = ((aspect_ratio > 1.05 or aspect_ratio < 0.9) and dark_pixel_density > 0.1) or \ | |
(is_bottom_region and dark_pixel_density > 0.08) or \ | |
is_text_cluster | |
if is_text_region: | |
# Add to text regions list | |
text_regions.append((x, y, w, h)) | |
# Add to text mask | |
cv2.rectangle(text_mask, (x, y), (x+w, y+h), 255, -1) | |
# Step 5: Create visualization for debugging | |
text_regions_vis = img_rgb.copy() | |
for x, y, w, h in text_regions: | |
cv2.rectangle(text_regions_vis, (x, y), (x+w, y+h), (0, 255, 0), 2) | |
# ENHANCED APPROACH FOR HISTORICAL DOCUMENTS: | |
# We'll identify different regions including titles at the top of the document | |
# First, look for potential title text at the top of the document | |
image_height = img.shape[0] | |
image_width = img.shape[1] | |
# Examine the top 20% of the image for potential title text | |
title_section_height = int(image_height * 0.2) | |
title_mask = np.zeros_like(gray) | |
title_mask[:title_section_height, :] = 255 | |
# Find potential title blocks in the top section | |
title_contours, _ = cv2.findContours( | |
cv2.bitwise_and(dilation, title_mask), | |
cv2.RETR_EXTERNAL, | |
cv2.CHAIN_APPROX_SIMPLE | |
) | |
# Extract title regions with more permissive criteria | |
title_regions = [] | |
for contour in title_contours: | |
area = cv2.contourArea(contour) | |
# Use more permissive criteria for title regions | |
if area > min_area * 0.8: # Smaller minimum area for titles | |
x, y, w, h = cv2.boundingRect(contour) | |
# Title regions typically have wider aspect ratio | |
aspect_ratio = w / h | |
# More permissive density check for titles that might be stylized | |
roi = binary[y:y+h, x:x+w] | |
dark_pixel_density = np.sum(roi > 0) / (w * h) | |
# Check if this might be a title | |
# Titles tend to be wider, in the center, and at the top | |
is_wide = aspect_ratio > 2.0 | |
is_centered = abs((x + w/2) - (image_width/2)) < (image_width * 0.3) | |
is_at_top = y < title_section_height | |
# If it looks like a title or has good text characteristics | |
if (is_wide and is_centered and is_at_top) or \ | |
(is_at_top and dark_pixel_density > 0.1): | |
title_regions.append((x, y, w, h)) | |
# Now handle the main content with our standard approach | |
# Use fixed regions for the main content - typically below the title | |
# For primary content, assume most text is in the bottom 70% | |
text_section_start = int(image_height * 0.7) # Start main text section at 70% down | |
# Create text mask combining the title regions and main text area | |
text_mask = np.zeros_like(gray) | |
text_mask[text_section_start:, :] = 255 | |
# Add title regions to the text mask | |
for x, y, w, h in title_regions: | |
# Add some padding around title regions | |
pad_x = max(5, int(w * 0.05)) | |
pad_y = max(5, int(h * 0.05)) | |
x_start = max(0, x - pad_x) | |
y_start = max(0, y - pad_y) | |
x_end = min(image_width, x + w + pad_x) | |
y_end = min(image_height, y + h + pad_y) | |
# Add title region to the text mask | |
text_mask[y_start:y_end, x_start:x_end] = 255 | |
# Image mask is the inverse of text mask - for visualization only | |
image_mask = np.zeros_like(gray) | |
image_mask[text_mask == 0] = 255 | |
# For main text regions, find blocks of text in the bottom part | |
# Create a temporary mask for the main text section | |
temp_mask = np.zeros_like(gray) | |
temp_mask[text_section_start:, :] = 255 | |
# Find text regions for visualization purposes | |
text_regions = [] | |
# Start with any title regions we found | |
text_regions.extend(title_regions) | |
# Then find text regions in the main content area | |
text_region_contours, _ = cv2.findContours( | |
cv2.bitwise_and(dilation, temp_mask), | |
cv2.RETR_EXTERNAL, | |
cv2.CHAIN_APPROX_SIMPLE | |
) | |
# Add each detected region | |
for contour in text_region_contours: | |
x, y, w, h = cv2.boundingRect(contour) | |
if w > 10 and h > 5: # Minimum size to be considered text | |
text_regions.append((x, y, w, h)) | |
# Add the entire bottom section as a fallback text region if none detected | |
if len(text_regions) == 0: | |
x, y = 0, text_section_start | |
w, h = img.shape[1], img.shape[0] - text_section_start | |
text_regions.append((x, y, w, h)) | |
# Create image regions visualization | |
image_regions_vis = img_rgb.copy() | |
# Top section is image | |
cv2.rectangle(image_regions_vis, (0, 0), (img.shape[1], text_section_start), (0, 0, 255), 2) | |
# Bottom section has text - draw green boxes around detected text regions | |
text_regions_vis = img_rgb.copy() | |
for x, y, w, h in text_regions: | |
cv2.rectangle(text_regions_vis, (x, y), (x+w, y+h), (0, 255, 0), 2) | |
# For OCR: CRITICAL - Don't modify the image content | |
# Only create a non-destructive enhanced version | |
# For text detection visualization: | |
text_regions_vis = img_rgb.copy() | |
for x, y, w, h in text_regions: | |
cv2.rectangle(text_regions_vis, (x, y), (x+w, y+h), (0, 255, 0), 2) | |
# For image region visualization: | |
image_regions_vis = img_rgb.copy() | |
cv2.rectangle(image_regions_vis, (0, 0), (img.shape[1], text_section_start), (0, 0, 255), 2) | |
# Create a minimally enhanced version of the original image | |
# that preserves ALL content (both text and image) | |
combined_result = img_rgb.copy() | |
# Apply gentle contrast enhancement if requested | |
if not preserve_content: | |
# Use a subtle CLAHE enhancement to improve OCR without losing content | |
lab_img = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2LAB) | |
l, a, b = cv2.split(lab_img) | |
# Very mild CLAHE settings to preserve text | |
clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8, 8)) | |
cl = clahe.apply(l) | |
# Merge channels back | |
enhanced_lab = cv2.merge((cl, a, b)) | |
combined_result = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2BGR) | |
# Extract individual region images for separate OCR processing | |
region_images = [] | |
if text_regions: | |
for idx, (x, y, w, h) in enumerate(text_regions): | |
# Add padding around region (10% of width/height) | |
pad_x = max(5, int(w * 0.1)) | |
pad_y = max(5, int(h * 0.1)) | |
# Ensure coordinates stay within image bounds | |
x_start = max(0, x - pad_x) | |
y_start = max(0, y - pad_y) | |
x_end = min(img_rgb.shape[1], x + w + pad_x) | |
y_end = min(img_rgb.shape[0], y + h + pad_y) | |
# Extract region with padding | |
region = img_rgb[y_start:y_end, x_start:x_end].copy() | |
# Store region with its coordinates | |
region_info = { | |
'image': region, | |
'coordinates': (x, y, w, h), | |
'padded_coordinates': (x_start, y_start, x_end - x_start, y_end - y_start), | |
'order': idx | |
} | |
region_images.append(region_info) | |
# Convert visualization results back to PIL Images | |
text_regions_pil = Image.fromarray(cv2.cvtColor(text_regions_vis, cv2.COLOR_BGR2RGB)) | |
image_regions_pil = Image.fromarray(cv2.cvtColor(image_regions_vis, cv2.COLOR_BGR2RGB)) | |
combined_result_pil = Image.fromarray(cv2.cvtColor(combined_result, cv2.COLOR_BGR2RGB)) | |
# Create base64 representation of text mask for visualization | |
_, buffer = cv2.imencode('.png', text_mask) | |
text_mask_base64 = base64.b64encode(buffer).decode('utf-8') | |
# Convert region images to PIL format | |
region_pil_images = [] | |
for region_info in region_images: | |
region_pil = Image.fromarray(cv2.cvtColor(region_info['image'], cv2.COLOR_BGR2RGB)) | |
region_info['pil_image'] = region_pil | |
region_pil_images.append(region_info) | |
# Return the segmentation results | |
return { | |
'text_regions': text_regions_pil, | |
'image_regions': image_regions_pil, | |
'text_mask_base64': f"data:image/png;base64,{text_mask_base64}", | |
'combined_result': combined_result_pil, | |
'text_regions_coordinates': text_regions, | |
'region_images': region_pil_images | |
} | |
except Exception as e: | |
logger.error(f"Error segmenting image {image_file.name}: {str(e)}") | |
# Return None values if processing fails | |
return { | |
'text_regions': None, | |
'image_regions': None, | |
'text_mask_base64': None, | |
'combined_result': None, | |
'text_regions_coordinates': [] | |
} | |
def process_segmented_image(image_path: Union[str, Path], output_dir: Optional[Path] = None, preserve_content: bool = True) -> Dict: | |
""" | |
Process an image using segmentation for improved OCR, saving visualization outputs. | |
Args: | |
image_path: Path to the image file | |
output_dir: Optional directory to save visualization outputs | |
Returns: | |
Dictionary with processing results and paths to output files | |
""" | |
# Convert to Path object if string | |
image_file = Path(image_path) if isinstance(image_path, str) else image_path | |
# Create output directory if not provided | |
if output_dir is None: | |
output_dir = Path("output") / "segmentation" | |
output_dir.mkdir(parents=True, exist_ok=True) | |
# Process the image with segmentation | |
segmentation_results = segment_image_for_ocr(image_file) | |
# Prepare results dictionary | |
results = { | |
'original_image': str(image_file), | |
'output_files': {} | |
} | |
# Save visualization outputs if segmentation was successful | |
if segmentation_results['text_regions'] is not None: | |
# Save text regions visualization | |
text_regions_path = output_dir / f"{image_file.stem}_text_regions.jpg" | |
segmentation_results['text_regions'].save(text_regions_path) | |
results['output_files']['text_regions'] = str(text_regions_path) | |
# Save image regions visualization | |
image_regions_path = output_dir / f"{image_file.stem}_image_regions.jpg" | |
segmentation_results['image_regions'].save(image_regions_path) | |
results['output_files']['image_regions'] = str(image_regions_path) | |
# Save combined result | |
combined_path = output_dir / f"{image_file.stem}_combined.jpg" | |
segmentation_results['combined_result'].save(combined_path) | |
results['output_files']['combined_result'] = str(combined_path) | |
# Save text mask visualization | |
text_mask_path = output_dir / f"{image_file.stem}_text_mask.png" | |
# Save text mask from base64 | |
if segmentation_results['text_mask_base64']: | |
base64_data = segmentation_results['text_mask_base64'].split(',')[1] | |
with open(text_mask_path, 'wb') as f: | |
f.write(base64.b64decode(base64_data)) | |
results['output_files']['text_mask'] = str(text_mask_path) | |
# Add detected text regions count | |
results['text_regions_count'] = len(segmentation_results['text_regions_coordinates']) | |
results['text_regions_coordinates'] = segmentation_results['text_regions_coordinates'] | |
return results | |
if __name__ == "__main__": | |
# Simple test - process a sample image if run directly | |
import sys | |
if len(sys.argv) > 1: | |
image_path = sys.argv[1] | |
else: | |
# Default to testing with the magician image | |
image_path = "input/magician-or-bottle-cungerer.jpg" | |
logger.info(f"Testing image segmentation on {image_path}") | |
results = process_segmented_image(image_path) | |
# Print results summary | |
logger.info(f"Segmentation complete. Found {results.get('text_regions_count', 0)} text regions.") | |
logger.info(f"Output files saved to: {[path for path in results.get('output_files', {}).values()]}") | |