ChintanSatva commited on
Commit
cc3cef4
·
verified ·
1 Parent(s): 344effa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +245 -144
app.py CHANGED
@@ -16,9 +16,8 @@ import asyncio
16
  import psutil
17
  import cachetools
18
  import hashlib
19
- from vllm import LLM
20
 
21
- app = FastAPI()
22
 
23
  # Configure logging
24
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -27,19 +26,28 @@ logger = logging.getLogger(__name__)
27
  # Set Tesseract path
28
  pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"
29
 
30
- # Initialize BitNet model for CPU-only
 
31
  try:
 
 
 
 
 
 
32
  llm = LLM(
33
- model="username/bitnet-finetuned-invoice", # Replace with your fine-tuned BitNet model
34
  device="cpu",
35
- enforce_eager=True, # Disable CUDA graph compilation
36
- tensor_parallel_size=1, # Single CPU process
37
- disable_custom_all_reduce=True, # Avoid GPU optimizations
38
- max_model_len=2048, # Fit within 16GB RAM
 
39
  )
 
40
  except Exception as e:
41
- logger.error(f"Failed to load BitNet model: {str(e)}")
42
- raise HTTPException(status_code=500, detail="BitNet model initialization failed")
43
 
44
  # In-memory caches (1-hour TTL)
45
  raw_text_cache = cachetools.TTLCache(maxsize=100, ttl=3600)
@@ -47,9 +55,12 @@ structured_data_cache = cachetools.TTLCache(maxsize=100, ttl=3600)
47
 
48
  def log_memory_usage():
49
  """Log current memory usage."""
50
- process = psutil.Process()
51
- mem_info = process.memory_info()
52
- return f"Memory usage: {mem_info.rss / 1024 / 1024:.2f} MB"
 
 
 
53
 
54
  def get_file_hash(file_bytes):
55
  """Generate MD5 hash of file content."""
@@ -65,140 +76,218 @@ async def process_image(img_bytes, filename, idx):
65
  logger.info(f"Starting OCR for {filename} image {idx}, {log_memory_usage()}")
66
  try:
67
  img = Image.open(io.BytesIO(img_bytes))
 
 
 
 
68
  img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
69
  gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
70
- img_pil = Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB))
71
- custom_config = r'--oem 1 --psm 6 -l eng+ara'
 
 
 
 
72
  page_text = pytesseract.image_to_string(img_pil, config=custom_config)
73
- logger.info(f"Completed OCR for {filename} image {idx}, took {time.time() - start_time:.2f} seconds, {log_memory_usage()}")
 
74
  return page_text + "\n"
75
  except Exception as e:
76
- logger.error(f"OCR failed for {filename} image {idx}: {str(e)}, {log_memory_usage()}")
77
  return ""
78
 
79
  async def process_pdf_page(img, page_idx):
80
  """Process a single PDF page with OCR."""
81
  start_time = time.time()
82
- logger.info(f"Starting OCR for PDF page {page_idx}, {log_memory_usage()}")
83
  try:
84
  img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
85
  gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
86
- img_pil = Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB))
87
- custom_config = r'--oem 1 --psm 6 -l eng+ara'
 
 
 
 
88
  page_text = pytesseract.image_to_string(img_pil, config=custom_config)
89
- logger.info(f"Completed OCR for PDF page {page_idx}, took {time.time() - start_time:.2f} seconds, {log_memory_usage()}")
 
90
  return page_text + "\n"
91
  except Exception as e:
92
- logger.error(f"OCR failed for PDF page {page_idx}: {str(e)}, {log_memory_usage()}")
93
  return ""
94
 
