Spaces:
Running
Running
import streamlit as st | |
import torch | |
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification | |
from PIL import Image | |
import io | |
import json | |
import pandas as pd | |
import plotly.express as px | |
import numpy as np | |
from typing import Dict, Any | |
import logging | |
import pytesseract | |
import re | |
from openai import OpenAI | |
import os | |
from dotenv import load_dotenv | |
from chatbot_utils import ask_receipt_chatbot | |
import time | |
from tensorboard.backend.event_processing import event_accumulator | |
from torch.utils.tensorboard import SummaryWriter | |
import matplotlib.pyplot as plt | |
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load environment variables | |
load_dotenv() | |
# Initialize OpenAI client for Perplexity | |
api_key = os.getenv('PERPLEXITY_API_KEY') | |
if not api_key: | |
st.error(""" | |
β οΈ Perplexity API key not found! Please add your API key to the Space's secrets: | |
1. Go to Space Settings | |
2. Click on 'Repository secrets' | |
3. Add a new secret with name 'PERPLEXITY_API_KEY' | |
4. Add your Perplexity API key as the value | |
""") | |
st.stop() | |
client = OpenAI( | |
api_key=api_key, | |
base_url="https://api.perplexity.ai" | |
) | |
# Initialize LayoutLM model | |
def load_model(): | |
model_name = "microsoft/layoutlmv3-base" | |
processor = LayoutLMv3Processor.from_pretrained(model_name) | |
model = LayoutLMv3ForTokenClassification.from_pretrained(model_name) | |
return processor, model | |
def extract_json_from_llm_output(llm_result): | |
match = re.search(r'\{.*\}', llm_result, re.DOTALL) | |
if match: | |
return match.group(0) | |
return None | |
def extract_fields(image_path): | |
# OCR | |
text = pytesseract.image_to_string(Image.open(image_path)) | |
# Display OCR output for debugging | |
st.subheader("Raw OCR Output") | |
st.code(text) | |
# Improved Regex patterns for fields | |
patterns = { | |
"name": r"Mrs\s+\w+\s+\w+", | |
"date": r"Date[:\s]+([\d/]+)", | |
"product": r"\d+\s+\w+.*Style\s+\d+", | |
"amount_paid": r"Total Paid\s+\$?([\d.,]+)", | |
"receipt_no": r"Receipt No\.?\s*:?\s*(\d+)" | |
} | |
results = {} | |
for field, pattern in patterns.items(): | |
match = re.search(pattern, text, re.IGNORECASE) | |
if match: | |
results[field] = match.group(1) if match.groups() else match.group(0) | |
else: | |
results[field] = None | |
return results | |
def extract_with_perplexity_llm(ocr_text): | |
prompt = f""" | |
Extract the following fields from this receipt text: | |
- name | |
- date | |
- product | |
- amount_paid | |
- receipt_no | |
Text: | |
\"\"\"{ocr_text}\"\"\" | |
Return the result as a JSON object with those fields. | |
""" | |
messages = [ | |
{ | |
"role": "system", | |
"content": "You are an AI assistant that extracts structured information from text." | |
}, | |
{ | |
"role": "user", | |
"content": prompt | |
} | |
] | |
response = client.chat.completions.create( | |
model="sonar-pro", | |
messages=messages | |
) | |
return response.choices[0].message.content | |
def main(): | |
st.set_page_config( | |
page_title="FormIQ - Intelligent Document Parser", | |
page_icon="π", | |
layout="wide" | |
) | |
st.title("FormIQ: Intelligent Document Parser") | |
st.markdown(""" | |
Upload your documents to extract and validate information using advanced AI models. | |
""") | |
# Sidebar | |
with st.sidebar: | |
st.header("Settings") | |
document_type = st.selectbox( | |
"Document Type", | |
options=["invoice", "receipt", "form"], | |
index=0 | |
) | |
confidence_threshold = st.slider( | |
"Confidence Threshold", | |
min_value=0.0, | |
max_value=1.0, | |
value=0.5, | |
step=0.05 | |
) | |
st.markdown("---") | |
st.markdown("### About") | |
st.markdown(""" | |
FormIQ uses LayoutLMv3 and Perplexity AI to extract and validate information from documents. | |
""") | |
# Receipt Chatbot in sidebar | |
st.markdown("---") | |
st.header("π¬ Receipt Chatbot") | |
st.write("Ask questions about your receipts stored in DynamoDB.") | |
user_question = st.text_input("Enter your question:", "What is the total amount paid?") | |
if st.button("Ask Chatbot", key="sidebar_chatbot"): | |
with st.spinner("Getting answer from Perplexity LLM..."): | |
answer = ask_receipt_chatbot(user_question) | |
st.success(answer) | |
# Main content | |
uploaded_file = st.file_uploader( | |
"Upload Document", | |
type=["png", "jpg", "jpeg", "pdf"], | |
help="Upload a document image to process" | |
) | |
if uploaded_file is not None: | |
# Display uploaded image | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Document", width=600) | |
# Process button | |
if st.button("Process Document"): | |
with st.spinner("Processing document..."): | |
try: | |
# Save the uploaded file to a temporary location | |
temp_path = "temp_uploaded_image.jpg" | |
image.save(temp_path) | |
# Extract fields using OCR + regex | |
fields = extract_fields(temp_path) | |
# Extract with Perplexity LLM | |
with st.spinner("Extracting structured data with Perplexity LLM..."): | |
try: | |
llm_result = extract_with_perplexity_llm(pytesseract.image_to_string(Image.open(temp_path))) | |
st.subheader("Structured Data (Perplexity LLM)") | |
st.code(llm_result, language="json") | |
# Display extracted fields | |
st.subheader("Extracted Fields") | |
fields_df = pd.DataFrame([fields]) | |
st.dataframe(fields_df) | |
except Exception as e: | |
st.error(f"LLM extraction failed: {e}") | |
except Exception as e: | |
logger.error(f"Error processing document: {str(e)}") | |
st.error(f"Error processing document: {str(e)}") | |
st.header("Model Training & Evaluation Demo") | |
if st.button("Start Training"): | |
epochs = 10 | |
num_classes = 3 # Example: 3 classes for confusion matrix | |
losses = [] | |
val_losses = [] | |
accuracies = [] | |
progress = st.progress(0) | |
chart = st.line_chart({"Loss": [], "Val Loss": [], "Accuracy": []}) | |
writer = SummaryWriter("logs") | |
for epoch in range(epochs): | |
# Simulate training | |
loss = np.exp(-epoch/5) + np.random.rand() * 0.05 | |
val_loss = loss + np.random.rand() * 0.02 | |
acc = 1 - loss + np.random.rand() * 0.02 | |
losses.append(loss) | |
val_losses.append(val_loss) | |
accuracies.append(acc) | |
chart.add_rows({"Loss": [loss], "Val Loss": [val_loss], "Accuracy": [acc]}) | |
progress.progress((epoch+1)/epochs) | |
st.write(f"Epoch {epoch+1}: Loss={loss:.4f}, Val Loss={val_loss:.4f}, Accuracy={acc:.4f}") | |
# Log to TensorBoard | |
writer.add_scalar("loss", loss, epoch) | |
writer.add_scalar("val_loss", val_loss, epoch) | |
writer.add_scalar("accuracy", acc, epoch) | |
# Simulate predictions and labels for confusion matrix | |
y_true = np.random.randint(0, num_classes, 100) | |
y_pred = y_true.copy() | |
# Add some noise to predictions | |
y_pred[np.random.choice(100, 10, replace=False)] = np.random.randint(0, num_classes, 10) | |
cm = confusion_matrix(y_true, y_pred, labels=range(num_classes)) | |
# Only log confusion matrix in the last epoch | |
if epoch == epochs - 1: | |
fig, ax = plt.subplots() | |
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[f"Class {i}" for i in range(num_classes)]) | |
disp.plot(ax=ax) | |
plt.close(fig) | |
writer.add_figure("confusion_matrix", fig, epoch) | |
time.sleep(0.5) | |
writer.close() | |
st.success("Training complete!") | |
# Wait a moment to ensure logs are written | |
time.sleep(1) | |
if __name__ == "__main__": | |
main() |