File size: 2,398 Bytes
319855f
 
 
 
 
 
 
 
 
 
 
 
 
8d7ab91
319855f
 
 
 
 
8d7ab91
319855f
 
 
 
 
 
 
 
8d7ab91
319855f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d7ab91
319855f
 
 
8d7ab91
 
 
 
 
319855f
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import streamlit as st
import pandas as pd
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from groq import Groq

# Load dataset
@st.cache_data
def load_data():
    dataset = load_dataset("FreedomIntelligence/RAG-Instruct", split="train")
    df = pd.DataFrame(dataset)
    return df[["question", "answer"]]

# Generate embeddings and index
@st.cache_resource
def setup_faiss(data):
    model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    embeddings = model.encode(data["question"].tolist())
    index = faiss.IndexFlatL2(embeddings.shape[1])
    index.add(np.array(embeddings))
    return model, index, embeddings

# Retrieve relevant context
def retrieve_context(query, model, index, data, top_k=1):
    query_vec = model.encode([query])
    distances, indices = index.search(np.array(query_vec), top_k)
    results = [data.iloc[i]["question"] + "\n\n" + data.iloc[i]["answer"] for i in indices[0]]
    return "\n\n".join(results)

# Call Groq LLM
def query_groq(context, query):
    prompt = f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
    client = Groq(api_key=st.secrets["GROQ_API_KEY"])
    response = client.chat.completions.create(
        messages=[{"role": "user", "content": prompt}],
        model="llama-3-70b-8192"
    )
    return response.choices[0].message.content

# Streamlit UI
st.set_page_config(page_title="RAG Demo with Groq", layout="wide")
st.title("🧠 RAG App using Groq API + RAG-Instruct Dataset")

data = load_data()
model, index, _ = setup_faiss(data)

st.markdown("Ask a question based on the QA knowledge base.")

# Optional queries
optional_queries = [
    "What is retrieval-augmented generation?",
    "How can I fine-tune a language model?",
    "What are the components of a RAG system?",
    "Explain prompt engineering basics.",
    "How does FAISS indexing help in RAG?"
]

query = st.text_input("Enter your question:", value=optional_queries[0])
if st.button("Ask"):
    with st.spinner("Retrieving and generating response..."):
        context = retrieve_context(query, model, index, data)
        answer = query_groq(context, query)
    st.subheader("📄 Retrieved Context")
    st.write(context)
    st.subheader("💬 Answer from Groq LLM")
    st.write(answer)

st.markdown("### Optional Queries to Try:")
st.write(", ".join(optional_queries))