95
- async def process_with_bitnet(filename: str, raw_text: str):
96
- """Process raw text with BitNet to extract structured data."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  start_time = time.time()
98
- logger.info(f"Starting BitNet processing for {filename}, {log_memory_usage()}")
99
 
100
  # Check structured data cache
101
  text_hash = get_text_hash(raw_text)
102
  if text_hash in structured_data_cache:
103
- logger.info(f"Structured data cache hit for {filename}, {log_memory_usage()}")
104
  return structured_data_cache[text_hash]
105
 
106
- # Truncate text for BitNet
107
- if len(raw_text) > 10000:
108
- raw_text = raw_text[:10000]
109
- logger.info(f"Truncated raw text for {filename} to 10000 characters, {log_memory_usage()}")
110
 
111
  try:
112
- prompt = f"""You are an intelligent invoice data extractor. Given raw text from an invoice (in English or other languages),
113
- extract key business fields into the specified JSON format. Return each field with an estimated accuracy score between 0 and 1.
114
-
115
- - Accuracy reflects confidence in the correctness of each field.
116
- - Handle synonyms (e.g., 'total' = 'net', 'tax' = 'GST'/'TDS').
117
- - Detect currency from symbols ($, ₹, €) or keywords (USD, INR, EUR); default to USD if unclear.
118
- - The 'items' list may have multiple entries, each with detailed attributes.
119
- - If a field is missing, return an empty value (`""` or `0`) and set `accuracy` to `0.0`.
120
- - Convert any date to YYYY-MM-DD.
121
-
122
- Raw text:
123
- {raw_text}
124
-
125
- Output JSON:
126
- {{
127
- "invoice": {{
128
- "invoice_number": {{"value": "", "accuracy": 0.0}},
129
- "invoice_date": {{"value": "", "accuracy": 0.0}},
130
- "due_date": {{"value": "", "accuracy": 0.0}},
131
- "purchase_order_number": {{"value": "", "accuracy": 0.0}},
132
- "vendor": {{
133
- "vendor_id": {{"value": "", "accuracy": 0.0}},
134
- "name": {{"value": "", "accuracy": 0.0}},
135
- "address": {{
136
- "line1": {{"value": "", "accuracy": 0.0}},
137
- "line2": {{"value": "", "accuracy": 0.0}},
138
- "city": {{"value": "", "accuracy": 0.0}},
139
- "state": {{"value": "", "accuracy": 0.0}},
140
- "postal_code": {{"value": "", "accuracy": 0.0}},
141
- "country": {{"value": "", "accuracy": 0.0}}
142
- }},
143
- "contact": {{
144
- "email": {{"value": "", "accuracy": 0.0}},
145
- "phone": {{"value": "", "accuracy": 0.0}}
146
- }},
147
- "tax_id": {{"value": "", "accuracy": 0.0}}
148
- }},
149
- "buyer": {{
150
- "buyer_id": {{"value": "", "accuracy": 0.0}},
151
- "name": {{"value": "", "accuracy": 0.0}},
152
- "address": {{
153
- "line1": {{"value": "", "accuracy": 0.0}},
154
- "line2": {{"value": "", "accuracy": 0.0}},
155
- "city": {{"value": "", "accuracy": 0.0}},
156
- "state": {{"value": "", "accuracy": 0.0}},
157
- "postal_code": {{"value": "", "accuracy": 0.0}},
158
- "country": {{"value": "", "accuracy": 0.0}}
159
- }},
160
- "contact": {{
161
- "email": {{"value": "", "accuracy": 0.0}},
162
- "phone": {{"value": "", "accuracy": 0.0}}
163
- }},
164
- "tax_id": {{"value": "", "accuracy": 0.0}}
165
- }},
166
- "items": [
167
- {{
168
- "item_id": {{"value": "", "accuracy": 0.0}},
169
- "description": {{"value": "", "accuracy": 0.0}},
170
- "quantity": {{"value": 0, "accuracy": 0.0}},
171
- "unit_of_measure": {{"value": "", "accuracy": 0.0}},
172
- "unit_price": {{"value": 0, "accuracy": 0.0}},
173
- "total_price": {{"value": 0, "accuracy": 0.0}},
174
- "tax_rate": {{"value": 0, "accuracy": 0.0}},
175
- "tax_amount": {{"value": 0, "accuracy": 0.0}},
176
- "discount": {{"value": 0, "accuracy": 0.0}},
177
- "net_amount": {{"value": 0, "accuracy": 0.0}}
178
- }}
179
- ],
180
- "sub_total": {{"value": 0, "accuracy": 0.0}},
181
- "tax_total": {{"value": 0, "accuracy": 0.0}},
182
- "discount_total": {{"value": 0, "accuracy": 0.0}},
183
- "total_amount": {{"value": 0, "accuracy": 0.0}},
184
- "currency": {{"value": "", "accuracy": 0.0}}
185
- }}
186
- }}
187
- """
188
- outputs = llm.generate(prompts=[prompt])
189
- json_str = outputs[0].outputs[0].text
190
- json_start = json_str.find("{")
191
- json_end = json_str.rfind("}") + 1
192
- structured_data = json.loads(json_str[json_start:json_end])
193
  structured_data_cache[text_hash] = structured_data
194
- logger.info(f"BitNet processing for {filename}, took {time.time() - start_time:.2f} seconds, {log_memory_usage()}")
195
  return structured_data
 
196
  except Exception as e:
197
- logger.error(f"BitNet processing failed for {filename}: {str(e)}, {log_memory_usage()}")
198
- return {"error": f"BitNet processing failed: {str(e)}"}
 
 
 
 
 
 
 
 
 
199
 
200
  @app.post("/ocr")
201
  async def extract_and_structure(files: List[UploadFile] = File(...)):
 
202
  output_json = {
203
  "success": True,
204
  "message": "",
@@ -207,15 +296,15 @@ async def extract_and_structure(files: List[UploadFile] = File(...)):
207
  success_count = 0
208
  fail_count = 0
209
 
210
- logger.info(f"Starting processing for {len(files)} files, {log_memory_usage()}")
211
 
212
  for file in files:
213
  total_start_time = time.time()
214
- logger.info(f"Processing file: {file.filename}, {log_memory_usage()}")
215
 
216
  # Validate file format
217
  valid_extensions = {'.pdf', '.jpg', '.jpeg', '.png'}
218
- file_ext = os.path.splitext(file.filename.lower())[1]
219
  if file_ext not in valid_extensions:
220
  fail_count += 1
221
  output_json["data"].append({
@@ -232,7 +321,7 @@ async def extract_and_structure(files: List[UploadFile] = File(...)):
232
  file_bytes = await file.read()
233
  file_stream = io.BytesIO(file_bytes)
234
  file_hash = get_file_hash(file_bytes)
235
- logger.info(f"Read file {file.filename}, took {time.time() - file_start_time:.2f} seconds, size: {len(file_bytes)/1024:.2f} KB, {log_memory_usage()}")
236
  except Exception as e:
237
  fail_count += 1
238
  output_json["data"].append({
@@ -240,17 +329,17 @@ async def extract_and_structure(files: List[UploadFile] = File(...)):
240
  "structured_data": {"error": f"Failed to read file: {str(e)}"},
241
  "error": f"Failed to read file: {str(e)}"
242
  })
243
- logger.error(f"Failed to read file {file.filename}: {str(e)}, {log_memory_usage()}")
244
  continue
245
 
246
  # Check raw text cache
247
  raw_text = ""
248
  if file_hash in raw_text_cache:
249
  raw_text = raw_text_cache[file_hash]
250
- logger.info(f"Raw text cache hit for {file.filename}, {log_memory_usage()}")
251
  else:
252
  if file_ext == '.pdf':
253
- # Try extracting embedded text
254
  try:
255
  extract_start_time = time.time()
256
  reader = PdfReader(file_stream)
@@ -258,16 +347,16 @@ async def extract_and_structure(files: List[UploadFile] = File(...)):
258
  text = page.extract_text()
259
  if text:
260
  raw_text += text + "\n"
261
- logger.info(f"Embedded text extraction for {file.filename}, took {time.time() - extract_start_time:.2f} seconds, text length: {len(raw_text)}, {log_memory_usage()}")
262
  except Exception as e:
263
- logger.warning(f"Embedded text extraction failed for {file.filename}: {str(e)}, {log_memory_usage()}")
264
 
265
  # If no embedded text, perform OCR
266
  if not raw_text.strip():
267
  try:
268
  convert_start_time = time.time()
269
- images = convert_from_bytes(file_bytes, dpi=100)
270
- logger.info(f"PDF to images conversion for {file.filename}, {len(images)} pages, took {time.time() - convert_start_time:.2f} seconds, {log_memory_usage()}")
271
 
272
  ocr_start_time = time.time()
273
  page_texts = []
@@ -275,7 +364,7 @@ async def extract_and_structure(files: List[UploadFile] = File(...)):
275
  page_text = await process_pdf_page(img, i)
276
  page_texts.append(page_text)
277
  raw_text = "".join(page_texts)
278
- logger.info(f"Total OCR for {file.filename}, took {time.time() - ocr_start_time:.2f} seconds, text length: {len(raw_text)}, {log_memory_usage()}")
279
  except Exception as e:
280
  fail_count += 1
281
  output_json["data"].append({
@@ -283,13 +372,13 @@ async def extract_and_structure(files: List[UploadFile] = File(...)):
283
  "structured_data": {"error": f"OCR failed: {str(e)}"},
284
  "error": f"OCR failed: {str(e)}"
285
  })
286
- logger.error(f"OCR failed for {file.filename}: {str(e)}, {log_memory_usage()}")
287
  continue
288
  else: # JPG/JPEG/PNG
289
  try:
290
  ocr_start_time = time.time()
291
  raw_text = await process_image(file_bytes, file.filename, 0)
292
- logger.info(f"Image OCR for {file.filename}, took {time.time() - ocr_start_time:.2f} seconds, text length: {len(raw_text)}, {log_memory_usage()}")
293
  except Exception as e:
294
  fail_count += 1
295
  output_json["data"].append({
@@ -297,33 +386,45 @@ async def extract_and_structure(files: List[UploadFile] = File(...)):
297
  "structured_data": {"error": f"Image OCR failed: {str(e)}"},
298
  "error": f"Image OCR failed: {str(e)}"
299
  })
300
- logger.error(f"Image OCR failed for {file.filename}: {str(e)}, {log_memory_usage()}")
301
  continue
302
 
303
  # Normalize text
304
  try:
305
- normalize_start_time = time.time()
306
  raw_text = unicodedata.normalize('NFKC', raw_text)
307
- raw_text = raw_text.encode().decode('utf-8')
308
  raw_text_cache[file_hash] = raw_text
309
- logger.info(f"Text normalization for {file.filename}, took {time.time() - normalize_start_time:.2f} seconds, text length: {len(raw_text)}, {log_memory_usage()}")
310
  except Exception as e:
311
- logger.warning(f"Text normalization failed for {file.filename}: {str(e)}, {log_memory_usage()}")
312
 
313
- # Process with BitNet
314
- structured_data = await process_with_bitnet(file.filename, raw_text)
315
- success_count += 1
316
- output_json["data"].append({
317
- "filename": file.filename,
318
- "structured_data": structured_data,
319
- "error": ""
320
- })
 
 
 
 
 
 
 
 
 
321
 
322
- logger.info(f"Total processing for {file.filename}, took {time.time() - total_start_time:.2f} seconds, {log_memory_usage()}")
323
 
324
  output_json["message"] = f"Processed {len(files)} files. {success_count} succeeded, {fail_count} failed."
325
  if fail_count > 0 and success_count == 0:
326
  output_json["success"] = False
327
 
328
- logger.info(f"Completed processing for {len(files)} files, {success_count} succeeded, {fail_count} failed, {log_memory_usage()}")
329
- return output_json
 
 
 
 
 
16
  import psutil
17
  import cachetools
18
  import hashlib
 
19
 
20
+ app = FastAPI(title="Invoice OCR and Extraction API", version="1.0.0")
21
 
22
  # Configure logging
23
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
26
  # Set Tesseract path
27
  pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"
28
 
29
+ # Initialize LLM with fallback handling
30
+ llm = None
31
  try:
32
+ # Try to import and initialize vLLM
33
+ from vllm import LLM
34
+
35
+ # For Hugging Face Spaces, use a smaller, more compatible model
36
+ model_name = "microsoft/DialoGPT-medium" # Fallback model
37
+
38
  llm = LLM(
39
+ model=model_name,
40
  device="cpu",
41
+ enforce_eager=True,
42
+ tensor_parallel_size=1,
43
+ disable_custom_all_reduce=True,
44
+ max_model_len=1024, # Reduced for compatibility
45
+ trust_remote_code=True
46
  )
47
+ logger.info("LLM model loaded successfully")
48
  except Exception as e:
49
+ logger.error(f"Failed to load vLLM: {str(e)}")
50
+ logger.info("Will use rule-based extraction as fallback")
51
 
52
  # In-memory caches (1-hour TTL)
53
  raw_text_cache = cachetools.TTLCache(maxsize=100, ttl=3600)
 
55
 
56
  def log_memory_usage():
57
  """Log current memory usage."""
58
+ try:
59
+ process = psutil.Process()
60
+ mem_info = process.memory_info()
61
+ return f"Memory usage: {mem_info.rss / 1024 / 1024:.2f} MB"
62
+ except:
63
+ return "Memory usage: N/A"
64
 
65
  def get_file_hash(file_bytes):
66
  """Generate MD5 hash of file content."""
 
76
  logger.info(f"Starting OCR for {filename} image {idx}, {log_memory_usage()}")
77
  try:
78
  img = Image.open(io.BytesIO(img_bytes))
79
+ # Convert to RGB if needed
80
+ if img.mode != 'RGB':
81
+ img = img.convert('RGB')
82
+
83
  img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
84
  gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
85
+
86
+ # Preprocess image for better OCR
87
+ gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
88
+
89
+ img_pil = Image.fromarray(gray)
90
+ custom_config = r'--oem 3 --psm 6 -l eng'
91
  page_text = pytesseract.image_to_string(img_pil, config=custom_config)
92
+
93
+ logger.info(f"Completed OCR for {filename} image {idx}, took {time.time() - start_time:.2f} seconds")
94
  return page_text + "\n"
95
  except Exception as e:
96
+ logger.error(f"OCR failed for {filename} image {idx}: {str(e)}")
97
  return ""
98
 
99
  async def process_pdf_page(img, page_idx):
100
  """Process a single PDF page with OCR."""
101
  start_time = time.time()
102
+ logger.info(f"Starting OCR for PDF page {page_idx}")
103
  try:
104
  img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
105
  gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
106
+
107
+ # Preprocess image for better OCR
108
+ gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
109
+
110
+ img_pil = Image.fromarray(gray)
111
+ custom_config = r'--oem 3 --psm 6 -l eng'
112
  page_text = pytesseract.image_to_string(img_pil, config=custom_config)
113
+
114
+ logger.info(f"Completed OCR for PDF page {page_idx}, took {time.time() - start_time:.2f} seconds")
115
  return page_text + "\n"
116
  except Exception as e:
117
+ logger.error(f"OCR failed for PDF page {page_idx}: {str(e)}")
118
  return ""
119
 
120
+ def rule_based_extraction(raw_text: str):
121
+ """Rule-based fallback extraction when LLM is not available."""
122
+ import re
123
+
124
+ # Initialize the structure
125
+ structured_data = {
126
+ "invoice": {
127
+ "invoice_number": {"value": "", "accuracy": 0.0},
128
+ "invoice_date": {"value": "", "accuracy": 0.0},
129
+ "due_date": {"value": "", "accuracy": 0.0},
130
+ "purchase_order_number": {"value": "", "accuracy": 0.0},
131
+ "vendor": {
132
+ "vendor_id": {"value": "", "accuracy": 0.0},
133
+ "name": {"value": "", "accuracy": 0.0},
134
+ "address": {
135
+ "line1": {"value": "", "accuracy": 0.0},
136
+ "line2": {"value": "", "accuracy": 0.0},
137
+ "city": {"value": "", "accuracy": 0.0},
138
+ "state": {"value": "", "accuracy": 0.0},
139
+ "postal_code": {"value": "", "accuracy": 0.0},
140
+ "country": {"value": "", "accuracy": 0.0}
141
+ },
142
+ "contact": {
143
+ "email": {"value": "", "accuracy": 0.0},
144
+ "phone": {"value": "", "accuracy": 0.0}
145
+ },
146
+ "tax_id": {"value": "", "accuracy": 0.0}
147
+ },
148
+ "buyer": {
149
+ "buyer_id": {"value": "", "accuracy": 0.0},
150
+ "name": {"value": "", "accuracy": 0.0},
151
+ "address": {
152
+ "line1": {"value": "", "accuracy": 0.0},
153
+ "line2": {"value": "", "accuracy": 0.0},
154
+ "city": {"value": "", "accuracy": 0.0},
155
+ "state": {"value": "", "accuracy": 0.0},
156
+ "postal_code": {"value": "", "accuracy": 0.0},
157
+ "country": {"value": "", "accuracy": 0.0}
158
+ },
159
+ "contact": {
160
+ "email": {"value": "", "accuracy": 0.0},
161
+ "phone": {"value": "", "accuracy": 0.0}
162
+ },
163
+ "tax_id": {"value": "", "accuracy": 0.0}
164
+ },
165
+ "items": [{
166
+ "item_id": {"value": "", "accuracy": 0.0},
167
+ "description": {"value": "", "accuracy": 0.0},
168
+ "quantity": {"value": 0, "accuracy": 0.0},
169
+ "unit_of_measure": {"value": "", "accuracy": 0.0},
170
+ "unit_price": {"value": 0, "accuracy": 0.0},
171
+ "total_price": {"value": 0, "accuracy": 0.0},
172
+ "tax_rate": {"value": 0, "accuracy": 0.0},
173
+ "tax_amount": {"value": 0, "accuracy": 0.0},
174
+ "discount": {"value": 0, "accuracy": 0.0},
175
+ "net_amount": {"value": 0, "accuracy": 0.0}
176
+ }],
177
+ "sub_total": {"value": 0, "accuracy": 0.0},
178
+ "tax_total": {"value": 0, "accuracy": 0.0},
179
+ "discount_total": {"value": 0, "accuracy": 0.0},
180
+ "total_amount": {"value": 0, "accuracy": 0.0},
181
+ "currency": {"value": "USD", "accuracy": 0.5}
182
+ }
183
+ }
184
+
185
+ # Simple pattern matching
186
+ try:
187
+ # Invoice number
188
+ inv_pattern = r'(?:invoice|inv)(?:\s*#|\s*no\.?|\s*number)?\s*:?\s*([A-Z0-9\-/]+)'
189
+ inv_match = re.search(inv_pattern, raw_text, re.IGNORECASE)
190
+ if inv_match:
191
+ structured_data["invoice"]["invoice_number"]["value"] = inv_match.group(1)
192
+ structured_data["invoice"]["invoice_number"]["accuracy"] = 0.7
193
+
194
+ # Date patterns
195
+ date_pattern = r'(\d{1,2}[/-]\d{1,2}[/-]\d{2,4}|\d{4}[/-]\d{1,2}[/-]\d{1,2})'
196
+ dates = re.findall(date_pattern, raw_text)
197
+ if dates:
198
+ structured_data["invoice"]["invoice_date"]["value"] = dates[0]
199
+ structured_data["invoice"]["invoice_date"]["accuracy"] = 0.6
200
+
201
+ # Total amount
202
+ amount_pattern = r'(?:total|amount|sum)\s*:?\s*\$?(\d+\.?\d*)'
203
+ amount_match = re.search(amount_pattern, raw_text, re.IGNORECASE)
204
+ if amount_match:
205
+ structured_data["invoice"]["total_amount"]["value"] = float(amount_match.group(1))
206
+ structured_data["invoice"]["total_amount"]["accuracy"] = 0.6
207
+
208
+ # Email
209
+ email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
210
+ email_match = re.search(email_pattern, raw_text)
211
+ if email_match:
212
+ structured_data["invoice"]["vendor"]["contact"]["email"]["value"] = email_match.group()
213
+ structured_data["invoice"]["vendor"]["contact"]["email"]["accuracy"] = 0.8
214
+
215
+ # Phone
216
+ phone_pattern = r'(?:\+?1[-.\s]?)?\(?([0-9]{3})\)?[-.\s]?([0-9]{3})[-.\s]?([0-9]{4})'
217
+ phone_match = re.search(phone_pattern, raw_text)
218
+ if phone_match:
219
+ structured_data["invoice"]["vendor"]["contact"]["phone"]["value"] = phone_match.group()
220
+ structured_data["invoice"]["vendor"]["contact"]["phone"]["accuracy"] = 0.7
221
+
222
+ except Exception as e:
223
+ logger.error(f"Rule-based extraction error: {str(e)}")
224
+
225
+ return structured_data
226
+
227
+ async def process_with_model(filename: str, raw_text: str):
228
+ """Process raw text with available model or fallback to rule-based."""
229
  start_time = time.time()
230
+ logger.info(f"Starting text processing for {filename}")
231
 
232
  # Check structured data cache
233
  text_hash = get_text_hash(raw_text)
234
  if text_hash in structured_data_cache:
235
+ logger.info(f"Structured data cache hit for {filename}")
236
  return structured_data_cache[text_hash]
237
 
238
+ # Truncate text
239
+ if len(raw_text) > 5000:
240
+ raw_text = raw_text[:5000]
241
+ logger.info(f"Truncated raw text for {filename} to 5000 characters")
242
 
243
  try:
244
+ if llm is not None:
245
+ # Use LLM if available
246
+ prompt = f"""Extract invoice data from this text and return JSON:
247
+
248
+ Text: {raw_text}
249
+
250
+ Return structured JSON with invoice details including vendor, amounts, dates."""
251
+
252
+ outputs = llm.generate(prompts=[prompt], sampling_params={"max_tokens": 512, "temperature": 0.1})
253
+ response_text = outputs[0].outputs[0].text
254
+
255
+ # Try to parse JSON from response
256
+ try:
257
+ json_start = response_text.find("{")
258
+ json_end = response_text.rfind("}") + 1
259
+ if json_start >= 0 and json_end > json_start:
260
+ structured_data = json.loads(response_text[json_start:json_end])
261
+ else:
262
+ raise ValueError("No JSON found in response")
263
+ except:
264
+ # Fallback to rule-based if JSON parsing fails
265
+ structured_data = rule_based_extraction(raw_text)
266
+ else:
267
+ # Use rule-based extraction
268
+ structured_data = rule_based_extraction(raw_text)
269
+
270
+ # Cache the result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  structured_data_cache[text_hash] = structured_data
272
+ logger.info(f"Text processing for {filename} completed in {time.time() - start_time:.2f} seconds")
273
  return structured_data
274
+
275
  except Exception as e:
276
+ logger.error(f"Text processing failed for {filename}: {str(e)}")
277
+ return rule_based_extraction(raw_text)
278
+
279
+ @app.get("/")
280
+ async def root():
281
+ """Health check endpoint."""
282
+ return {
283
+ "message": "Invoice OCR and Extraction API",
284
+ "status": "active",
285
+ "llm_available": llm is not None
286
+ }
287
 
288
  @app.post("/ocr")
289
  async def extract_and_structure(files: List[UploadFile] = File(...)):
290
+ """Main endpoint for OCR and data extraction."""
291
  output_json = {
292
  "success": True,
293
  "message": "",
 
296
  success_count = 0
297
  fail_count = 0
298
 
299
+ logger.info(f"Starting processing for {len(files)} files")
300
 
301
  for file in files:
302
  total_start_time = time.time()
303
+ logger.info(f"Processing file: {file.filename}")
304
 
305
  # Validate file format
306
  valid_extensions = {'.pdf', '.jpg', '.jpeg', '.png'}
307
+ file_ext = os.path.splitext(file.filename.lower())[1] if file.filename else '.unknown'
308
  if file_ext not in valid_extensions:
309
  fail_count += 1
310
  output_json["data"].append({
 
321
  file_bytes = await file.read()
322
  file_stream = io.BytesIO(file_bytes)
323
  file_hash = get_file_hash(file_bytes)
324
+ logger.info(f"Read file {file.filename}, size: {len(file_bytes)/1024:.2f} KB")
325
  except Exception as e:
326
  fail_count += 1
327
  output_json["data"].append({
 
329
  "structured_data": {"error": f"Failed to read file: {str(e)}"},
330
  "error": f"Failed to read file: {str(e)}"
331
  })
332
+ logger.error(f"Failed to read file {file.filename}: {str(e)}")
333
  continue
334
 
335
  # Check raw text cache
336
  raw_text = ""
337
  if file_hash in raw_text_cache:
338
  raw_text = raw_text_cache[file_hash]
339
+ logger.info(f"Raw text cache hit for {file.filename}")
340
  else:
341
  if file_ext == '.pdf':
342
+ # Try extracting embedded text first
343
  try:
344
  extract_start_time = time.time()
345
  reader = PdfReader(file_stream)
 
347
  text = page.extract_text()
348
  if text:
349
  raw_text += text + "\n"
350
+ logger.info(f"Embedded text extraction for {file.filename}, text length: {len(raw_text)}")
351
  except Exception as e:
352
+ logger.warning(f"Embedded text extraction failed for {file.filename}: {str(e)}")
353
 
354
  # If no embedded text, perform OCR
355
  if not raw_text.strip():
356
  try:
357
  convert_start_time = time.time()
358
+ images = convert_from_bytes(file_bytes, dpi=150, first_page=1, last_page=3) # Limit pages
359
+ logger.info(f"PDF to images conversion for {file.filename}, {len(images)} pages")
360
 
361
  ocr_start_time = time.time()
362
  page_texts = []
 
364
  page_text = await process_pdf_page(img, i)
365
  page_texts.append(page_text)
366
  raw_text = "".join(page_texts)
367
+ logger.info(f"Total OCR for {file.filename}, text length: {len(raw_text)}")
368
  except Exception as e:
369
  fail_count += 1
370
  output_json["data"].append({
 
372
  "structured_data": {"error": f"OCR failed: {str(e)}"},
373
  "error": f"OCR failed: {str(e)}"
374
  })
375
+ logger.error(f"OCR failed for {file.filename}: {str(e)}")
376
  continue
377
  else: # JPG/JPEG/PNG
378
  try:
379
  ocr_start_time = time.time()
380
  raw_text = await process_image(file_bytes, file.filename, 0)
381
+ logger.info(f"Image OCR for {file.filename}, text length: {len(raw_text)}")
382
  except Exception as e:
383
  fail_count += 1
384
  output_json["data"].append({
 
386
  "structured_data": {"error": f"Image OCR failed: {str(e)}"},
387
  "error": f"Image OCR failed: {str(e)}"
388
  })
389
+ logger.error(f"Image OCR failed for {file.filename}: {str(e)}")
390
  continue
391
 
392
  # Normalize text
393
  try:
 
394
  raw_text = unicodedata.normalize('NFKC', raw_text)
395
+ raw_text = raw_text.encode('utf-8', errors='ignore').decode('utf-8')
396
  raw_text_cache[file_hash] = raw_text
397
+ logger.info(f"Text normalization for {file.filename} completed")
398
  except Exception as e:
399
+ logger.warning(f"Text normalization failed for {file.filename}: {str(e)}")
400
 
401
+ # Process with model or rule-based extraction
402
+ if raw_text.strip():
403
+ structured_data = await process_with_model(file.filename, raw_text)
404
+ success_count += 1
405
+ output_json["data"].append({
406
+ "filename": file.filename,
407
+ "structured_data": structured_data,
408
+ "raw_text": raw_text[:500] + "..." if len(raw_text) > 500 else raw_text, # Include snippet
409
+ "error": ""
410
+ })
411
+ else:
412
+ fail_count += 1
413
+ output_json["data"].append({
414
+ "filename": file.filename,
415
+ "structured_data": {"error": "No text extracted from file"},
416
+ "error": "No text extracted from file"
417
+ })
418
 
419
+ logger.info(f"Total processing for {file.filename} completed in {time.time() - total_start_time:.2f} seconds")
420
 
421
  output_json["message"] = f"Processed {len(files)} files. {success_count} succeeded, {fail_count} failed."
422
  if fail_count > 0 and success_count == 0:
423
  output_json["success"] = False
424
 
425
+ logger.info(f"Batch processing completed: {success_count} succeeded, {fail_count} failed")
426
+ return output_json
427
+
428
+ if __name__ == "__main__":
429
+ import uvicorn
430
+ uvicorn.run(app, host="0.0.0.0", port=7860)