formiq / src /models /layoutlm.py
chandini2595's picture
Initial commit without binary files
83dd2a8
import torch
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from PIL import Image
import numpy as np
from typing import Dict, List, Tuple, Optional
import logging
logger = logging.getLogger(__name__)
class FormIQModel:
def __init__(
self,
model_name: str = "microsoft/layoutlmv3-base",
device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
"""Initialize the FormIQ model with LayoutLMv3.
Args:
model_name: Name of the pre-trained model to use
device: Device to run the model on ('cuda' or 'cpu')
"""
self.device = device
self.processor = LayoutLMv3Processor.from_pretrained(model_name)
self.model = LayoutLMv3ForTokenClassification.from_pretrained(model_name)
self.model.to(device)
logger.info(f"Model initialized on {device}")
def preprocess_image(self, image: Image.Image) -> Dict[str, torch.Tensor]:
"""Preprocess the input image for the model.
Args:
image: PIL Image to process
Returns:
Dictionary of processed inputs
"""
try:
# Process image and text
encoding = self.processor(
image,
return_tensors="pt",
truncation=True,
max_length=512
)
# Move tensors to device
encoding = {k: v.to(self.device) for k, v in encoding.items()}
return encoding
except Exception as e:
logger.error(f"Error preprocessing image: {str(e)}")
raise
def predict(
self,
image: Image.Image,
confidence_threshold: float = 0.5
) -> Dict[str, List[Dict[str, any]]]:
"""Extract information from the document image.
Args:
image: PIL Image of the document
confidence_threshold: Minimum confidence score for predictions
Returns:
Dictionary containing extracted fields and their metadata
"""
try:
# Preprocess image
inputs = self.preprocess_image(image)
# Get model predictions
with torch.no_grad():
outputs = self.model(**inputs)
predictions = outputs.logits.argmax(-1).squeeze().cpu().numpy()
scores = torch.softmax(outputs.logits, dim=-1).max(-1)[0].squeeze().cpu().numpy()
# Process predictions
extracted_fields = self._process_predictions(predictions, scores, confidence_threshold)
return {
"fields": extracted_fields,
"metadata": {
"confidence_scores": scores.tolist(),
"model_version": self.model.config.model_type
}
}
except Exception as e:
logger.error(f"Error during prediction: {str(e)}")
raise
def _process_predictions(
self,
predictions: np.ndarray,
scores: np.ndarray,
confidence_threshold: float
) -> List[Dict[str, any]]:
"""Process raw model predictions into structured output.
Args:
predictions: Array of predicted class indices
scores: Array of confidence scores
confidence_threshold: Minimum confidence score
Returns:
List of dictionaries containing field information
"""
# TODO: Implement field-specific post-processing
# This is a placeholder implementation
processed_fields = []
for pred, score in zip(predictions, scores):
if score >= confidence_threshold:
field_info = {
"label": self.model.config.id2label[pred],
"confidence": float(score),
"bbox": None # TODO: Add bounding box information
}
processed_fields.append(field_info)
return processed_fields
def validate_extraction(
self,
extracted_fields: Dict[str, List[Dict[str, any]]],
document_type: str
) -> Dict[str, any]:
"""Validate extracted fields based on document type rules.
Args:
extracted_fields: Dictionary of extracted fields
document_type: Type of document (e.g., 'invoice', 'receipt')
Returns:
Dictionary containing validation results
"""
# TODO: Implement field validation logic
# This is a placeholder implementation
return {
"is_valid": True,
"validation_errors": [],
"confidence_score": 1.0
}