AllAboutRAG / app.py
bainskarman's picture
Update app.py
13f8dc4 verified
raw
history blame
6.05 kB
import streamlit as st
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline
import torch
from transformers import pipeline
from langdetect import detect
# Load a smaller LLM with customizable parameters
def load_llm(temperature, top_k, max_new_tokens):
model_name = "HuggingFaceH4/zephyr-7b-alpha" # Replace with your preferred model
pipe = pipeline(
"text-generation",
model=model_name,
torch_dtype=torch.float16,
device_map="auto",
temperature=temperature,
top_k=top_k,
max_new_tokens=max_new_tokens, # Use max_new_tokens instead of max_length
)
llm = HuggingFacePipeline(pipeline=pipe)
return llm
# Extract text from PDF
def extract_text_from_pdf(file):
reader = PdfReader(file)
text = ""
for page in reader.pages:
text += page.extract_text()
return text
# Split text into chunks
def split_text(text, chunk_size=1000, chunk_overlap=200):
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
chunks = splitter.split_text(text)
return chunks
# Create embeddings and vector store
def create_vector_store(chunks, indexing_method="multi-representation", **kwargs):
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
if indexing_method == "multi-representation":
vector_store = FAISS.from_texts(chunks, embeddings)
elif indexing_method == "raptors":
# Implement RAPTORS logic here (e.g., hierarchical chunking)
vector_store = FAISS.from_texts(chunks, embeddings)
elif indexing_method == "colbert":
# Implement ColBERT logic here (e.g., contextualized embeddings)
vector_store = FAISS.from_texts(chunks, embeddings)
return vector_store
# Query the PDF
def query_pdf(vector_store, query, llm, query_method="multi-query", **kwargs):
if query_method == "multi-query":
# Implement Multi-Query logic here
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vector_store.as_retriever())
elif query_method == "rag-fusion":
# Implement RAG Fusion logic here
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vector_store.as_retriever())
elif query_method == "decomposition":
# Implement Decomposition logic here
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vector_store.as_retriever())
elif query_method == "step-back":
# Implement Step Back logic here
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vector_store.as_retriever())
elif query_method == "hyde":
# Implement HyDE logic here
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vector_store.as_retriever())
result = qa.run(query)
return result
# Detect language of the text
def detect_language(text):
try:
return detect(text)
except:
return "en" # Default to English if detection fails
# Streamlit App
def main():
st.title("Chat with PDF")
st.write("Upload a PDF and ask questions about it!")
# File uploader
uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
if uploaded_file is None:
st.info("Using default PDF.")
uploaded_file = "default.pdf" # Add a default PDF
# Step 1: Extract text and split into chunks
if "text" not in st.session_state:
st.session_state.text = None
if "chunks" not in st.session_state:
st.session_state.chunks = None
if st.button("Extract Text and Split into Chunks"):
st.session_state.text = extract_text_from_pdf(uploaded_file)
st.session_state.chunks = split_text(st.session_state.text)
st.success("Text extracted and split into chunks!")
# Step 2: Create vector store
if "vector_store" not in st.session_state:
st.session_state.vector_store = None
if st.session_state.chunks:
st.subheader("Indexing Options")
indexing_method = st.selectbox(
"Indexing Method",
["multi-representation", "raptors", "colbert"],
help="Choose how to index the PDF text."
)
if st.button("Create Vector Store"):
st.session_state.vector_store = create_vector_store(st.session_state.chunks, indexing_method=indexing_method)
st.success("Vector store created!")
# Step 3: Load LLM with user-defined parameters
if "llm" not in st.session_state:
st.session_state.llm = None
if st.session_state.vector_store:
st.subheader("LLM Parameters")
temperature = st.slider("Temperature", 0.1, 1.0, 0.7, help="Controls randomness in the output.")
top_k = st.slider("Top-k", 1, 100, 50, help="Limits sampling to the top-k tokens.")
max_new_tokens = st.slider("Max New Tokens", 50, 500, 200, help="Maximum number of tokens to generate.")
if st.button("Load LLM"):
st.session_state.llm = load_llm(temperature=temperature, top_k=top_k, max_new_tokens=max_new_tokens)
st.success("LLM loaded!")
# Step 4: Query the PDF
if st.session_state.llm:
st.subheader("Query Translation Options")
query_method = st.selectbox(
"Query Translation Method",
["multi-query", "rag-fusion", "decomposition", "step-back", "hyde"],
help="Choose a method to improve query retrieval."
)
query = st.text_input("Ask a question about the PDF:")
if query:
result = query_pdf(st.session_state.vector_store, query, st.session_state.llm, query_method=query_method)
st.write("**Answer:**", result["answer"])
st.write("**Source Text:**", result["source_text"])
if __name__ == "__main__":
main()