test_RAG / app.py
amasood's picture
Update app.py
8d7ab91 verified
raw
history blame
2.4 kB
import streamlit as st
import pandas as pd
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from groq import Groq
# Load dataset
@st.cache_data
def load_data():
dataset = load_dataset("FreedomIntelligence/RAG-Instruct", split="train")
df = pd.DataFrame(dataset)
return df[["question", "answer"]]
# Generate embeddings and index
@st.cache_resource
def setup_faiss(data):
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embeddings = model.encode(data["question"].tolist())
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(np.array(embeddings))
return model, index, embeddings
# Retrieve relevant context
def retrieve_context(query, model, index, data, top_k=1):
query_vec = model.encode([query])
distances, indices = index.search(np.array(query_vec), top_k)
results = [data.iloc[i]["question"] + "\n\n" + data.iloc[i]["answer"] for i in indices[0]]
return "\n\n".join(results)
# Call Groq LLM
def query_groq(context, query):
prompt = f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
client = Groq(api_key=st.secrets["GROQ_API_KEY"])
response = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama-3-70b-8192"
)
return response.choices[0].message.content
# Streamlit UI
st.set_page_config(page_title="RAG Demo with Groq", layout="wide")
st.title("🧠 RAG App using Groq API + RAG-Instruct Dataset")
data = load_data()
model, index, _ = setup_faiss(data)
st.markdown("Ask a question based on the QA knowledge base.")
# Optional queries
optional_queries = [
"What is retrieval-augmented generation?",
"How can I fine-tune a language model?",
"What are the components of a RAG system?",
"Explain prompt engineering basics.",
"How does FAISS indexing help in RAG?"
]
query = st.text_input("Enter your question:", value=optional_queries[0])
if st.button("Ask"):
with st.spinner("Retrieving and generating response..."):
context = retrieve_context(query, model, index, data)
answer = query_groq(context, query)
st.subheader("πŸ“„ Retrieved Context")
st.write(context)
st.subheader("πŸ’¬ Answer from Groq LLM")
st.write(answer)
st.markdown("### Optional Queries to Try:")
st.write(", ".join(optional_queries))