|
import streamlit as st |
|
import pandas as pd |
|
from datasets import load_dataset |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import numpy as np |
|
from groq import Groq |
|
|
|
|
|
@st.cache_data |
|
def load_data(): |
|
dataset = load_dataset("FreedomIntelligence/RAG-Instruct", split="train") |
|
df = pd.DataFrame(dataset) |
|
return df[["question", "answer"]] |
|
|
|
|
|
@st.cache_resource |
|
def setup_faiss(data): |
|
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
|
embeddings = model.encode(data["question"].tolist()) |
|
index = faiss.IndexFlatL2(embeddings.shape[1]) |
|
index.add(np.array(embeddings)) |
|
return model, index, embeddings |
|
|
|
|
|
def retrieve_context(query, model, index, data, top_k=1): |
|
query_vec = model.encode([query]) |
|
distances, indices = index.search(np.array(query_vec), top_k) |
|
results = [data.iloc[i]["question"] + "\n\n" + data.iloc[i]["answer"] for i in indices[0]] |
|
return "\n\n".join(results) |
|
|
|
|
|
def query_groq(context, query): |
|
prompt = f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:" |
|
client = Groq(api_key=st.secrets["GROQ_API_KEY"]) |
|
response = client.chat.completions.create( |
|
messages=[{"role": "user", "content": prompt}], |
|
model="llama-3-70b-8192" |
|
) |
|
return response.choices[0].message.content |
|
|
|
|
|
st.set_page_config(page_title="RAG Demo with Groq", layout="wide") |
|
st.title("π§ RAG App using Groq API + RAG-Instruct Dataset") |
|
|
|
data = load_data() |
|
model, index, _ = setup_faiss(data) |
|
|
|
st.markdown("Ask a question based on the QA knowledge base.") |
|
|
|
|
|
optional_queries = [ |
|
"What is retrieval-augmented generation?", |
|
"How can I fine-tune a language model?", |
|
"What are the components of a RAG system?", |
|
"Explain prompt engineering basics.", |
|
"How does FAISS indexing help in RAG?" |
|
] |
|
|
|
query = st.text_input("Enter your question:", value=optional_queries[0]) |
|
if st.button("Ask"): |
|
with st.spinner("Retrieving and generating response..."): |
|
context = retrieve_context(query, model, index, data) |
|
answer = query_groq(context, query) |
|
st.subheader("π Retrieved Context") |
|
st.write(context) |
|
st.subheader("π¬ Answer from Groq LLM") |
|
st.write(answer) |
|
|
|
st.markdown("### Optional Queries to Try:") |
|
st.write(", ".join(optional_queries)) |
|
|