test_RAG / app.py
amasood's picture
Update app.py
10d3bb1 verified
raw
history blame
3.47 kB
import os
import streamlit as st
import pandas as pd
import numpy as np
import faiss
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from groq import Groq
# Constants for saving/loading index
INDEX_FILE = "faiss_index.index"
QUESTIONS_FILE = "questions.npy"
# Load dataset
@st.cache_data
def load_data():
dataset = load_dataset("FreedomIntelligence/RAG-Instruct", split="train")
df = pd.DataFrame(dataset)
return df[["question", "answer"]]
# Build or load FAISS index
@st.cache_resource
def setup_faiss(data):
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
if os.path.exists(INDEX_FILE) and os.path.exists(QUESTIONS_FILE):
st.info("πŸ” Loading FAISS index from disk...")
index = faiss.read_index(INDEX_FILE)
questions = np.load(QUESTIONS_FILE, allow_pickle=True)
else:
st.info("βš™οΈ FAISS index not found. Building new index...")
questions = data["question"].tolist()
embeddings = []
progress_bar = st.progress(0, text="Embedding questions...")
total = len(questions)
for i, chunk in enumerate(np.array_split(questions, 10)):
emb = model.encode(chunk)
embeddings.extend(emb)
progress_bar.progress((i + 1) / 10, text=f"Embedding... {int((i + 1) * 10)}%")
embeddings = np.array(embeddings)
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
faiss.write_index(index, INDEX_FILE)
np.save(QUESTIONS_FILE, np.array(questions, dtype=object))
progress_bar.empty()
st.success("βœ… FAISS index built and saved!")
return model, index, questions
# Retrieve relevant context
def retrieve_context(query, model, index, questions, data, top_k=1):
query_vec = model.encode([query])
distances, indices = index.search(np.array(query_vec), top_k)
results = [questions[i] + "\n\n" + data.iloc[i]["answer"] for i in indices[0]]
return "\n\n".join(results)
# Call Groq LLM
def query_groq(context, query):
prompt = f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
client = Groq(api_key=st.secrets[GROQ_API_KEY])
response = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama-3-70b-8192"
)
return response.choices[0].message.content
# Streamlit UI
st.set_page_config(page_title="RAG App with Groq", layout="wide")
st.title("πŸ” RAG App using Groq API + RAG-Instruct Dataset")
# Load data and setup
data = load_data()
model, index, questions = setup_faiss(data)
st.markdown("Ask a question based on the QA knowledge base.")
# Optional queries
optional_queries = [
"What is retrieval-augmented generation?",
"How can I fine-tune a language model?",
"What are the components of a RAG system?",
"Explain prompt engineering basics.",
"How does FAISS indexing help in RAG?"
]
query = st.text_input("Enter your question:", value=optional_queries[0])
if st.button("Ask"):
with st.spinner("Retrieving and generating response..."):
context = retrieve_context(query, model, index, questions, data)
answer = query_groq(context, query)
st.subheader("πŸ“„ Retrieved Context")
st.write(context)
st.subheader("πŸ’¬ Answer from Groq LLM")
st.write(answer)
st.markdown("### πŸ’‘ Optional Queries to Try:")
for q in optional_queries:
st.markdown(f"- {q}")