from fastapi import FastAPI, File, UploadFile, HTTPException import pytesseract import cv2 import os from PIL import Image import json import unicodedata from pdf2image import convert_from_bytes from pypdf import PdfReader import numpy as np from typing import List import io import logging import time import asyncio import psutil import cachetools import hashlib app = FastAPI(title="Invoice OCR and Extraction API", version="1.0.0") # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Set Tesseract path pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract" # Initialize LLM with fallback handling llm = None try: # Try to import and initialize a lightweight model using transformers from transformers import pipeline # Use a lightweight model for text processing llm = pipeline("text-generation", model="microsoft/DialoGPT-small", device=-1, # CPU only return_full_text=False, max_length=512) logger.info("Lightweight text generation model loaded successfully") except Exception as e: logger.error(f"Failed to load text generation model: {str(e)}") logger.info("Will use rule-based extraction only") # In-memory caches (1-hour TTL) raw_text_cache = cachetools.TTLCache(maxsize=100, ttl=3600) structured_data_cache = cachetools.TTLCache(maxsize=100, ttl=3600) def log_memory_usage(): """Log current memory usage.""" try: process = psutil.Process() mem_info = process.memory_info() return f"Memory usage: {mem_info.rss / 1024 / 1024:.2f} MB" except: return "Memory usage: N/A" def get_file_hash(file_bytes): """Generate MD5 hash of file content.""" return hashlib.md5(file_bytes).hexdigest() def get_text_hash(raw_text): """Generate MD5 hash of raw text.""" return hashlib.md5(raw_text.encode('utf-8')).hexdigest() async def process_image(img_bytes, filename, idx): """Process a single image (JPG/JPEG/PNG) with OCR.""" start_time = time.time() logger.info(f"Starting OCR for {filename} image {idx}, {log_memory_usage()}") try: img = Image.open(io.BytesIO(img_bytes)) # Convert to RGB if needed if img.mode != 'RGB': img = img.convert('RGB') img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) # Preprocess image for better OCR gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] img_pil = Image.fromarray(gray) custom_config = r'--oem 3 --psm 6 -l eng' page_text = pytesseract.image_to_string(img_pil, config=custom_config) logger.info(f"Completed OCR for {filename} image {idx}, took {time.time() - start_time:.2f} seconds") return page_text + "\n" except Exception as e: logger.error(f"OCR failed for {filename} image {idx}: {str(e)}") return "" async def process_pdf_page(img, page_idx): """Process a single PDF page with OCR.""" start_time = time.time() logger.info(f"Starting OCR for PDF page {page_idx}") try: img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) # Preprocess image for better OCR gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] img_pil = Image.fromarray(gray) custom_config = r'--oem 3 --psm 6 -l eng' page_text = pytesseract.image_to_string(img_pil, config=custom_config) logger.info(f"Completed OCR for PDF page {page_idx}, took {time.time() - start_time:.2f} seconds") return page_text + "\n" except Exception as e: logger.error(f"OCR failed for PDF page {page_idx}: {str(e)}") return "" def rule_based_extraction(raw_text: str): """Rule-based fallback extraction when LLM is not available.""" import re # Initialize the structure structured_data = { "invoice": { "invoice_number": {"value": "", "accuracy": 0.0}, "invoice_date": {"value": "", "accuracy": 0.0}, "due_date": {"value": "", "accuracy": 0.0}, "purchase_order_number": {"value": "", "accuracy": 0.0}, "vendor": { "vendor_id": {"value": "", "accuracy": 0.0}, "name": {"value": "", "accuracy": 0.0}, "address": { "line1": {"value": "", "accuracy": 0.0}, "line2": {"value": "", "accuracy": 0.0}, "city": {"value": "", "accuracy": 0.0}, "state": {"value": "", "accuracy": 0.0}, "postal_code": {"value": "", "accuracy": 0.0}, "country": {"value": "", "accuracy": 0.0} }, "contact": { "email": {"value": "", "accuracy": 0.0}, "phone": {"value": "", "accuracy": 0.0} }, "tax_id": {"value": "", "accuracy": 0.0} }, "buyer": { "buyer_id": {"value": "", "accuracy": 0.0}, "name": {"value": "", "accuracy": 0.0}, "address": { "line1": {"value": "", "accuracy": 0.0}, "line2": {"value": "", "accuracy": 0.0}, "city": {"value": "", "accuracy": 0.0}, "state": {"value": "", "accuracy": 0.0}, "postal_code": {"value": "", "accuracy": 0.0}, "country": {"value": "", "accuracy": 0.0} }, "contact": { "email": {"value": "", "accuracy": 0.0}, "phone": {"value": "", "accuracy": 0.0} }, "tax_id": {"value": "", "accuracy": 0.0} }, "items": [{ "item_id": {"value": "", "accuracy": 0.0}, "description": {"value": "", "accuracy": 0.0}, "quantity": {"value": 0, "accuracy": 0.0}, "unit_of_measure": {"value": "", "accuracy": 0.0}, "unit_price": {"value": 0, "accuracy": 0.0}, "total_price": {"value": 0, "accuracy": 0.0}, "tax_rate": {"value": 0, "accuracy": 0.0}, "tax_amount": {"value": 0, "accuracy": 0.0}, "discount": {"value": 0, "accuracy": 0.0}, "net_amount": {"value": 0, "accuracy": 0.0} }], "sub_total": {"value": 0, "accuracy": 0.0}, "tax_total": {"value": 0, "accuracy": 0.0}, "discount_total": {"value": 0, "accuracy": 0.0}, "total_amount": {"value": 0, "accuracy": 0.0}, "currency": {"value": "USD", "accuracy": 0.5} } } # Simple pattern matching try: # Invoice number inv_pattern = r'(?:invoice|inv)(?:\s*#|\s*no\.?|\s*number)?\s*:?\s*([A-Z0-9\-/]+)' inv_match = re.search(inv_pattern, raw_text, re.IGNORECASE) if inv_match: structured_data["invoice"]["invoice_number"]["value"] = inv_match.group(1) structured_data["invoice"]["invoice_number"]["accuracy"] = 0.7 # Date patterns date_pattern = r'(\d{1,2}[/-]\d{1,2}[/-]\d{2,4}|\d{4}[/-]\d{1,2}[/-]\d{1,2})' dates = re.findall(date_pattern, raw_text) if dates: structured_data["invoice"]["invoice_date"]["value"] = dates[0] structured_data["invoice"]["invoice_date"]["accuracy"] = 0.6 # Total amount amount_pattern = r'(?:total|amount|sum)\s*:?\s*\$?(\d+\.?\d*)' amount_match = re.search(amount_pattern, raw_text, re.IGNORECASE) if amount_match: structured_data["invoice"]["total_amount"]["value"] = float(amount_match.group(1)) structured_data["invoice"]["total_amount"]["accuracy"] = 0.6 # Email email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b' email_match = re.search(email_pattern, raw_text) if email_match: structured_data["invoice"]["vendor"]["contact"]["email"]["value"] = email_match.group() structured_data["invoice"]["vendor"]["contact"]["email"]["accuracy"] = 0.8 # Phone phone_pattern = r'(?:\+?1[-.\s]?)?\(?([0-9]{3})\)?[-.\s]?([0-9]{3})[-.\s]?([0-9]{4})' phone_match = re.search(phone_pattern, raw_text) if phone_match: structured_data["invoice"]["vendor"]["contact"]["phone"]["value"] = phone_match.group() structured_data["invoice"]["vendor"]["contact"]["phone"]["accuracy"] = 0.7 except Exception as e: logger.error(f"Rule-based extraction error: {str(e)}") return structured_data async def process_with_model(filename: str, raw_text: str): """Process raw text with available model or fallback to rule-based.""" start_time = time.time() logger.info(f"Starting text processing for {filename}") # Check structured data cache text_hash = get_text_hash(raw_text) if text_hash in structured_data_cache: logger.info(f"Structured data cache hit for {filename}") return structured_data_cache[text_hash] # Truncate text if len(raw_text) > 5000: raw_text = raw_text[:5000] logger.info(f"Truncated raw text for {filename} to 5000 characters") try: if llm is not None: # Use transformers pipeline if available prompt = f"""Extract key information from this invoice text and format as JSON: Invoice Text: {raw_text[:1000]} Please extract: invoice number, date, vendor name, total amount, email, phone number.""" try: response = llm(prompt, max_length=200, num_return_sequences=1, temperature=0.7) response_text = response[0]['generated_text'] if response else "" # Simple parsing - look for structured data in response # This is a simplified approach since we're using a general model structured_data = rule_based_extraction(raw_text) # Enhance with any additional info from model if available if "invoice" in response_text.lower(): # Model provided some invoice-related text, keep rule-based but mark as enhanced for key in structured_data["invoice"]: if isinstance(structured_data["invoice"][key], dict) and "accuracy" in structured_data["invoice"][key]: if structured_data["invoice"][key]["accuracy"] > 0: structured_data["invoice"][key]["accuracy"] = min(0.8, structured_data["invoice"][key]["accuracy"] + 0.1) except Exception as model_error: logger.warning(f"Model processing failed, using rule-based: {str(model_error)}") structured_data = rule_based_extraction(raw_text) else: # Use rule-based extraction structured_data = rule_based_extraction(raw_text) # Cache the result structured_data_cache[text_hash] = structured_data logger.info(f"Text processing for {filename} completed in {time.time() - start_time:.2f} seconds") return structured_data except Exception as e: logger.error(f"Text processing failed for {filename}: {str(e)}") return rule_based_extraction(raw_text) @app.get("/") async def root(): """Health check endpoint.""" return { "message": "Invoice OCR and Extraction API", "status": "active", "llm_available": llm is not None } @app.post("/ocr") async def extract_and_structure(files: List[UploadFile] = File(...)): """Main endpoint for OCR and data extraction.""" output_json = { "success": True, "message": "", "data": [] } success_count = 0 fail_count = 0 logger.info(f"Starting processing for {len(files)} files") for file in files: total_start_time = time.time() logger.info(f"Processing file: {file.filename}") # Validate file format valid_extensions = {'.pdf', '.jpg', '.jpeg', '.png'} file_ext = os.path.splitext(file.filename.lower())[1] if file.filename else '.unknown' if file_ext not in valid_extensions: fail_count += 1 output_json["data"].append({ "filename": file.filename, "structured_data": {"error": f"Unsupported file format: {file_ext}"}, "error": f"Unsupported file format: {file_ext}" }) logger.error(f"Unsupported file format for {file.filename}: {file_ext}") continue # Read file into memory try: file_start_time = time.time() file_bytes = await file.read() file_stream = io.BytesIO(file_bytes) file_hash = get_file_hash(file_bytes) logger.info(f"Read file {file.filename}, size: {len(file_bytes)/1024:.2f} KB") except Exception as e: fail_count += 1 output_json["data"].append({ "filename": file.filename, "structured_data": {"error": f"Failed to read file: {str(e)}"}, "error": f"Failed to read file: {str(e)}" }) logger.error(f"Failed to read file {file.filename}: {str(e)}") continue # Check raw text cache raw_text = "" if file_hash in raw_text_cache: raw_text = raw_text_cache[file_hash] logger.info(f"Raw text cache hit for {file.filename}") else: if file_ext == '.pdf': # Try extracting embedded text first try: extract_start_time = time.time() reader = PdfReader(file_stream) for page in reader.pages: text = page.extract_text() if text: raw_text += text + "\n" logger.info(f"Embedded text extraction for {file.filename}, text length: {len(raw_text)}") except Exception as e: logger.warning(f"Embedded text extraction failed for {file.filename}: {str(e)}") # If no embedded text, perform OCR if not raw_text.strip(): try: convert_start_time = time.time() images = convert_from_bytes(file_bytes, dpi=150, first_page=1, last_page=3) # Limit pages logger.info(f"PDF to images conversion for {file.filename}, {len(images)} pages") ocr_start_time = time.time() page_texts = [] for i, img in enumerate(images): page_text = await process_pdf_page(img, i) page_texts.append(page_text) raw_text = "".join(page_texts) logger.info(f"Total OCR for {file.filename}, text length: {len(raw_text)}") except Exception as e: fail_count += 1 output_json["data"].append({ "filename": file.filename, "structured_data": {"error": f"OCR failed: {str(e)}"}, "error": f"OCR failed: {str(e)}" }) logger.error(f"OCR failed for {file.filename}: {str(e)}") continue else: # JPG/JPEG/PNG try: ocr_start_time = time.time() raw_text = await process_image(file_bytes, file.filename, 0) logger.info(f"Image OCR for {file.filename}, text length: {len(raw_text)}") except Exception as e: fail_count += 1 output_json["data"].append({ "filename": file.filename, "structured_data": {"error": f"Image OCR failed: {str(e)}"}, "error": f"Image OCR failed: {str(e)}" }) logger.error(f"Image OCR failed for {file.filename}: {str(e)}") continue # Normalize text try: raw_text = unicodedata.normalize('NFKC', raw_text) raw_text = raw_text.encode('utf-8', errors='ignore').decode('utf-8') raw_text_cache[file_hash] = raw_text logger.info(f"Text normalization for {file.filename} completed") except Exception as e: logger.warning(f"Text normalization failed for {file.filename}: {str(e)}") # Process with model or rule-based extraction if raw_text.strip(): structured_data = await process_with_model(file.filename, raw_text) success_count += 1 output_json["data"].append({ "filename": file.filename, "structured_data": structured_data, "raw_text": raw_text[:500] + "..." if len(raw_text) > 500 else raw_text, # Include snippet "error": "" }) else: fail_count += 1 output_json["data"].append({ "filename": file.filename, "structured_data": {"error": "No text extracted from file"}, "error": "No text extracted from file" }) logger.info(f"Total processing for {file.filename} completed in {time.time() - total_start_time:.2f} seconds") output_json["message"] = f"Processed {len(files)} files. {success_count} succeeded, {fail_count} failed." if fail_count > 0 and success_count == 0: output_json["success"] = False logger.info(f"Batch processing completed: {success_count} succeeded, {fail_count} failed") return output_json if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)