RAG-PDF / app.py
masadonline's picture
Update app.py
3cdab77 verified
raw
history blame
5.09 kB
import streamlit as st
import os
import json
import tempfile
import pdfplumber
import faiss
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from openai import OpenAI
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
# Setup GROQ client
client = OpenAI(api_key=GROQ_API_KEY, base_url="https://api.groq.com/openai/v1")
# Constants
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL = "llama3-8b-8192"
embedder = SentenceTransformer(EMBEDDING_MODEL)
# Streamlit app setup
st.set_page_config(page_title="🧸 ToyShop Assistant", layout="wide")
st.title("🧸 ToyShop RAG-Based Assistant")
# --- Helper functions ---
def extract_pdf_text(file):
text = ""
with pdfplumber.open(file) as pdf:
for page in pdf.pages:
page_text = page.extract_text()
if page_text:
text += page_text + "\n"
return text.strip()
def load_json_orders(json_file):
valid_orders = []
try:
data = json.load(json_file)
if isinstance(data, list):
for i, order in enumerate(data):
try:
json.dumps(order) # test serialization
valid_orders.append(order)
except Exception as e:
st.warning(f"⚠️ Skipping invalid order at index {i}: {e}")
elif isinstance(data, dict):
for k, order in data.items():
try:
json.dumps(order)
valid_orders.append(order)
except Exception as e:
st.warning(f"⚠️ Skipping invalid order with key '{k}': {e}")
except Exception as e:
st.error(f"❌ Error parsing JSON file: {e}")
return valid_orders
def build_index(text_chunks):
vectors = embedder.encode(text_chunks)
index = faiss.IndexFlatL2(vectors.shape[1])
index.add(np.array(vectors))
return index, text_chunks
def ask_llm(context, query):
prompt = (
f"You are a helpful assistant for an online toy shop.\n\n"
f"Knowledge base:\n{context}\n\n"
f"Question: {query}"
)
# For debugging: show the prompt being sent.
st.expander("Prompt to LLM").code(prompt)
response = client.chat.completions.create(
model=LLM_MODEL,
messages=[{"role": "user", "content": prompt}]
)
# Log full response for inspection (can be commented out in production)
st.expander("Raw LLM API Response").json(response)
return response.choices[0].message.content.strip()
# --- File upload section ---
st.subheader("πŸ“ Upload Customer Orders (JSON)")
orders_file = st.file_uploader("Upload JSON file", type="json")
st.subheader("πŸ“š Upload FAQs / Product Info / Return Policy (PDFs)")
pdf_files = st.file_uploader("Upload one or more PDFs", type="pdf", accept_multiple_files=True)
order_chunks, pdf_chunks = [], []
# --- Process JSON ---
if orders_file:
orders = load_json_orders(orders_file)
if orders:
order_chunks = [json.dumps(order, ensure_ascii=False) for order in orders]
st.success(f"βœ… Loaded {len(order_chunks)} customer order records.")
# Attempt to flatten for viewing
try:
df = pd.json_normalize(orders)
st.dataframe(df, use_container_width=True)
except Exception:
st.warning("⚠️ Nested JSON detected. Showing raw JSON preview instead.")
st.json(orders)
else:
st.error("No valid orders found in the JSON file.")
# --- Process PDFs ---
if pdf_files:
for pdf_file in pdf_files:
try:
text = extract_pdf_text(pdf_file)
# Split into paragraphs (non-empty lines)
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
pdf_chunks.extend(paragraphs)
st.success(f"πŸ“„ Processed {pdf_file.name}")
except Exception as e:
st.error(f"❌ Failed to read {pdf_file.name}: {e}")
combined_chunks = order_chunks + pdf_chunks
# --- Question Answering Section ---
if combined_chunks:
index, sources = build_index(combined_chunks)
st.subheader("❓ Ask a Question")
user_query = st.text_input("What would you like to know?", placeholder="e.g. What is the status of order 123?")
if user_query:
query_vector = embedder.encode([user_query])
D, I = index.search(query_vector, k=5)
# Prepare context from the top-K results:
context = "\n---\n".join([sources[i] for i in I[0]])
st.expander("Combined Context").code(context)
with st.spinner("πŸ€” Thinking..."):
try:
answer = ask_llm(context, user_query)
st.markdown("### 🧠 Answer")
# Use st.write() to render the answer as text.
st.write(answer)
except Exception as e:
st.error(f"❌ GROQ API Error: {e}")
else:
st.info("πŸ“‚ Please upload both JSON orders and relevant PDFs to begin.")