Spaces:
Sleeping
Sleeping
import torch | |
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification | |
from PIL import Image | |
import numpy as np | |
from typing import Dict, List, Tuple, Optional | |
import logging | |
logger = logging.getLogger(__name__) | |
class FormIQModel: | |
def __init__( | |
self, | |
model_name: str = "microsoft/layoutlmv3-base", | |
device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
): | |
"""Initialize the FormIQ model with LayoutLMv3. | |
Args: | |
model_name: Name of the pre-trained model to use | |
device: Device to run the model on ('cuda' or 'cpu') | |
""" | |
self.device = device | |
self.processor = LayoutLMv3Processor.from_pretrained(model_name) | |
self.model = LayoutLMv3ForTokenClassification.from_pretrained(model_name) | |
self.model.to(device) | |
logger.info(f"Model initialized on {device}") | |
def preprocess_image(self, image: Image.Image) -> Dict[str, torch.Tensor]: | |
"""Preprocess the input image for the model. | |
Args: | |
image: PIL Image to process | |
Returns: | |
Dictionary of processed inputs | |
""" | |
try: | |
# Process image and text | |
encoding = self.processor( | |
image, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512 | |
) | |
# Move tensors to device | |
encoding = {k: v.to(self.device) for k, v in encoding.items()} | |
return encoding | |
except Exception as e: | |
logger.error(f"Error preprocessing image: {str(e)}") | |
raise | |
def predict( | |
self, | |
image: Image.Image, | |
confidence_threshold: float = 0.5 | |
) -> Dict[str, List[Dict[str, any]]]: | |
"""Extract information from the document image. | |
Args: | |
image: PIL Image of the document | |
confidence_threshold: Minimum confidence score for predictions | |
Returns: | |
Dictionary containing extracted fields and their metadata | |
""" | |
try: | |
# Preprocess image | |
inputs = self.preprocess_image(image) | |
# Get model predictions | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
predictions = outputs.logits.argmax(-1).squeeze().cpu().numpy() | |
scores = torch.softmax(outputs.logits, dim=-1).max(-1)[0].squeeze().cpu().numpy() | |
# Process predictions | |
extracted_fields = self._process_predictions(predictions, scores, confidence_threshold) | |
return { | |
"fields": extracted_fields, | |
"metadata": { | |
"confidence_scores": scores.tolist(), | |
"model_version": self.model.config.model_type | |
} | |
} | |
except Exception as e: | |
logger.error(f"Error during prediction: {str(e)}") | |
raise | |
def _process_predictions( | |
self, | |
predictions: np.ndarray, | |
scores: np.ndarray, | |
confidence_threshold: float | |
) -> List[Dict[str, any]]: | |
"""Process raw model predictions into structured output. | |
Args: | |
predictions: Array of predicted class indices | |
scores: Array of confidence scores | |
confidence_threshold: Minimum confidence score | |
Returns: | |
List of dictionaries containing field information | |
""" | |
# TODO: Implement field-specific post-processing | |
# This is a placeholder implementation | |
processed_fields = [] | |
for pred, score in zip(predictions, scores): | |
if score >= confidence_threshold: | |
field_info = { | |
"label": self.model.config.id2label[pred], | |
"confidence": float(score), | |
"bbox": None # TODO: Add bounding box information | |
} | |
processed_fields.append(field_info) | |
return processed_fields | |
def validate_extraction( | |
self, | |
extracted_fields: Dict[str, List[Dict[str, any]]], | |
document_type: str | |
) -> Dict[str, any]: | |
"""Validate extracted fields based on document type rules. | |
Args: | |
extracted_fields: Dictionary of extracted fields | |
document_type: Type of document (e.g., 'invoice', 'receipt') | |
Returns: | |
Dictionary containing validation results | |
""" | |
# TODO: Implement field validation logic | |
# This is a placeholder implementation | |
return { | |
"is_valid": True, | |
"validation_errors": [], | |
"confidence_score": 1.0 | |
} |