Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import json | |
import tempfile | |
import pdfplumber | |
import faiss | |
import numpy as np | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
from openai import OpenAI | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
# Setup GROQ client | |
client = OpenAI(api_key=GROQ_API_KEY, base_url="https://api.groq.com/openai/v1") | |
# Constants | |
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
LLM_MODEL = "llama3-8b-8192" | |
embedder = SentenceTransformer(EMBEDDING_MODEL) | |
# Streamlit app setup | |
st.set_page_config(page_title="π§Έ ToyShop Assistant", layout="wide") | |
st.title("π§Έ ToyShop RAG-Based Assistant") | |
# --- Helper functions --- | |
def extract_pdf_text(file): | |
text = "" | |
with pdfplumber.open(file) as pdf: | |
for page in pdf.pages: | |
page_text = page.extract_text() | |
if page_text: | |
text += page_text + "\n" | |
return text.strip() | |
def load_json_orders(json_file): | |
valid_orders = [] | |
try: | |
data = json.load(json_file) | |
if isinstance(data, list): | |
for i, order in enumerate(data): | |
try: | |
json.dumps(order) # test serialization | |
valid_orders.append(order) | |
except Exception as e: | |
st.warning(f"β οΈ Skipping invalid order at index {i}: {e}") | |
elif isinstance(data, dict): | |
for k, order in data.items(): | |
try: | |
json.dumps(order) | |
valid_orders.append(order) | |
except Exception as e: | |
st.warning(f"β οΈ Skipping invalid order with key '{k}': {e}") | |
except Exception as e: | |
st.error(f"β Error parsing JSON file: {e}") | |
return valid_orders | |
def build_index(text_chunks): | |
vectors = embedder.encode(text_chunks) | |
index = faiss.IndexFlatL2(vectors.shape[1]) | |
index.add(np.array(vectors)) | |
return index, text_chunks | |
def ask_llm(context, query): | |
prompt = ( | |
f"You are a helpful assistant for an online toy shop.\n\n" | |
f"Knowledge base:\n{context}\n\n" | |
f"Question: {query}" | |
) | |
# For debugging: show the prompt being sent. | |
st.expander("Prompt to LLM").code(prompt) | |
response = client.chat.completions.create( | |
model=LLM_MODEL, | |
messages=[{"role": "user", "content": prompt}] | |
) | |
# Log full response for inspection (can be commented out in production) | |
st.expander("Raw LLM API Response").json(response) | |
return response.choices[0].message.content.strip() | |
# --- File upload section --- | |
st.subheader("π Upload Customer Orders (JSON)") | |
orders_file = st.file_uploader("Upload JSON file", type="json") | |
st.subheader("π Upload FAQs / Product Info / Return Policy (PDFs)") | |
pdf_files = st.file_uploader("Upload one or more PDFs", type="pdf", accept_multiple_files=True) | |
order_chunks, pdf_chunks = [], [] | |
# --- Process JSON --- | |
if orders_file: | |
orders = load_json_orders(orders_file) | |
if orders: | |
order_chunks = [json.dumps(order, ensure_ascii=False) for order in orders] | |
st.success(f"β Loaded {len(order_chunks)} customer order records.") | |
# Attempt to flatten for viewing | |
try: | |
df = pd.json_normalize(orders) | |
st.dataframe(df, use_container_width=True) | |
except Exception: | |
st.warning("β οΈ Nested JSON detected. Showing raw JSON preview instead.") | |
st.json(orders) | |
else: | |
st.error("No valid orders found in the JSON file.") | |
# --- Process PDFs --- | |
if pdf_files: | |
for pdf_file in pdf_files: | |
try: | |
text = extract_pdf_text(pdf_file) | |
# Split into paragraphs (non-empty lines) | |
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] | |
pdf_chunks.extend(paragraphs) | |
st.success(f"π Processed {pdf_file.name}") | |
except Exception as e: | |
st.error(f"β Failed to read {pdf_file.name}: {e}") | |
combined_chunks = order_chunks + pdf_chunks | |
# --- Question Answering Section --- | |
if combined_chunks: | |
index, sources = build_index(combined_chunks) | |
st.subheader("β Ask a Question") | |
user_query = st.text_input("What would you like to know?", placeholder="e.g. What is the status of order 123?") | |
if user_query: | |
query_vector = embedder.encode([user_query]) | |
D, I = index.search(query_vector, k=5) | |
# Prepare context from the top-K results: | |
context = "\n---\n".join([sources[i] for i in I[0]]) | |
st.expander("Combined Context").code(context) | |
with st.spinner("π€ Thinking..."): | |
try: | |
answer = ask_llm(context, user_query) | |
st.markdown("### π§ Answer") | |
# Use st.write() to render the answer as text. | |
st.write(answer) | |
except Exception as e: | |
st.error(f"β GROQ API Error: {e}") | |
else: | |
st.info("π Please upload both JSON orders and relevant PDFs to begin.") | |