AllAboutRAG / app.py
bainskarman's picture
Update app.py
9d72b0b verified
raw
history blame
4.83 kB
import streamlit as st
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings # Updated import
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline
import torch
from transformers import pipeline
from langdetect import detect
# Load a smaller LLM (e.g., Zephyr-7B or Mistral-7B)
def load_llm():
model_name = "HuggingFaceH4/zephyr-7b-alpha" # Replace with your preferred model
pipe = pipeline("text-generation", model=model_name, torch_dtype=torch.float16, device_map="auto")
llm = HuggingFacePipeline(pipeline=pipe)
return llm
# Extract text from PDF
def extract_text_from_pdf(file):
reader = PdfReader(file)
text = ""
for page in reader.pages:
text += page.extract_text()
return text
# Split text into chunks
def split_text(text, chunk_size=1000, chunk_overlap=200):
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
chunks = splitter.split_text(text)
return chunks
# Create embeddings and vector store
def create_vector_store(chunks, indexing_method="multi-representation"):
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
if indexing_method == "multi-representation":
vector_store = FAISS.from_texts(chunks, embeddings)
elif indexing_method == "raptors":
# Implement RAPTORS logic here (e.g., hierarchical chunking)
vector_store = FAISS.from_texts(chunks, embeddings)
elif indexing_method == "colbert":
# Implement ColBERT logic here (e.g., contextualized embeddings)
vector_store = FAISS.from_texts(chunks, embeddings)
return vector_store
# Query the PDF
def query_pdf(vector_store, query, llm, query_method="multi-query"):
if query_method == "multi-query":
# Implement Multi-Query logic here
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vector_store.as_retriever())
elif query_method == "rag-fusion":
# Implement RAG Fusion logic here
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vector_store.as_retriever())
elif query_method == "decomposition":
# Implement Decomposition logic here
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vector_store.as_retriever())
elif query_method == "step-back":
# Implement Step Back logic here
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vector_store.as_retriever())
elif query_method == "hyde":
# Implement HyDE logic here
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vector_store.as_retriever())
result = qa.run(query)
return result
# Detect language of the text
def detect_language(text):
try:
return detect(text)
except:
return "en" # Default to English if detection fails
# Streamlit App
def main():
st.title("Chat with PDF")
st.write("Upload a PDF and ask questions about it!")
# File uploader
uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
if uploaded_file is None:
st.info("Using default PDF.")
uploaded_file = "default.pdf" # Add a default PDF
# Extract text
text = extract_text_from_pdf(uploaded_file)
# Detect language
language = detect_language(text)
st.write(f"Detected Language: {language}")
# Split text into chunks
chunk_size = st.slider("Chunk Size", 500, 2000, 1000)
chunk_overlap = st.slider("Chunk Overlap", 0, 500, 200)
chunks = split_text(text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
# Indexing options
indexing_method = st.selectbox(
"Indexing Method",
["multi-representation", "raptors", "colbert"],
help="Choose how to index the PDF text."
)
st.write(f"**Indexing Method:** {indexing_method}")
# Create vector store
vector_store = create_vector_store(chunks, indexing_method=indexing_method)
# Load LLM
llm = load_llm()
# Query translation options
query_method = st.selectbox(
"Query Translation Method",
["multi-query", "rag-fusion", "decomposition", "step-back", "hyde"],
help="Choose a method to improve query retrieval."
)
st.write(f"**Query Translation Method:** {query_method}")
# User input
query = st.text_input("Ask a question about the PDF:")
if query:
# Query the PDF
result = query_pdf(vector_store, query, llm, query_method=query_method)
st.write("**Answer:**", result["answer"])
st.write("**Source Text:**", result["source_text"])
if __name__ == "__main__":
main()