|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
import pytesseract |
|
import cv2 |
|
import os |
|
from PIL import Image |
|
import json |
|
import unicodedata |
|
from pdf2image import convert_from_bytes |
|
from pypdf import PdfReader |
|
import numpy as np |
|
from typing import List |
|
import io |
|
import logging |
|
import time |
|
import asyncio |
|
import psutil |
|
import cachetools |
|
import hashlib |
|
from vllm import LLM |
|
|
|
app = FastAPI() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract" |
|
|
|
|
|
try: |
|
llm = LLM( |
|
model="username/bitnet-finetuned-invoice", |
|
device="cpu", |
|
enforce_eager=True, |
|
tensor_parallel_size=1, |
|
disable_custom_all_reduce=True, |
|
max_model_len=2048, |
|
) |
|
except Exception as e: |
|
logger.error(f"Failed to load BitNet model: {str(e)}") |
|
raise HTTPException(status_code=500, detail="BitNet model initialization failed") |
|
|
|
|
|
raw_text_cache = cachetools.TTLCache(maxsize=100, ttl=3600) |
|
structured_data_cache = cachetools.TTLCache(maxsize=100, ttl=3600) |
|
|
|
def log_memory_usage(): |
|
"""Log current memory usage.""" |
|
process = psutil.Process() |
|
mem_info = process.memory_info() |
|
return f"Memory usage: {mem_info.rss / 1024 / 1024:.2f} MB" |
|
|
|
def get_file_hash(file_bytes): |
|
"""Generate MD5 hash of file content.""" |
|
return hashlib.md5(file_bytes).hexdigest() |
|
|
|
def get_text_hash(raw_text): |
|
"""Generate MD5 hash of raw text.""" |
|
return hashlib.md5(raw_text.encode('utf-8')).hexdigest() |
|
|
|
async def process_image(img_bytes, filename, idx): |
|
"""Process a single image (JPG/JPEG/PNG) with OCR.""" |
|
start_time = time.time() |
|
logger.info(f"Starting OCR for {filename} image {idx}, {log_memory_usage()}") |
|
try: |
|
img = Image.open(io.BytesIO(img_bytes)) |
|
img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) |
|
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) |
|
img_pil = Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)) |
|
custom_config = r'--oem 1 --psm 6 -l eng+ara' |
|
page_text = pytesseract.image_to_string(img_pil, config=custom_config) |
|
logger.info(f"Completed OCR for {filename} image {idx}, took {time.time() - start_time:.2f} seconds, {log_memory_usage()}") |
|
return page_text + "\n" |
|
except Exception as e: |
|
logger.error(f"OCR failed for {filename} image {idx}: {str(e)}, {log_memory_usage()}") |
|
return "" |
|
|
|
async def process_pdf_page(img, page_idx): |
|
"""Process a single PDF page with OCR.""" |
|
start_time = time.time() |
|
logger.info(f"Starting OCR for PDF page {page_idx}, {log_memory_usage()}") |
|
try: |
|
img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) |
|
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) |
|
img_pil = Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)) |
|
custom_config = r'--oem 1 --psm 6 -l eng+ara' |
|
page_text = pytesseract.image_to_string(img_pil, config=custom_config) |
|
logger.info(f"Completed OCR for PDF page {page_idx}, took {time.time() - start_time:.2f} seconds, {log_memory_usage()}") |
|
return page_text + "\n" |
|
except Exception as e: |
|
logger.error(f"OCR failed for PDF page {page_idx}: {str(e)}, {log_memory_usage()}") |
|
return "" |
|
|
|
async def process_with_bitnet(filename: str, raw_text: str): |
|
"""Process raw text with BitNet to extract structured data.""" |
|
start_time = time.time() |
|
logger.info(f"Starting BitNet processing for {filename}, {log_memory_usage()}") |
|
|
|
|
|
text_hash = get_text_hash(raw_text) |
|
if text_hash in structured_data_cache: |
|
logger.info(f"Structured data cache hit for {filename}, {log_memory_usage()}") |
|
return structured_data_cache[text_hash] |
|
|
|
|
|
if len(raw_text) > 10000: |
|
raw_text = raw_text[:10000] |
|
logger.info(f"Truncated raw text for {filename} to 10000 characters, {log_memory_usage()}") |
|
|
|
try: |
|
prompt = f"""You are an intelligent invoice data extractor. Given raw text from an invoice (in English or other languages), |
|
extract key business fields into the specified JSON format. Return each field with an estimated accuracy score between 0 and 1. |
|
|
|
- Accuracy reflects confidence in the correctness of each field. |
|
- Handle synonyms (e.g., 'total' = 'net', 'tax' = 'GST'/'TDS'). |
|
- Detect currency from symbols ($, ₹, €) or keywords (USD, INR, EUR); default to USD if unclear. |
|
- The 'items' list may have multiple entries, each with detailed attributes. |
|
- If a field is missing, return an empty value (`""` or `0`) and set `accuracy` to `0.0`. |
|
- Convert any date to YYYY-MM-DD. |
|
|
|
Raw text: |
|
{raw_text} |
|
|
|
Output JSON: |
|
{{ |
|
"invoice": {{ |
|
"invoice_number": {{"value": "", "accuracy": 0.0}}, |
|
"invoice_date": {{"value": "", "accuracy": 0.0}}, |
|
"due_date": {{"value": "", "accuracy": 0.0}}, |
|
"purchase_order_number": {{"value": "", "accuracy": 0.0}}, |
|
"vendor": {{ |
|
"vendor_id": {{"value": "", "accuracy": 0.0}}, |
|
"name": {{"value": "", "accuracy": 0.0}}, |
|
"address": {{ |
|
"line1": {{"value": "", "accuracy": 0.0}}, |
|
"line2": {{"value": "", "accuracy": 0.0}}, |
|
"city": {{"value": "", "accuracy": 0.0}}, |
|
"state": {{"value": "", "accuracy": 0.0}}, |
|
"postal_code": {{"value": "", "accuracy": 0.0}}, |
|
"country": {{"value": "", "accuracy": 0.0}} |
|
}}, |
|
"contact": {{ |
|
"email": {{"value": "", "accuracy": 0.0}}, |
|
"phone": {{"value": "", "accuracy": 0.0}} |
|
}}, |
|
"tax_id": {{"value": "", "accuracy": 0.0}} |
|
}}, |
|
"buyer": {{ |
|
"buyer_id": {{"value": "", "accuracy": 0.0}}, |
|
"name": {{"value": "", "accuracy": 0.0}}, |
|
"address": {{ |
|
"line1": {{"value": "", "accuracy": 0.0}}, |
|
"line2": {{"value": "", "accuracy": 0.0}}, |
|
"city": {{"value": "", "accuracy": 0.0}}, |
|
"state": {{"value": "", "accuracy": 0.0}}, |
|
"postal_code": {{"value": "", "accuracy": 0.0}}, |
|
"country": {{"value": "", "accuracy": 0.0}} |
|
}}, |
|
"contact": {{ |
|
"email": {{"value": "", "accuracy": 0.0}}, |
|
"phone": {{"value": "", "accuracy": 0.0}} |
|
}}, |
|
"tax_id": {{"value": "", "accuracy": 0.0}} |
|
}}, |
|
"items": [ |
|
{{ |
|
"item_id": {{"value": "", "accuracy": 0.0}}, |
|
"description": {{"value": "", "accuracy": 0.0}}, |
|
"quantity": {{"value": 0, "accuracy": 0.0}}, |
|
"unit_of_measure": {{"value": "", "accuracy": 0.0}}, |
|
"unit_price": {{"value": 0, "accuracy": 0.0}}, |
|
"total_price": {{"value": 0, "accuracy": 0.0}}, |
|
"tax_rate": {{"value": 0, "accuracy": 0.0}}, |
|
"tax_amount": {{"value": 0, "accuracy": 0.0}}, |
|
"discount": {{"value": 0, "accuracy": 0.0}}, |
|
"net_amount": {{"value": 0, "accuracy": 0.0}} |
|
}} |
|
], |
|
"sub_total": {{"value": 0, "accuracy": 0.0}}, |
|
"tax_total": {{"value": 0, "accuracy": 0.0}}, |
|
"discount_total": {{"value": 0, "accuracy": 0.0}}, |
|
"total_amount": {{"value": 0, "accuracy": 0.0}}, |
|
"currency": {{"value": "", "accuracy": 0.0}} |
|
}} |
|
}} |
|
""" |
|
outputs = llm.generate(prompts=[prompt]) |
|
json_str = outputs[0].outputs[0].text |
|
json_start = json_str.find("{") |
|
json_end = json_str.rfind("}") + 1 |
|
structured_data = json.loads(json_str[json_start:json_end]) |
|
structured_data_cache[text_hash] = structured_data |
|
logger.info(f"BitNet processing for {filename}, took {time.time() - start_time:.2f} seconds, {log_memory_usage()}") |
|
return structured_data |
|
except Exception as e: |
|
logger.error(f"BitNet processing failed for {filename}: {str(e)}, {log_memory_usage()}") |
|
return {"error": f"BitNet processing failed: {str(e)}"} |
|
|
|
@app.post("/ocr") |
|
async def extract_and_structure(files: List[UploadFile] = File(...)): |
|
output_json = { |
|
"success": True, |
|
"message": "", |
|
"data": [] |
|
} |
|
success_count = 0 |
|
fail_count = 0 |
|
|
|
logger.info(f"Starting processing for {len(files)} files, {log_memory_usage()}") |
|
|
|
for file in files: |
|
total_start_time = time.time() |
|
logger.info(f"Processing file: {file.filename}, {log_memory_usage()}") |
|
|
|
|
|
valid_extensions = {'.pdf', '.jpg', '.jpeg', '.png'} |
|
file_ext = os.path.splitext(file.filename.lower())[1] |
|
if file_ext not in valid_extensions: |
|
fail_count += 1 |
|
output_json["data"].append({ |
|
"filename": file.filename, |
|
"structured_data": {"error": f"Unsupported file format: {file_ext}"}, |
|
"error": f"Unsupported file format: {file_ext}" |
|
}) |
|
logger.error(f"Unsupported file format for {file.filename}: {file_ext}") |
|
continue |
|
|
|
|
|
try: |
|
file_start_time = time.time() |
|
file_bytes = await file.read() |
|
file_stream = io.BytesIO(file_bytes) |
|
file_hash = get_file_hash(file_bytes) |
|
logger.info(f"Read file {file.filename}, took {time.time() - file_start_time:.2f} seconds, size: {len(file_bytes)/1024:.2f} KB, {log_memory_usage()}") |
|
except Exception as e: |
|
fail_count += 1 |
|
output_json["data"].append({ |
|
"filename": file.filename, |
|
"structured_data": {"error": f"Failed to read file: {str(e)}"}, |
|
"error": f"Failed to read file: {str(e)}" |
|
}) |
|
logger.error(f"Failed to read file {file.filename}: {str(e)}, {log_memory_usage()}") |
|
continue |
|
|
|
|
|
raw_text = "" |
|
if file_hash in raw_text_cache: |
|
raw_text = raw_text_cache[file_hash] |
|
logger.info(f"Raw text cache hit for {file.filename}, {log_memory_usage()}") |
|
else: |
|
if file_ext == '.pdf': |
|
|
|
try: |
|
extract_start_time = time.time() |
|
reader = PdfReader(file_stream) |
|
for page in reader.pages: |
|
text = page.extract_text() |
|
if text: |
|
raw_text += text + "\n" |
|
logger.info(f"Embedded text extraction for {file.filename}, took {time.time() - extract_start_time:.2f} seconds, text length: {len(raw_text)}, {log_memory_usage()}") |
|
except Exception as e: |
|
logger.warning(f"Embedded text extraction failed for {file.filename}: {str(e)}, {log_memory_usage()}") |
|
|
|
|
|
if not raw_text.strip(): |
|
try: |
|
convert_start_time = time.time() |
|
images = convert_from_bytes(file_bytes, dpi=100) |
|
logger.info(f"PDF to images conversion for {file.filename}, {len(images)} pages, took {time.time() - convert_start_time:.2f} seconds, {log_memory_usage()}") |
|
|
|
ocr_start_time = time.time() |
|
page_texts = [] |
|
for i, img in enumerate(images): |
|
page_text = await process_pdf_page(img, i) |
|
page_texts.append(page_text) |
|
raw_text = "".join(page_texts) |
|
logger.info(f"Total OCR for {file.filename}, took {time.time() - ocr_start_time:.2f} seconds, text length: {len(raw_text)}, {log_memory_usage()}") |
|
except Exception as e: |
|
fail_count += 1 |
|
output_json["data"].append({ |
|
"filename": file.filename, |
|
"structured_data": {"error": f"OCR failed: {str(e)}"}, |
|
"error": f"OCR failed: {str(e)}" |
|
}) |
|
logger.error(f"OCR failed for {file.filename}: {str(e)}, {log_memory_usage()}") |
|
continue |
|
else: |
|
try: |
|
ocr_start_time = time.time() |
|
raw_text = await process_image(file_bytes, file.filename, 0) |
|
logger.info(f"Image OCR for {file.filename}, took {time.time() - ocr_start_time:.2f} seconds, text length: {len(raw_text)}, {log_memory_usage()}") |
|
except Exception as e: |
|
fail_count += 1 |
|
output_json["data"].append({ |
|
"filename": file.filename, |
|
"structured_data": {"error": f"Image OCR failed: {str(e)}"}, |
|
"error": f"Image OCR failed: {str(e)}" |
|
}) |
|
logger.error(f"Image OCR failed for {file.filename}: {str(e)}, {log_memory_usage()}") |
|
continue |
|
|
|
|
|
try: |
|
normalize_start_time = time.time() |
|
raw_text = unicodedata.normalize('NFKC', raw_text) |
|
raw_text = raw_text.encode().decode('utf-8') |
|
raw_text_cache[file_hash] = raw_text |
|
logger.info(f"Text normalization for {file.filename}, took {time.time() - normalize_start_time:.2f} seconds, text length: {len(raw_text)}, {log_memory_usage()}") |
|
except Exception as e: |
|
logger.warning(f"Text normalization failed for {file.filename}: {str(e)}, {log_memory_usage()}") |
|
|
|
|
|
structured_data = await process_with_bitnet(file.filename, raw_text) |
|
success_count += 1 |
|
output_json["data"].append({ |
|
"filename": file.filename, |
|
"structured_data": structured_data, |
|
"error": "" |
|
}) |
|
|
|
logger.info(f"Total processing for {file.filename}, took {time.time() - total_start_time:.2f} seconds, {log_memory_usage()}") |
|
|
|
output_json["message"] = f"Processed {len(files)} files. {success_count} succeeded, {fail_count} failed." |
|
if fail_count > 0 and success_count == 0: |
|
output_json["success"] = False |
|
|
|
logger.info(f"Completed processing for {len(files)} files, {success_count} succeeded, {fail_count} failed, {log_memory_usage()}") |
|
return output_json |