AllAboutRAG / app.py
bainskarman's picture
Update app.py
783a14e verified
raw
history blame
7.38 kB
import streamlit as st
import os
import requests
import faiss
import numpy as np
from pdfminer.high_level import extract_text
from sentence_transformers import SentenceTransformer
from langdetect import detect
# Load the Hugging Face token
huggingface_token = os.environ.get("Key2")
# Load Sentence Transformer Model
embedder = SentenceTransformer("all-MiniLM-L6-v2")
# Default system prompts for each query translation method
DEFAULT_SYSTEM_PROMPTS = {
"Multi-Query": """You are an AI language model assistant. Your task is to generate five \
different versions of the given user question to retrieve relevant documents from a vector \
database. By generating multiple perspectives on the user question, your goal is to help\
the user overcome some of the limitations of the distance-based similarity search.\
Provide these alternative questions separated by newlines. Original question: {question}""",
"RAG Fusion": """You are an AI language model assistant. Your task is to combine multiple \
queries into a single, refined query to improve retrieval accuracy. Original question: {question}""",
"Decomposition": """You are an AI language model assistant. Your task is to break down \
the given user question into simpler sub-questions. Provide these sub-questions separated \
by newlines. Original question: {question}""",
"Step Back": """You are an AI language model assistant. Your task is to refine the given \
user question by taking a step back and asking a more general question. Original question: {question}""",
"HyDE": """You are an AI language model assistant. Your task is to generate a hypothetical \
document that would be relevant to the given user question. Original question: {question}""",
}
# Function to query the Hugging Face model
def query_huggingface_model(prompt, max_new_tokens=1000, temperature=0.7, top_k=50):
model_name = "HuggingFaceH4/zephyr-7b-alpha"
api_url = f"https://api-inference.huggingface.co/models/{model_name}"
headers = {"Authorization": f"Bearer {huggingface_token}"}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_k": top_k,
},
}
response = requests.post(api_url, headers=headers, json=payload)
if response.status_code == 200:
return response.json()[0]["generated_text"]
else:
st.error(f"Error: {response.status_code} - {response.text}")
return None
# Function to detect language
def detect_language(text):
try:
return detect(text)
except:
return "en"
# Extract text from PDF with line and page numbers
def extract_text_from_pdf(pdf_file):
text = extract_text(pdf_file)
return text.split("\n")
# Chunk text into smaller segments
def split_text_into_chunks(text_lines, chunk_size=500):
words = " ".join(text_lines).split()
return [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
# Build FAISS Index
def build_faiss_index(embeddings):
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index
# Search in FAISS Index
def search_faiss_index(query_embedding, index, top_k=5):
distances, indices = index.search(query_embedding, top_k)
return indices[0], distances[0]
def main():
st.title("Enhanced RAG Model with FAISS Indexing")
# Sidebar for options
st.sidebar.header("Upload PDF")
pdf_file = st.sidebar.file_uploader("Upload a PDF file", type="pdf")
st.sidebar.header("Query Translation")
query_translation = st.sidebar.selectbox(
"Select Query Translation Method",
["Multi-Query", "RAG Fusion", "Decomposition", "Step Back", "HyDE"]
)
st.sidebar.header("Similarity Search")
similarity_method = st.sidebar.selectbox("Select Similarity Search Method", ["Cosine Similarity", "KNN"])
if similarity_method == "KNN":
k_value = st.sidebar.slider("Select K Value", 1, 10, 5)
# LLM Parameters
max_new_tokens = st.sidebar.slider("Max New Tokens", 10, 1000, 500)
temperature = st.sidebar.slider("Temperature", 0.1, 1.0, 0.7)
top_k = st.sidebar.slider("Top K", 1, 100, 50)
# Input Prompt
prompt = st.text_input("Enter your query:")
# State to hold intermediate results
if 'embeddings' not in st.session_state:
st.session_state.embeddings = None
if 'chunks' not in st.session_state:
st.session_state.chunks = []
if 'faiss_index' not in st.session_state:
st.session_state.faiss_index = None
if 'relevant_chunks' not in st.session_state:
st.session_state.relevant_chunks = []
if 'translated_queries' not in st.session_state:
st.session_state.translated_queries = []
# Button 1: Embed PDF
if st.button("1. Embed PDF") and pdf_file:
text_lines = extract_text_from_pdf(pdf_file)
st.session_state.lang = detect_language(" ".join(text_lines))
st.write(f"**Detected Language:** {st.session_state.lang}")
# Chunk the text
st.session_state.chunks = split_text_into_chunks(text_lines)
# Encode chunks
chunk_embeddings = embedder.encode(st.session_state.chunks, convert_to_tensor=False)
# Build FAISS index
st.session_state.faiss_index = build_faiss_index(np.array(chunk_embeddings))
st.success("PDF Embedded Successfully")
# Button 2: Generate Translated Queries
if st.button("2. Query Translation") and prompt:
formatted_prompt = DEFAULT_SYSTEM_PROMPTS[query_translation].format(question=prompt)
response = query_huggingface_model(formatted_prompt, max_new_tokens, temperature, top_k)
st.session_state.translated_queries = response.split("\n")
st.write("**Generated Queries:**")
st.write(st.session_state.translated_queries)
# Button 3: Retrieve Document Details
if st.button("3. Retrieve Documents") and st.session_state.translated_queries:
st.session_state.relevant_chunks = []
for query in st.session_state.translated_queries:
query_embedding = embedder.encode([query], convert_to_tensor=False)
top_k_indices, _ = search_faiss_index(np.array(query_embedding), st.session_state.faiss_index, top_k=5)
relevant_chunks = [st.session_state.chunks[i] for i in top_k_indices]
st.session_state.relevant_chunks.append(relevant_chunks)
st.write("**Retrieved Documents (for each query):**")
for i, relevant_chunks in enumerate(st.session_state.relevant_chunks):
st.write(f"**Query {i + 1}: {st.session_state.translated_queries[i]}**")
for chunk in relevant_chunks:
st.write(f"{chunk[:100]}...")
# Button 4: Generate Final Response
if st.button("4. Final Response") and st.session_state.relevant_chunks:
context = "\n".join([chunk for sublist in st.session_state.relevant_chunks for chunk in sublist])
llm_input = f"{DEFAULT_SYSTEM_PROMPTS[query_translation].format(question=prompt)}\n\nContext: {context}\n\nAnswer this question: {prompt}"
final_response = query_huggingface_model(llm_input, max_new_tokens, temperature, top_k)
st.subheader("Final Response:")
st.write(final_response)
if __name__ == "__main__":
main()