Spaces:
Running
Running
File size: 4,877 Bytes
83dd2a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from PIL import Image
import numpy as np
from typing import Dict, List, Tuple, Optional
import logging
logger = logging.getLogger(__name__)
class FormIQModel:
def __init__(
self,
model_name: str = "microsoft/layoutlmv3-base",
device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
"""Initialize the FormIQ model with LayoutLMv3.
Args:
model_name: Name of the pre-trained model to use
device: Device to run the model on ('cuda' or 'cpu')
"""
self.device = device
self.processor = LayoutLMv3Processor.from_pretrained(model_name)
self.model = LayoutLMv3ForTokenClassification.from_pretrained(model_name)
self.model.to(device)
logger.info(f"Model initialized on {device}")
def preprocess_image(self, image: Image.Image) -> Dict[str, torch.Tensor]:
"""Preprocess the input image for the model.
Args:
image: PIL Image to process
Returns:
Dictionary of processed inputs
"""
try:
# Process image and text
encoding = self.processor(
image,
return_tensors="pt",
truncation=True,
max_length=512
)
# Move tensors to device
encoding = {k: v.to(self.device) for k, v in encoding.items()}
return encoding
except Exception as e:
logger.error(f"Error preprocessing image: {str(e)}")
raise
def predict(
self,
image: Image.Image,
confidence_threshold: float = 0.5
) -> Dict[str, List[Dict[str, any]]]:
"""Extract information from the document image.
Args:
image: PIL Image of the document
confidence_threshold: Minimum confidence score for predictions
Returns:
Dictionary containing extracted fields and their metadata
"""
try:
# Preprocess image
inputs = self.preprocess_image(image)
# Get model predictions
with torch.no_grad():
outputs = self.model(**inputs)
predictions = outputs.logits.argmax(-1).squeeze().cpu().numpy()
scores = torch.softmax(outputs.logits, dim=-1).max(-1)[0].squeeze().cpu().numpy()
# Process predictions
extracted_fields = self._process_predictions(predictions, scores, confidence_threshold)
return {
"fields": extracted_fields,
"metadata": {
"confidence_scores": scores.tolist(),
"model_version": self.model.config.model_type
}
}
except Exception as e:
logger.error(f"Error during prediction: {str(e)}")
raise
def _process_predictions(
self,
predictions: np.ndarray,
scores: np.ndarray,
confidence_threshold: float
) -> List[Dict[str, any]]:
"""Process raw model predictions into structured output.
Args:
predictions: Array of predicted class indices
scores: Array of confidence scores
confidence_threshold: Minimum confidence score
Returns:
List of dictionaries containing field information
"""
# TODO: Implement field-specific post-processing
# This is a placeholder implementation
processed_fields = []
for pred, score in zip(predictions, scores):
if score >= confidence_threshold:
field_info = {
"label": self.model.config.id2label[pred],
"confidence": float(score),
"bbox": None # TODO: Add bounding box information
}
processed_fields.append(field_info)
return processed_fields
def validate_extraction(
self,
extracted_fields: Dict[str, List[Dict[str, any]]],
document_type: str
) -> Dict[str, any]:
"""Validate extracted fields based on document type rules.
Args:
extracted_fields: Dictionary of extracted fields
document_type: Type of document (e.g., 'invoice', 'receipt')
Returns:
Dictionary containing validation results
"""
# TODO: Implement field validation logic
# This is a placeholder implementation
return {
"is_valid": True,
"validation_errors": [],
"confidence_score": 1.0
} |