Spaces:
Running
Running
import streamlit as st | |
import torch | |
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification | |
from PIL import Image | |
import io | |
import json | |
import pandas as pd | |
import plotly.express as px | |
import numpy as np | |
from typing import Dict, Any | |
import logging | |
import pytesseract | |
import re | |
from openai import OpenAI | |
import os | |
from pdf2image import convert_from_bytes | |
from dotenv import load_dotenv | |
from chatbot_utils import ask_receipt_chatbot | |
import time | |
from tensorboard.backend.event_processing import event_accumulator | |
from torch.utils.tensorboard import SummaryWriter | |
import matplotlib.pyplot as plt | |
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay | |
import matplotlib | |
matplotlib.use('Agg') | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load environment variables | |
load_dotenv() | |
# Initialize OpenAI client for Perplexity | |
api_key = os.getenv('PERPLEXITY_API_KEY') | |
if not api_key: | |
st.error(""" | |
⚠️ Perplexity API key not found! Please add your API key to the Space's secrets: | |
1. Go to Space Settings | |
2. Click on 'Repository secrets' | |
3. Add a new secret with name 'PERPLEXITY_API_KEY' | |
4. Add your Perplexity API key as the value | |
""") | |
st.stop() | |
client = OpenAI( | |
api_key=api_key, | |
base_url="https://api.perplexity.ai" | |
) | |
# Initialize LayoutLM model | |
def load_model(): | |
model_name = "microsoft/layoutlmv3-base" | |
processor = LayoutLMv3Processor.from_pretrained(model_name) | |
model = LayoutLMv3ForTokenClassification.from_pretrained(model_name) | |
return processor, model | |
def extract_json_from_llm_output(llm_result): | |
match = re.search(r'\{.*\}', llm_result, re.DOTALL) | |
if match: | |
return match.group(0) | |
return None | |
def extract_fields(image_path): | |
text = pytesseract.image_to_string(Image.open(image_path)) | |
st.subheader("Raw OCR Output") | |
st.code(text) | |
# Improved Regex patterns for fields | |
patterns = { | |
"name": r"Mrs\s+\w+\s+\w+", | |
"date": r"Date[:\s]+([\d/]+)", | |
"product": r"\d+\s+\w+.*Style\s+\d+", | |
"amount_paid": r"Total Paid\s+\$?([\d.,]+)", | |
"receipt_no": r"Receipt No\.?\s*:?\s*(\d+)" | |
} | |
results = {} | |
for field, pattern in patterns.items(): | |
match = re.search(pattern, text, re.IGNORECASE) | |
if match: | |
results[field] = match.group(1) if match.groups() else match.group(0) | |
else: | |
results[field] = None | |
# Extract all products | |
results["products"] = extract_products(text) | |
return results | |
def extract_products(text): | |
# This pattern matches lines like: "1076903 PISTACHIO 14.49" | |
product_pattern = r"\d{6,} ([A-Z0-9 ]+) (\d+\.\d{2})" | |
matches = re.findall(product_pattern, text) | |
products = [{"name": name.strip(), "price": float(price)} for name, price in matches] | |
return products | |
def extract_with_perplexity_llm(ocr_text): | |
prompt = f""" | |
You are an expert at extracting structured data from receipts. | |
From the following OCR text, extract these fields and return them as a flat JSON object with exactly these keys: | |
- name (customer name) | |
- date (date of purchase) | |
- amount_paid (total amount paid, or price if only one product) | |
- receipt_no (receipt number) | |
- product (the main product name, as a string; if multiple products, pick the most expensive or the only one) | |
**Note:** If the receipt has only one product, set 'product' to its name and 'amount_paid' to its price. If there is a 'price' and an 'amount paid', treat them as the same if they are equal. | |
Example output: | |
{{ | |
"name": "Mrs. Genevieve Lopez", | |
"date": "12/13/2024", | |
"amount_paid": 579.18, | |
"receipt_no": "042085", | |
"product": "Wireless Airpods" | |
}} | |
Text: | |
\"\"\"{ocr_text}\"\"\" | |
""" | |
messages = [ | |
{ | |
"role": "system", | |
"content": "You are an AI assistant that extracts structured information from text." | |
}, | |
{ | |
"role": "user", | |
"content": prompt | |
} | |
] | |
response = client.chat.completions.create( | |
model="sonar-pro", | |
messages=messages | |
) | |
return response.choices[0].message.content | |
def save_to_dynamodb(data, table_name="Receipts"): | |
# ... existing code ... | |
# data["products"] is a list of dicts | |
table.put_item(Item=data) | |
def merge_extractions(regex_fields, llm_fields): | |
merged = {} | |
for key in ["name", "date", "amount_paid", "receipt_no"]: | |
merged[key] = llm_fields.get(key) or regex_fields.get(key) | |
merged["products"] = llm_fields.get("products") or regex_fields.get("products") | |
return merged | |
def main(): | |
st.set_page_config( | |
page_title="FormIQ - Intelligent Document Parser", | |
page_icon="📄", | |
layout="wide" | |
) | |
st.title("FormIQ: Intelligent Document Parser") | |
st.markdown(""" | |
Upload your documents to extract and validate information using advanced AI models. | |
""") | |
# Sidebar | |
with st.sidebar: | |
st.header("Settings") | |
document_type = st.selectbox( | |
"Document Type", | |
options=["invoice", "receipt", "form"], | |
index=0 | |
) | |
confidence_threshold = st.slider( | |
"Confidence Threshold", | |
min_value=0.0, | |
max_value=1.0, | |
value=0.5, | |
step=0.05 | |
) | |
st.markdown("---") | |
st.markdown("### About") | |
st.markdown(""" | |
FormIQ uses LayoutLMv3 and Perplexity AI to extract and validate information from documents. | |
""") | |
# Receipt Chatbot in sidebar | |
st.markdown("---") | |
st.header("💬 Receipt Chatbot") | |
st.write("Ask questions about your receipts stored in DynamoDB.") | |
user_question = st.text_input("Enter your question:", "What is the total amount paid?") | |
if st.button("Ask Chatbot", key="sidebar_chatbot"): | |
with st.spinner("Getting answer from Perplexity LLM..."): | |
answer = ask_receipt_chatbot(user_question) | |
st.success(answer) | |
# Main content | |
uploaded_file = st.file_uploader( | |
"Upload Document", | |
type=["png", "jpg", "jpeg", "pdf"], | |
help="Upload a document image to process" | |
) | |
if uploaded_file is not None: | |
# Display uploaded image | |
if uploaded_file.type == "application/pdf": | |
images = convert_from_bytes(uploaded_file.read()) | |
image = images[0] # Use the first page | |
else: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Document", width=600) | |
# Process button | |
if st.button("Process Document"): | |
with st.spinner("Processing document..."): | |
try: | |
# Save the uploaded file to a temporary location | |
temp_path = "temp_uploaded_image.jpg" | |
image.save(temp_path) | |
# Extract fields using OCR + regex | |
fields = extract_fields(temp_path) | |
# Extract with Perplexity LLM | |
with st.spinner("Extracting structured data with Perplexity LLM..."): | |
llm_result = extract_with_perplexity_llm(pytesseract.image_to_string(Image.open(temp_path))) | |
st.subheader("Structured Data (Perplexity LLM)") | |
st.json(llm_result) | |
# Try to parse the JSON from the LLM output | |
llm_data = {} | |
try: | |
llm_json = extract_json_from_llm_output(llm_result) | |
if llm_json: | |
llm_data = json.loads(llm_json) | |
# Save to DynamoDB | |
try: | |
save_to_dynamodb(llm_data) | |
st.success("Saved to DynamoDB!") | |
except Exception as e: | |
st.error(f"Failed to save to DynamoDB: {e}") | |
except Exception as e: | |
st.error(f"Failed to parse LLM output as JSON: {e}") | |
# Display extracted products if present | |
if "products" in llm_data and llm_data["products"]: | |
st.subheader("Products (LLM Extracted)") | |
st.dataframe(pd.DataFrame(llm_data["products"])) | |
except Exception as e: | |
logger.error(f"Error processing document: {str(e)}") | |
st.error(f"Error processing document: {str(e)}") | |
st.header("Model Training & Evaluation Demo") | |
if st.button("Start Training"): | |
epochs = 10 | |
num_classes = 3 # Example: 3 classes for confusion matrix | |
losses = [] | |
val_losses = [] | |
accuracies = [] | |
progress = st.progress(0) | |
chart = st.line_chart({"Loss": [], "Val Loss": [], "Accuracy": []}) | |
writer = SummaryWriter("logs") | |
for epoch in range(epochs): | |
# Simulate training | |
loss = np.exp(-epoch/5) + np.random.rand() * 0.05 | |
val_loss = loss + np.random.rand() * 0.02 | |
acc = 1 - loss + np.random.rand() * 0.02 | |
losses.append(loss) | |
val_losses.append(val_loss) | |
accuracies.append(acc) | |
chart.add_rows({"Loss": [loss], "Val Loss": [val_loss], "Accuracy": [acc]}) | |
progress.progress((epoch+1)/epochs) | |
st.write(f"Epoch {epoch+1}: Loss={loss:.4f}, Val Loss={val_loss:.4f}, Accuracy={acc:.4f}") | |
# Log to TensorBoard | |
writer.add_scalar("loss", loss, epoch) | |
writer.add_scalar("val_loss", val_loss, epoch) | |
writer.add_scalar("accuracy", acc, epoch) | |
# Simulate predictions and labels for confusion matrix | |
y_true = np.random.randint(0, num_classes, 100) | |
y_pred = y_true.copy() | |
y_pred[np.random.choice(100, 10, replace=False)] = np.random.randint(0, num_classes, 10) | |
cm = confusion_matrix(y_true, y_pred, labels=range(num_classes)) | |
# Only log confusion matrix in the last epoch | |
if epoch == epochs - 1: | |
fig, ax = plt.subplots() | |
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[f"Class {i}" for i in range(num_classes)]) | |
disp.plot(ax=ax) | |
plt.close(fig) | |
writer.add_figure("confusion_matrix", fig, epoch) | |
writer.close() | |
st.success("Training complete!") | |
# Show last confusion matrix in Streamlit | |
if 'cm' in locals(): | |
st.subheader("Confusion Matrix (Last Epoch)") | |
fig, ax = plt.subplots() | |
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[f"Class {i}" for i in range(num_classes)]) | |
disp.plot(ax=ax) | |
st.pyplot(fig) | |
else: | |
st.info("Confusion matrix not found.") | |
if __name__ == "__main__": | |
main() |