|
import os |
|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
import faiss |
|
from datasets import load_dataset |
|
from sentence_transformers import SentenceTransformer |
|
from groq import Groq |
|
|
|
|
|
INDEX_FILE = "faiss_index.index" |
|
QUESTIONS_FILE = "questions.npy" |
|
|
|
|
|
@st.cache_data |
|
def load_data(): |
|
dataset = load_dataset("FreedomIntelligence/RAG-Instruct", split="train") |
|
df = pd.DataFrame(dataset) |
|
return df[["question", "answer"]] |
|
|
|
|
|
@st.cache_resource |
|
def setup_faiss(data): |
|
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
if os.path.exists(INDEX_FILE) and os.path.exists(QUESTIONS_FILE): |
|
st.info("π Loading FAISS index from disk...") |
|
index = faiss.read_index(INDEX_FILE) |
|
questions = np.load(QUESTIONS_FILE, allow_pickle=True) |
|
else: |
|
st.info("βοΈ FAISS index not found. Building new index...") |
|
|
|
questions = data["question"].tolist() |
|
embeddings = [] |
|
progress_bar = st.progress(0, text="Embedding questions...") |
|
total = len(questions) |
|
|
|
for i, chunk in enumerate(np.array_split(questions, 10)): |
|
emb = model.encode(chunk) |
|
embeddings.extend(emb) |
|
progress_bar.progress((i + 1) / 10, text=f"Embedding... {int((i + 1) * 10)}%") |
|
|
|
embeddings = np.array(embeddings) |
|
index = faiss.IndexFlatL2(embeddings.shape[1]) |
|
index.add(embeddings) |
|
|
|
faiss.write_index(index, INDEX_FILE) |
|
np.save(QUESTIONS_FILE, np.array(questions, dtype=object)) |
|
|
|
progress_bar.empty() |
|
st.success("β
FAISS index built and saved!") |
|
|
|
return model, index, questions |
|
|
|
|
|
|
|
def retrieve_context(query, model, index, questions, data, top_k=1): |
|
query_vec = model.encode([query]) |
|
distances, indices = index.search(np.array(query_vec), top_k) |
|
results = [questions[i] + "\n\n" + data.iloc[i]["answer"] for i in indices[0]] |
|
return "\n\n".join(results) |
|
|
|
|
|
def query_groq(context, query): |
|
prompt = f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:" |
|
client = Groq(api_key=st.secrets[GROQ_API_KEY]) |
|
response = client.chat.completions.create( |
|
messages=[{"role": "user", "content": prompt}], |
|
model="llama-3-70b-8192" |
|
) |
|
return response.choices[0].message.content |
|
|
|
|
|
st.set_page_config(page_title="RAG App with Groq", layout="wide") |
|
st.title("π RAG App using Groq API + RAG-Instruct Dataset") |
|
|
|
|
|
data = load_data() |
|
model, index, questions = setup_faiss(data) |
|
|
|
st.markdown("Ask a question based on the QA knowledge base.") |
|
|
|
|
|
optional_queries = [ |
|
"What is retrieval-augmented generation?", |
|
"How can I fine-tune a language model?", |
|
"What are the components of a RAG system?", |
|
"Explain prompt engineering basics.", |
|
"How does FAISS indexing help in RAG?" |
|
] |
|
|
|
query = st.text_input("Enter your question:", value=optional_queries[0]) |
|
if st.button("Ask"): |
|
with st.spinner("Retrieving and generating response..."): |
|
context = retrieve_context(query, model, index, questions, data) |
|
answer = query_groq(context, query) |
|
st.subheader("π Retrieved Context") |
|
st.write(context) |
|
st.subheader("π¬ Answer from Groq LLM") |
|
st.write(answer) |
|
|
|
st.markdown("### π‘ Optional Queries to Try:") |
|
for q in optional_queries: |
|
st.markdown(f"- {q}") |
|
|