File size: 2,398 Bytes
319855f 8d7ab91 319855f 8d7ab91 319855f 8d7ab91 319855f 8d7ab91 319855f 8d7ab91 319855f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import streamlit as st
import pandas as pd
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from groq import Groq
# Load dataset
@st.cache_data
def load_data():
dataset = load_dataset("FreedomIntelligence/RAG-Instruct", split="train")
df = pd.DataFrame(dataset)
return df[["question", "answer"]]
# Generate embeddings and index
@st.cache_resource
def setup_faiss(data):
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embeddings = model.encode(data["question"].tolist())
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(np.array(embeddings))
return model, index, embeddings
# Retrieve relevant context
def retrieve_context(query, model, index, data, top_k=1):
query_vec = model.encode([query])
distances, indices = index.search(np.array(query_vec), top_k)
results = [data.iloc[i]["question"] + "\n\n" + data.iloc[i]["answer"] for i in indices[0]]
return "\n\n".join(results)
# Call Groq LLM
def query_groq(context, query):
prompt = f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
client = Groq(api_key=st.secrets["GROQ_API_KEY"])
response = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama-3-70b-8192"
)
return response.choices[0].message.content
# Streamlit UI
st.set_page_config(page_title="RAG Demo with Groq", layout="wide")
st.title("🧠 RAG App using Groq API + RAG-Instruct Dataset")
data = load_data()
model, index, _ = setup_faiss(data)
st.markdown("Ask a question based on the QA knowledge base.")
# Optional queries
optional_queries = [
"What is retrieval-augmented generation?",
"How can I fine-tune a language model?",
"What are the components of a RAG system?",
"Explain prompt engineering basics.",
"How does FAISS indexing help in RAG?"
]
query = st.text_input("Enter your question:", value=optional_queries[0])
if st.button("Ask"):
with st.spinner("Retrieving and generating response..."):
context = retrieve_context(query, model, index, data)
answer = query_groq(context, query)
st.subheader("📄 Retrieved Context")
st.write(context)
st.subheader("💬 Answer from Groq LLM")
st.write(answer)
st.markdown("### Optional Queries to Try:")
st.write(", ".join(optional_queries))
|