File size: 3,467 Bytes
8b835fd
319855f
 
8b835fd
 
319855f
 
 
 
8b835fd
 
 
 
319855f
 
 
 
 
8d7ab91
319855f
8b835fd
319855f
 
 
8b835fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319855f
 
8b835fd
319855f
 
8b835fd
319855f
 
 
 
 
10d3bb1
319855f
 
 
 
 
 
 
8b835fd
 
319855f
8b835fd
319855f
8b835fd
319855f
8d7ab91
319855f
 
 
8d7ab91
 
 
 
 
319855f
 
 
 
 
8b835fd
319855f
 
 
 
 
 
8b835fd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import streamlit as st
import pandas as pd
import numpy as np
import faiss
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from groq import Groq

# Constants for saving/loading index
INDEX_FILE = "faiss_index.index"
QUESTIONS_FILE = "questions.npy"

# Load dataset
@st.cache_data
def load_data():
    dataset = load_dataset("FreedomIntelligence/RAG-Instruct", split="train")
    df = pd.DataFrame(dataset)
    return df[["question", "answer"]]

# Build or load FAISS index
@st.cache_resource
def setup_faiss(data):
    model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

    if os.path.exists(INDEX_FILE) and os.path.exists(QUESTIONS_FILE):
        st.info("πŸ” Loading FAISS index from disk...")
        index = faiss.read_index(INDEX_FILE)
        questions = np.load(QUESTIONS_FILE, allow_pickle=True)
    else:
        st.info("βš™οΈ FAISS index not found. Building new index...")

        questions = data["question"].tolist()
        embeddings = []
        progress_bar = st.progress(0, text="Embedding questions...")
        total = len(questions)

        for i, chunk in enumerate(np.array_split(questions, 10)):
            emb = model.encode(chunk)
            embeddings.extend(emb)
            progress_bar.progress((i + 1) / 10, text=f"Embedding... {int((i + 1) * 10)}%")

        embeddings = np.array(embeddings)
        index = faiss.IndexFlatL2(embeddings.shape[1])
        index.add(embeddings)

        faiss.write_index(index, INDEX_FILE)
        np.save(QUESTIONS_FILE, np.array(questions, dtype=object))

        progress_bar.empty()
        st.success("βœ… FAISS index built and saved!")

    return model, index, questions


# Retrieve relevant context
def retrieve_context(query, model, index, questions, data, top_k=1):
    query_vec = model.encode([query])
    distances, indices = index.search(np.array(query_vec), top_k)
    results = [questions[i] + "\n\n" + data.iloc[i]["answer"] for i in indices[0]]
    return "\n\n".join(results)

# Call Groq LLM
def query_groq(context, query):
    prompt = f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
    client = Groq(api_key=st.secrets[GROQ_API_KEY])
    response = client.chat.completions.create(
        messages=[{"role": "user", "content": prompt}],
        model="llama-3-70b-8192"
    )
    return response.choices[0].message.content

# Streamlit UI
st.set_page_config(page_title="RAG App with Groq", layout="wide")
st.title("πŸ” RAG App using Groq API + RAG-Instruct Dataset")

# Load data and setup
data = load_data()
model, index, questions = setup_faiss(data)

st.markdown("Ask a question based on the QA knowledge base.")

# Optional queries
optional_queries = [
    "What is retrieval-augmented generation?",
    "How can I fine-tune a language model?",
    "What are the components of a RAG system?",
    "Explain prompt engineering basics.",
    "How does FAISS indexing help in RAG?"
]

query = st.text_input("Enter your question:", value=optional_queries[0])
if st.button("Ask"):
    with st.spinner("Retrieving and generating response..."):
        context = retrieve_context(query, model, index, questions, data)
        answer = query_groq(context, query)
    st.subheader("πŸ“„ Retrieved Context")
    st.write(context)
    st.subheader("πŸ’¬ Answer from Groq LLM")
    st.write(answer)

st.markdown("### πŸ’‘ Optional Queries to Try:")
for q in optional_queries:
    st.markdown(f"- {q}")