Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification, TrOCRProcessor, VisionEncoderDecoderModel | |
from PIL import Image | |
import io | |
import json | |
import pandas as pd | |
import plotly.express as px | |
import numpy as np | |
from typing import Dict, Any | |
import logging | |
import pytesseract | |
import re | |
from openai import OpenAI | |
import os | |
from pdf2image import convert_from_bytes | |
from dotenv import load_dotenv | |
from chatbot_utils import ask_receipt_chatbot | |
import time | |
from tensorboard.backend.event_processing import event_accumulator | |
from torch.utils.tensorboard import SummaryWriter | |
import matplotlib.pyplot as plt | |
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay | |
import matplotlib | |
import boto3 | |
from decimal import Decimal | |
import uuid | |
from paddleocr import PaddleOCR | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load environment variables | |
load_dotenv() | |
# Initialize OpenAI client for Perplexity | |
api_key = os.getenv('PERPLEXITY_API_KEY') | |
if not api_key: | |
st.error(""" | |
⚠️ Perplexity API key not found! Please add your API key to the Space's secrets: | |
1. Go to Space Settings | |
2. Click on 'Repository secrets' | |
3. Add a new secret with name 'PERPLEXITY_API_KEY' | |
4. Add your Perplexity API key as the value | |
""") | |
st.stop() | |
client = OpenAI( | |
api_key=api_key, | |
base_url="https://api.perplexity.ai" | |
) | |
# Initialize LayoutLM model | |
def load_model(): | |
model_name = "microsoft/layoutlmv3-base" | |
processor = LayoutLMv3Processor.from_pretrained(model_name) | |
model = LayoutLMv3ForTokenClassification.from_pretrained(model_name) | |
return processor, model | |
def extract_json_from_llm_output(llm_result): | |
# Try to extract JSON from a code block first (```json ... ``` or ``` ... ```) | |
code_block_match = re.search(r"```(?:json)?\s*({[\s\S]*?})\s*```", llm_result, re.IGNORECASE) | |
if code_block_match: | |
return code_block_match.group(1) | |
# Fallback: extract first {...} block | |
match = re.search(r'\{[\s\S]*\}', llm_result) | |
if match: | |
return match.group(0) | |
return None | |
def extract_fields(image_path): | |
text = pytesseract.image_to_string(Image.open(image_path)) | |
st.subheader("Raw OCR Output") | |
st.code(text) | |
# Improved Regex patterns for fields | |
patterns = { | |
"name": r"Mrs\s+\w+\s+\w+", | |
"date": r"Date[:\s]+([\d/]+)", | |
"product": r"\d+\s+\w+.*Style\s+\d+", | |
"amount_paid": r"Total Paid\s+\$?([\d.,]+)", | |
"receipt_no": r"Receipt No\.?\s*:?\s*(\d+)" | |
} | |
results = {} | |
for field, pattern in patterns.items(): | |
match = re.search(pattern, text, re.IGNORECASE) | |
if match: | |
results[field] = match.group(1) if match.groups() else match.group(0) | |
else: | |
results[field] = None | |
# Extract all products | |
results["products"] = extract_products(text) | |
return results | |
def extract_products(text): | |
# Pattern to match product lines with quantity, name, and price | |
# Example: "2 PISTACHIO 14.49" or "1076903 PISTACHIO 14.49" | |
product_pattern = r"(?:(\d+)\s+)?([A-Z0-9 ]+)\s+(\d+\.\d{2})" | |
matches = re.findall(product_pattern, text) | |
products = [] | |
for match in matches: | |
quantity, name, price = match | |
product = { | |
"name": name.strip(), | |
"price": float(price), | |
"quantity": int(quantity) if quantity else 1, | |
"total": float(price) * (int(quantity) if quantity else 1) | |
} | |
products.append(product) | |
return products | |
def extract_with_perplexity_llm(ocr_text): | |
prompt = f""" | |
You are an expert at extracting structured data from receipts. | |
From the following OCR text, extract these fields and return them as a JSON object with exactly these keys: | |
- name (customer name) | |
- date (date of purchase) | |
- amount_paid (total amount paid) | |
- receipt_no (receipt number) | |
- products (a list of all products, each with name, price, and quantity if available) | |
Example output: | |
{{ | |
"name": "Mrs. Genevieve Lopez", | |
"date": "12/13/2024", | |
"amount_paid": 29.69, | |
"receipt_no": "042085", | |
"products": [ | |
{{"name": "Orange Juice", "price": 2.15, "quantity": 1}}, | |
{{"name": "Apples", "price": 3.50, "quantity": 1}} | |
] | |
}} | |
Text: | |
\"\"\"{ocr_text}\"\"\" | |
""" | |
messages = [ | |
{ | |
"role": "system", | |
"content": "You are an AI assistant that extracts structured information from text." | |
}, | |
{ | |
"role": "user", | |
"content": prompt | |
} | |
] | |
response = client.chat.completions.create( | |
model="sonar-pro", | |
messages=messages | |
) | |
return response.choices[0].message.content | |
def convert_floats_to_decimal(obj): | |
if isinstance(obj, float): | |
return Decimal(str(obj)) | |
elif isinstance(obj, dict): | |
return {k: convert_floats_to_decimal(v) for k, v in obj.items()} | |
elif isinstance(obj, list): | |
return [convert_floats_to_decimal(i) for i in obj] | |
else: | |
return obj | |
def save_to_dynamodb(data, table_name="Receipts"): | |
dynamodb = boto3.resource('dynamodb') | |
table = dynamodb.Table(table_name) | |
# Calculate total amount if not provided | |
if "products" in data and not data.get("amount_paid"): | |
total = sum(product["total"] for product in data["products"]) | |
data["amount_paid"] = total | |
# Convert all float values to Decimal for DynamoDB | |
data = convert_floats_to_decimal(data) | |
# Generate receipt number if not present | |
if not data.get("receipt_no"): | |
data["receipt_no"] = str(uuid.uuid4()) | |
table.put_item(Item=data) | |
def merge_extractions(regex_fields, llm_fields): | |
merged = {} | |
for key in ["name", "date", "amount_paid", "receipt_no"]: | |
merged[key] = llm_fields.get(key) or regex_fields.get(key) | |
merged["products"] = llm_fields.get("products") or regex_fields.get("products") | |
return merged | |
def extract_handwritten_text(image): | |
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten') | |
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten') | |
pixel_values = processor(images=image, return_tensors="pt").pixel_values | |
generated_ids = model.generate(pixel_values) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return generated_text | |
def get_paddle_ocr(): | |
return PaddleOCR(use_angle_cls=True, lang='en', show_log=False) | |
def extract_handwritten_text_paddle(image): | |
ocr = get_paddle_ocr() | |
# Save PIL image to a temporary file | |
temp_path = 'temp_uploaded_image_paddle.jpg' | |
image.save(temp_path) | |
result = ocr.ocr(temp_path, cls=True) | |
lines = [line[1][0] for line in result[0]] | |
return '\n'.join(lines) | |
def main(): | |
st.set_page_config( | |
page_title="FormIQ - Intelligent Receipt Parser", | |
page_icon="📄", | |
layout="wide" | |
) | |
st.title("FormIQ: Intelligent Receipt Parser") | |
st.markdown(""" | |
Upload your documents to extract and validate information using advanced AI models. | |
""") | |
# Sidebar | |
with st.sidebar: | |
st.header("Settings") | |
document_type = st.selectbox( | |
"Document Type", | |
options=["invoice", "receipt", "form"], | |
index=0 | |
) | |
confidence_threshold = st.slider( | |
"Confidence Threshold", | |
min_value=0.0, | |
max_value=1.0, | |
value=0.5, | |
step=0.05 | |
) | |
st.markdown("---") | |
st.markdown("### About") | |
st.markdown(""" | |
FormIQ uses LayoutLMv3 and Perplexity AI to extract and validate information from documents. | |
""") | |
# Receipt Chatbot in sidebar | |
st.markdown("---") | |
st.header("💬 Receipt Chatbot") | |
st.write("Ask questions about your receipts stored in DynamoDB.") | |
user_question = st.text_input("Enter your question:", "What is the total amount paid?") | |
if st.button("Ask Chatbot", key="sidebar_chatbot"): | |
with st.spinner("Getting answer from Perplexity LLM..."): | |
answer = ask_receipt_chatbot(user_question) | |
st.success(answer) | |
# Main content | |
uploaded_file = st.file_uploader( | |
"Upload Document", | |
type=["png", "jpg", "jpeg", "pdf"], | |
help="Upload a document image to process" | |
) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file).convert("RGB") | |
st.image(image, caption="Uploaded Document", width=600) | |
handwritten_text = None | |
# Option to extract handwritten text with PaddleOCR | |
if st.checkbox("Extract handwritten text (PaddleOCR)?"): | |
with st.spinner("Extracting handwritten text with PaddleOCR..."): | |
handwritten_text = extract_handwritten_text_paddle(image) | |
st.subheader("Handwritten Text Extracted (PaddleOCR)") | |
st.write(handwritten_text) | |
# Process button | |
if st.button("Process Document"): | |
with st.spinner("Processing document..."): | |
try: | |
temp_path = "temp_uploaded_image.jpg" | |
image.save(temp_path) | |
# Use handwritten text if available, else fallback to pytesseract | |
if handwritten_text: | |
llm_input_text = handwritten_text | |
else: | |
llm_input_text = pytesseract.image_to_string(Image.open(temp_path)) | |
llm_result = extract_with_perplexity_llm(llm_input_text) | |
llm_json = extract_json_from_llm_output(llm_result) | |
st.subheader("Structured Data (Perplexity LLM)") | |
if llm_json: | |
try: | |
llm_data = json.loads(llm_json) | |
st.json(llm_data) | |
save_to_dynamodb(llm_data) | |
st.success("Saved to DynamoDB!") | |
except Exception as e: | |
st.error(f"Failed to parse LLM output as JSON: {e}") | |
else: | |
st.warning("No valid JSON found in LLM output.") | |
except Exception as e: | |
logger.error(f"Error processing document: {str(e)}") | |
st.error(f"Error processing document: {str(e)}") | |
st.header("Model Training & Evaluation Demo") | |
if st.button("Start Training"): | |
epochs = 10 | |
num_classes = 3 # Example: 3 classes for confusion matrix | |
losses = [] | |
val_losses = [] | |
accuracies = [] | |
progress = st.progress(0) | |
chart = st.line_chart({"Loss": [], "Val Loss": [], "Accuracy": []}) | |
writer = SummaryWriter("logs") | |
for epoch in range(epochs): | |
# Simulate training | |
loss = np.exp(-epoch/5) + np.random.rand() * 0.05 | |
val_loss = loss + np.random.rand() * 0.02 | |
acc = 1 - loss + np.random.rand() * 0.02 | |
losses.append(loss) | |
val_losses.append(val_loss) | |
accuracies.append(acc) | |
chart.add_rows({"Loss": [loss], "Val Loss": [val_loss], "Accuracy": [acc]}) | |
progress.progress((epoch+1)/epochs) | |
st.write(f"Epoch {epoch+1}: Loss={loss:.4f}, Val Loss={val_loss:.4f}, Accuracy={acc:.4f}") | |
# Log to TensorBoard | |
writer.add_scalar("loss", loss, epoch) | |
writer.add_scalar("val_loss", val_loss, epoch) | |
writer.add_scalar("accuracy", acc, epoch) | |
# Simulate predictions and labels for confusion matrix | |
y_true = np.random.randint(0, num_classes, 100) | |
y_pred = y_true.copy() | |
y_pred[np.random.choice(100, 10, replace=False)] = np.random.randint(0, num_classes, 10) | |
cm = confusion_matrix(y_true, y_pred, labels=range(num_classes)) | |
# Only log confusion matrix in the last epoch | |
if epoch == epochs - 1: | |
fig, ax = plt.subplots() | |
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[f"Class {i}" for i in range(num_classes)]) | |
disp.plot(ax=ax) | |
plt.close(fig) | |
writer.add_figure("confusion_matrix", fig, epoch) | |
writer.close() | |
st.success("Training complete!") | |
# Show last confusion matrix in Streamlit | |
if 'cm' in locals(): | |
st.subheader("Confusion Matrix (Last Epoch)") | |
fig, ax = plt.subplots() | |
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[f"Class {i}" for i in range(num_classes)]) | |
disp.plot(ax=ax) | |
st.pyplot(fig) | |
else: | |
st.info("Confusion matrix not found.") | |
if __name__ == "__main__": | |
main() |