|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
import pytesseract |
|
import cv2 |
|
import os |
|
from PIL import Image |
|
import json |
|
import unicodedata |
|
from pdf2image import convert_from_bytes |
|
from pypdf import PdfReader |
|
import numpy as np |
|
from typing import List |
|
import io |
|
import logging |
|
import time |
|
import asyncio |
|
import psutil |
|
import cachetools |
|
import hashlib |
|
|
|
app = FastAPI(title="Invoice OCR and Extraction API", version="1.0.0") |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract" |
|
|
|
|
|
llm = None |
|
try: |
|
|
|
from transformers import pipeline |
|
|
|
|
|
llm = pipeline("text-generation", |
|
model="microsoft/DialoGPT-small", |
|
device=-1, |
|
return_full_text=False, |
|
max_length=512) |
|
logger.info("Lightweight text generation model loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Failed to load text generation model: {str(e)}") |
|
logger.info("Will use rule-based extraction only") |
|
|
|
|
|
raw_text_cache = cachetools.TTLCache(maxsize=100, ttl=3600) |
|
structured_data_cache = cachetools.TTLCache(maxsize=100, ttl=3600) |
|
|
|
def log_memory_usage(): |
|
"""Log current memory usage.""" |
|
try: |
|
process = psutil.Process() |
|
mem_info = process.memory_info() |
|
return f"Memory usage: {mem_info.rss / 1024 / 1024:.2f} MB" |
|
except: |
|
return "Memory usage: N/A" |
|
|
|
def get_file_hash(file_bytes): |
|
"""Generate MD5 hash of file content.""" |
|
return hashlib.md5(file_bytes).hexdigest() |
|
|
|
def get_text_hash(raw_text): |
|
"""Generate MD5 hash of raw text.""" |
|
return hashlib.md5(raw_text.encode('utf-8')).hexdigest() |
|
|
|
async def process_image(img_bytes, filename, idx): |
|
"""Process a single image (JPG/JPEG/PNG) with OCR.""" |
|
start_time = time.time() |
|
logger.info(f"Starting OCR for {filename} image {idx}, {log_memory_usage()}") |
|
try: |
|
img = Image.open(io.BytesIO(img_bytes)) |
|
|
|
if img.mode != 'RGB': |
|
img = img.convert('RGB') |
|
|
|
img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) |
|
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] |
|
|
|
img_pil = Image.fromarray(gray) |
|
custom_config = r'--oem 3 --psm 6 -l eng' |
|
page_text = pytesseract.image_to_string(img_pil, config=custom_config) |
|
|
|
logger.info(f"Completed OCR for {filename} image {idx}, took {time.time() - start_time:.2f} seconds") |
|
return page_text + "\n" |
|
except Exception as e: |
|
logger.error(f"OCR failed for {filename} image {idx}: {str(e)}") |
|
return "" |
|
|
|
async def process_pdf_page(img, page_idx): |
|
"""Process a single PDF page with OCR.""" |
|
start_time = time.time() |
|
logger.info(f"Starting OCR for PDF page {page_idx}") |
|
try: |
|
img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) |
|
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] |
|
|
|
img_pil = Image.fromarray(gray) |
|
custom_config = r'--oem 3 --psm 6 -l eng' |
|
page_text = pytesseract.image_to_string(img_pil, config=custom_config) |
|
|
|
logger.info(f"Completed OCR for PDF page {page_idx}, took {time.time() - start_time:.2f} seconds") |
|
return page_text + "\n" |
|
except Exception as e: |
|
logger.error(f"OCR failed for PDF page {page_idx}: {str(e)}") |
|
return "" |
|
|
|
def rule_based_extraction(raw_text: str): |
|
"""Rule-based fallback extraction when LLM is not available.""" |
|
import re |
|
|
|
|
|
structured_data = { |
|
"invoice": { |
|
"invoice_number": {"value": "", "accuracy": 0.0}, |
|
"invoice_date": {"value": "", "accuracy": 0.0}, |
|
"due_date": {"value": "", "accuracy": 0.0}, |
|
"purchase_order_number": {"value": "", "accuracy": 0.0}, |
|
"vendor": { |
|
"vendor_id": {"value": "", "accuracy": 0.0}, |
|
"name": {"value": "", "accuracy": 0.0}, |
|
"address": { |
|
"line1": {"value": "", "accuracy": 0.0}, |
|
"line2": {"value": "", "accuracy": 0.0}, |
|
"city": {"value": "", "accuracy": 0.0}, |
|
"state": {"value": "", "accuracy": 0.0}, |
|
"postal_code": {"value": "", "accuracy": 0.0}, |
|
"country": {"value": "", "accuracy": 0.0} |
|
}, |
|
"contact": { |
|
"email": {"value": "", "accuracy": 0.0}, |
|
"phone": {"value": "", "accuracy": 0.0} |
|
}, |
|
"tax_id": {"value": "", "accuracy": 0.0} |
|
}, |
|
"buyer": { |
|
"buyer_id": {"value": "", "accuracy": 0.0}, |
|
"name": {"value": "", "accuracy": 0.0}, |
|
"address": { |
|
"line1": {"value": "", "accuracy": 0.0}, |
|
"line2": {"value": "", "accuracy": 0.0}, |
|
"city": {"value": "", "accuracy": 0.0}, |
|
"state": {"value": "", "accuracy": 0.0}, |
|
"postal_code": {"value": "", "accuracy": 0.0}, |
|
"country": {"value": "", "accuracy": 0.0} |
|
}, |
|
"contact": { |
|
"email": {"value": "", "accuracy": 0.0}, |
|
"phone": {"value": "", "accuracy": 0.0} |
|
}, |
|
"tax_id": {"value": "", "accuracy": 0.0} |
|
}, |
|
"items": [{ |
|
"item_id": {"value": "", "accuracy": 0.0}, |
|
"description": {"value": "", "accuracy": 0.0}, |
|
"quantity": {"value": 0, "accuracy": 0.0}, |
|
"unit_of_measure": {"value": "", "accuracy": 0.0}, |
|
"unit_price": {"value": 0, "accuracy": 0.0}, |
|
"total_price": {"value": 0, "accuracy": 0.0}, |
|
"tax_rate": {"value": 0, "accuracy": 0.0}, |
|
"tax_amount": {"value": 0, "accuracy": 0.0}, |
|
"discount": {"value": 0, "accuracy": 0.0}, |
|
"net_amount": {"value": 0, "accuracy": 0.0} |
|
}], |
|
"sub_total": {"value": 0, "accuracy": 0.0}, |
|
"tax_total": {"value": 0, "accuracy": 0.0}, |
|
"discount_total": {"value": 0, "accuracy": 0.0}, |
|
"total_amount": {"value": 0, "accuracy": 0.0}, |
|
"currency": {"value": "USD", "accuracy": 0.5} |
|
} |
|
} |
|
|
|
|
|
try: |
|
|
|
inv_pattern = r'(?:invoice|inv)(?:\s*#|\s*no\.?|\s*number)?\s*:?\s*([A-Z0-9\-/]+)' |
|
inv_match = re.search(inv_pattern, raw_text, re.IGNORECASE) |
|
if inv_match: |
|
structured_data["invoice"]["invoice_number"]["value"] = inv_match.group(1) |
|
structured_data["invoice"]["invoice_number"]["accuracy"] = 0.7 |
|
|
|
|
|
date_pattern = r'(\d{1,2}[/-]\d{1,2}[/-]\d{2,4}|\d{4}[/-]\d{1,2}[/-]\d{1,2})' |
|
dates = re.findall(date_pattern, raw_text) |
|
if dates: |
|
structured_data["invoice"]["invoice_date"]["value"] = dates[0] |
|
structured_data["invoice"]["invoice_date"]["accuracy"] = 0.6 |
|
|
|
|
|
amount_pattern = r'(?:total|amount|sum)\s*:?\s*\$?(\d+\.?\d*)' |
|
amount_match = re.search(amount_pattern, raw_text, re.IGNORECASE) |
|
if amount_match: |
|
structured_data["invoice"]["total_amount"]["value"] = float(amount_match.group(1)) |
|
structured_data["invoice"]["total_amount"]["accuracy"] = 0.6 |
|
|
|
|
|
email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b' |
|
email_match = re.search(email_pattern, raw_text) |
|
if email_match: |
|
structured_data["invoice"]["vendor"]["contact"]["email"]["value"] = email_match.group() |
|
structured_data["invoice"]["vendor"]["contact"]["email"]["accuracy"] = 0.8 |
|
|
|
|
|
phone_pattern = r'(?:\+?1[-.\s]?)?\(?([0-9]{3})\)?[-.\s]?([0-9]{3})[-.\s]?([0-9]{4})' |
|
phone_match = re.search(phone_pattern, raw_text) |
|
if phone_match: |
|
structured_data["invoice"]["vendor"]["contact"]["phone"]["value"] = phone_match.group() |
|
structured_data["invoice"]["vendor"]["contact"]["phone"]["accuracy"] = 0.7 |
|
|
|
except Exception as e: |
|
logger.error(f"Rule-based extraction error: {str(e)}") |
|
|
|
return structured_data |
|
|
|
async def process_with_model(filename: str, raw_text: str): |
|
"""Process raw text with available model or fallback to rule-based.""" |
|
start_time = time.time() |
|
logger.info(f"Starting text processing for {filename}") |
|
|
|
|
|
text_hash = get_text_hash(raw_text) |
|
if text_hash in structured_data_cache: |
|
logger.info(f"Structured data cache hit for {filename}") |
|
return structured_data_cache[text_hash] |
|
|
|
|
|
if len(raw_text) > 5000: |
|
raw_text = raw_text[:5000] |
|
logger.info(f"Truncated raw text for {filename} to 5000 characters") |
|
|
|
try: |
|
if llm is not None: |
|
|
|
prompt = f"""Extract key information from this invoice text and format as JSON: |
|
|
|
Invoice Text: {raw_text[:1000]} |
|
|
|
Please extract: invoice number, date, vendor name, total amount, email, phone number.""" |
|
|
|
try: |
|
response = llm(prompt, max_length=200, num_return_sequences=1, temperature=0.7) |
|
response_text = response[0]['generated_text'] if response else "" |
|
|
|
|
|
|
|
structured_data = rule_based_extraction(raw_text) |
|
|
|
|
|
if "invoice" in response_text.lower(): |
|
|
|
for key in structured_data["invoice"]: |
|
if isinstance(structured_data["invoice"][key], dict) and "accuracy" in structured_data["invoice"][key]: |
|
if structured_data["invoice"][key]["accuracy"] > 0: |
|
structured_data["invoice"][key]["accuracy"] = min(0.8, structured_data["invoice"][key]["accuracy"] + 0.1) |
|
|
|
except Exception as model_error: |
|
logger.warning(f"Model processing failed, using rule-based: {str(model_error)}") |
|
structured_data = rule_based_extraction(raw_text) |
|
else: |
|
|
|
structured_data = rule_based_extraction(raw_text) |
|
|
|
|
|
structured_data_cache[text_hash] = structured_data |
|
logger.info(f"Text processing for {filename} completed in {time.time() - start_time:.2f} seconds") |
|
return structured_data |
|
|
|
except Exception as e: |
|
logger.error(f"Text processing failed for {filename}: {str(e)}") |
|
return rule_based_extraction(raw_text) |
|
|
|
@app.get("/") |
|
async def root(): |
|
"""Health check endpoint.""" |
|
return { |
|
"message": "Invoice OCR and Extraction API", |
|
"status": "active", |
|
"llm_available": llm is not None |
|
} |
|
|
|
@app.post("/ocr") |
|
async def extract_and_structure(files: List[UploadFile] = File(...)): |
|
"""Main endpoint for OCR and data extraction.""" |
|
output_json = { |
|
"success": True, |
|
"message": "", |
|
"data": [] |
|
} |
|
success_count = 0 |
|
fail_count = 0 |
|
|
|
logger.info(f"Starting processing for {len(files)} files") |
|
|
|
for file in files: |
|
total_start_time = time.time() |
|
logger.info(f"Processing file: {file.filename}") |
|
|
|
|
|
valid_extensions = {'.pdf', '.jpg', '.jpeg', '.png'} |
|
file_ext = os.path.splitext(file.filename.lower())[1] if file.filename else '.unknown' |
|
if file_ext not in valid_extensions: |
|
fail_count += 1 |
|
output_json["data"].append({ |
|
"filename": file.filename, |
|
"structured_data": {"error": f"Unsupported file format: {file_ext}"}, |
|
"error": f"Unsupported file format: {file_ext}" |
|
}) |
|
logger.error(f"Unsupported file format for {file.filename}: {file_ext}") |
|
continue |
|
|
|
|
|
try: |
|
file_start_time = time.time() |
|
file_bytes = await file.read() |
|
file_stream = io.BytesIO(file_bytes) |
|
file_hash = get_file_hash(file_bytes) |
|
logger.info(f"Read file {file.filename}, size: {len(file_bytes)/1024:.2f} KB") |
|
except Exception as e: |
|
fail_count += 1 |
|
output_json["data"].append({ |
|
"filename": file.filename, |
|
"structured_data": {"error": f"Failed to read file: {str(e)}"}, |
|
"error": f"Failed to read file: {str(e)}" |
|
}) |
|
logger.error(f"Failed to read file {file.filename}: {str(e)}") |
|
continue |
|
|
|
|
|
raw_text = "" |
|
if file_hash in raw_text_cache: |
|
raw_text = raw_text_cache[file_hash] |
|
logger.info(f"Raw text cache hit for {file.filename}") |
|
else: |
|
if file_ext == '.pdf': |
|
|
|
try: |
|
extract_start_time = time.time() |
|
reader = PdfReader(file_stream) |
|
for page in reader.pages: |
|
text = page.extract_text() |
|
if text: |
|
raw_text += text + "\n" |
|
logger.info(f"Embedded text extraction for {file.filename}, text length: {len(raw_text)}") |
|
except Exception as e: |
|
logger.warning(f"Embedded text extraction failed for {file.filename}: {str(e)}") |
|
|
|
|
|
if not raw_text.strip(): |
|
try: |
|
convert_start_time = time.time() |
|
images = convert_from_bytes(file_bytes, dpi=150, first_page=1, last_page=3) |
|
logger.info(f"PDF to images conversion for {file.filename}, {len(images)} pages") |
|
|
|
ocr_start_time = time.time() |
|
page_texts = [] |
|
for i, img in enumerate(images): |
|
page_text = await process_pdf_page(img, i) |
|
page_texts.append(page_text) |
|
raw_text = "".join(page_texts) |
|
logger.info(f"Total OCR for {file.filename}, text length: {len(raw_text)}") |
|
except Exception as e: |
|
fail_count += 1 |
|
output_json["data"].append({ |
|
"filename": file.filename, |
|
"structured_data": {"error": f"OCR failed: {str(e)}"}, |
|
"error": f"OCR failed: {str(e)}" |
|
}) |
|
logger.error(f"OCR failed for {file.filename}: {str(e)}") |
|
continue |
|
else: |
|
try: |
|
ocr_start_time = time.time() |
|
raw_text = await process_image(file_bytes, file.filename, 0) |
|
logger.info(f"Image OCR for {file.filename}, text length: {len(raw_text)}") |
|
except Exception as e: |
|
fail_count += 1 |
|
output_json["data"].append({ |
|
"filename": file.filename, |
|
"structured_data": {"error": f"Image OCR failed: {str(e)}"}, |
|
"error": f"Image OCR failed: {str(e)}" |
|
}) |
|
logger.error(f"Image OCR failed for {file.filename}: {str(e)}") |
|
continue |
|
|
|
|
|
try: |
|
raw_text = unicodedata.normalize('NFKC', raw_text) |
|
raw_text = raw_text.encode('utf-8', errors='ignore').decode('utf-8') |
|
raw_text_cache[file_hash] = raw_text |
|
logger.info(f"Text normalization for {file.filename} completed") |
|
except Exception as e: |
|
logger.warning(f"Text normalization failed for {file.filename}: {str(e)}") |
|
|
|
|
|
if raw_text.strip(): |
|
structured_data = await process_with_model(file.filename, raw_text) |
|
success_count += 1 |
|
output_json["data"].append({ |
|
"filename": file.filename, |
|
"structured_data": structured_data, |
|
"raw_text": raw_text[:500] + "..." if len(raw_text) > 500 else raw_text, |
|
"error": "" |
|
}) |
|
else: |
|
fail_count += 1 |
|
output_json["data"].append({ |
|
"filename": file.filename, |
|
"structured_data": {"error": "No text extracted from file"}, |
|
"error": "No text extracted from file" |
|
}) |
|
|
|
logger.info(f"Total processing for {file.filename} completed in {time.time() - total_start_time:.2f} seconds") |
|
|
|
output_json["message"] = f"Processed {len(files)} files. {success_count} succeeded, {fail_count} failed." |
|
if fail_count > 0 and success_count == 0: |
|
output_json["success"] = False |
|
|
|
logger.info(f"Batch processing completed: {success_count} succeeded, {fail_count} failed") |
|
return output_json |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |