BrailleMenuGenV2 / models /document_ai.py
Chamin09's picture
Create document_ai.py
70d7f43 verified
import torch
from PIL import Image
import numpy as np
import os
import sys
# Try to import pytesseract, but handle if it's not available
try:
import pytesseract
TESSERACT_AVAILABLE = True
except ImportError:
TESSERACT_AVAILABLE = False
# Check if tesseract is installed
if TESSERACT_AVAILABLE:
try:
pytesseract.get_tesseract_version()
except Exception:
TESSERACT_AVAILABLE = False
# Initialize the model and processor with caching
processor = None
model = None
def get_document_ai_models():
"""Get or initialize document AI models with proper caching."""
global processor, model
if processor is None:
from transformers import LayoutLMv2Processor
processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
if model is None:
from transformers import LayoutLMv2ForSequenceClassification
model = LayoutLMv2ForSequenceClassification.from_pretrained("microsoft/layoutlmv2-base-uncased")
return processor, model
def extract_text_with_tesseract(image):
"""Extract text using Tesseract OCR."""
if not TESSERACT_AVAILABLE:
raise RuntimeError("tesseract is not installed or it's not in your PATH. See README file for more information.")
if isinstance(image, np.ndarray):
pil_image = Image.fromarray(image).convert("RGB")
else:
pil_image = image.convert("RGB")
# Use pytesseract for OCR
text = pytesseract.image_to_string(pil_image)
# Get word boxes for structure
boxes = pytesseract.image_to_data(pil_image, output_type=pytesseract.Output.DICT)
# Extract words and their positions
words = []
word_boxes = []
for i in range(len(boxes['text'])):
if boxes['text'][i].strip() != '':
words.append(boxes['text'][i])
x, y, w, h = boxes['left'][i], boxes['top'][i], boxes['width'][i], boxes['height'][i]
word_boxes.append([x, y, x + w, y + h])
return words, word_boxes
def extract_text_with_transformers(image):
"""Extract text using transformers models when Tesseract is not available."""
try:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# Initialize the processor and model
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
# Prepare the image
if isinstance(image, np.ndarray):
pil_image = Image.fromarray(image).convert("RGB")
else:
pil_image = image.convert("RGB")
# Process the image
pixel_values = processor(pil_image, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Split into words
words = generated_text.split()
# Since we don't have bounding boxes, return empty boxes
word_boxes = [[0, 0, 0, 0] for _ in words]
return words, word_boxes
except Exception as e:
# If transformers OCR fails, return a simple error message
return ["Error extracting text with transformers OCR:", str(e)], [[0, 0, 0, 0], [0, 0, 0, 0]]
def extract_text_and_layout(image):
"""
Extract text and layout information using OCR.
Args:
image: PIL Image object
Returns:
Dictionary with extracted text and layout information
"""
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert("RGB")
try:
# Try Tesseract first
if TESSERACT_AVAILABLE:
words, boxes = extract_text_with_tesseract(image)
else:
# Fall back to transformers OCR
words, boxes = extract_text_with_transformers(image)
except Exception as e:
# If both methods fail, return the error
return {
'words': [f"Error extracting text: {str(e)}"],
'boxes': [[0, 0, 0, 0]],
'success': False
}
# If no words were found, return empty result
if not words:
return {
'words': [],
'boxes': [],
'success': False
}
return {
'words': words,
'boxes': boxes,
'success': True
}