Chamin09 commited on
Commit
70d7f43
·
verified ·
1 Parent(s): 86d4eab

Create document_ai.py

Browse files
Files changed (1) hide show
  1. models/document_ai.py +137 -0
models/document_ai.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+ import os
5
+ import sys
6
+
7
+ # Try to import pytesseract, but handle if it's not available
8
+ try:
9
+ import pytesseract
10
+ TESSERACT_AVAILABLE = True
11
+ except ImportError:
12
+ TESSERACT_AVAILABLE = False
13
+
14
+ # Check if tesseract is installed
15
+ if TESSERACT_AVAILABLE:
16
+ try:
17
+ pytesseract.get_tesseract_version()
18
+ except Exception:
19
+ TESSERACT_AVAILABLE = False
20
+
21
+ # Initialize the model and processor with caching
22
+ processor = None
23
+ model = None
24
+
25
+ def get_document_ai_models():
26
+ """Get or initialize document AI models with proper caching."""
27
+ global processor, model
28
+ if processor is None:
29
+ from transformers import LayoutLMv2Processor
30
+ processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
31
+ if model is None:
32
+ from transformers import LayoutLMv2ForSequenceClassification
33
+ model = LayoutLMv2ForSequenceClassification.from_pretrained("microsoft/layoutlmv2-base-uncased")
34
+ return processor, model
35
+
36
+ def extract_text_with_tesseract(image):
37
+ """Extract text using Tesseract OCR."""
38
+ if not TESSERACT_AVAILABLE:
39
+ raise RuntimeError("tesseract is not installed or it's not in your PATH. See README file for more information.")
40
+
41
+ if isinstance(image, np.ndarray):
42
+ pil_image = Image.fromarray(image).convert("RGB")
43
+ else:
44
+ pil_image = image.convert("RGB")
45
+
46
+ # Use pytesseract for OCR
47
+ text = pytesseract.image_to_string(pil_image)
48
+
49
+ # Get word boxes for structure
50
+ boxes = pytesseract.image_to_data(pil_image, output_type=pytesseract.Output.DICT)
51
+
52
+ # Extract words and their positions
53
+ words = []
54
+ word_boxes = []
55
+
56
+ for i in range(len(boxes['text'])):
57
+ if boxes['text'][i].strip() != '':
58
+ words.append(boxes['text'][i])
59
+ x, y, w, h = boxes['left'][i], boxes['top'][i], boxes['width'][i], boxes['height'][i]
60
+ word_boxes.append([x, y, x + w, y + h])
61
+
62
+ return words, word_boxes
63
+
64
+ def extract_text_with_transformers(image):
65
+ """Extract text using transformers models when Tesseract is not available."""
66
+ try:
67
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
68
+
69
+ # Initialize the processor and model
70
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
71
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
72
+
73
+ # Prepare the image
74
+ if isinstance(image, np.ndarray):
75
+ pil_image = Image.fromarray(image).convert("RGB")
76
+ else:
77
+ pil_image = image.convert("RGB")
78
+
79
+ # Process the image
80
+ pixel_values = processor(pil_image, return_tensors="pt").pixel_values
81
+ generated_ids = model.generate(pixel_values)
82
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
83
+
84
+ # Split into words
85
+ words = generated_text.split()
86
+
87
+ # Since we don't have bounding boxes, return empty boxes
88
+ word_boxes = [[0, 0, 0, 0] for _ in words]
89
+
90
+ return words, word_boxes
91
+
92
+ except Exception as e:
93
+ # If transformers OCR fails, return a simple error message
94
+ return ["Error extracting text with transformers OCR:", str(e)], [[0, 0, 0, 0], [0, 0, 0, 0]]
95
+
96
+ def extract_text_and_layout(image):
97
+ """
98
+ Extract text and layout information using OCR.
99
+
100
+ Args:
101
+ image: PIL Image object
102
+
103
+ Returns:
104
+ Dictionary with extracted text and layout information
105
+ """
106
+ # Convert numpy array to PIL Image if needed
107
+ if isinstance(image, np.ndarray):
108
+ image = Image.fromarray(image).convert("RGB")
109
+
110
+ try:
111
+ # Try Tesseract first
112
+ if TESSERACT_AVAILABLE:
113
+ words, boxes = extract_text_with_tesseract(image)
114
+ else:
115
+ # Fall back to transformers OCR
116
+ words, boxes = extract_text_with_transformers(image)
117
+ except Exception as e:
118
+ # If both methods fail, return the error
119
+ return {
120
+ 'words': [f"Error extracting text: {str(e)}"],
121
+ 'boxes': [[0, 0, 0, 0]],
122
+ 'success': False
123
+ }
124
+
125
+ # If no words were found, return empty result
126
+ if not words:
127
+ return {
128
+ 'words': [],
129
+ 'boxes': [],
130
+ 'success': False
131
+ }
132
+
133
+ return {
134
+ 'words': words,
135
+ 'boxes': boxes,
136
+ 'success': True
137
+ }