Spaces:
Sleeping
Sleeping
import streamlit as st | |
import boto3 | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_aws import BedrockEmbeddings | |
# --- CHANGED: Import Qdrant instead of Chroma --- | |
from langchain_qdrant import Qdrant | |
# --- Optional: If you need direct Qdrant client interaction or for advanced setups --- | |
# from qdrant_client import QdrantClient, models | |
from langchain_aws import ChatBedrock | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema import StrOutputParser | |
from langchain.schema.runnable import RunnablePassthrough | |
import os | |
from dotenv import load_dotenv # Import load_dotenv | |
# --- Load Environment Variables --- | |
load_dotenv() # This loads variables from .env file | |
# --- Streamlit UI Setup (MUST BE THE FIRST STREAMLIT COMMAND) --- | |
st.set_page_config( | |
page_title="Math Research Paper RAG Bot", | |
page_icon="π", | |
layout="wide" | |
) | |
st.title("π Math Research Paper RAG Chatbot") | |
st.markdown( | |
""" | |
Upload a mathematical research paper (PDF) and ask questions about its content. | |
This bot uses Amazon Bedrock (Claude 3 Sonnet for reasoning, Titan Embeddings for vectors) | |
and **QdrantDB** for Retrieval-Augmented Generation. | |
**Note:** This application requires AWS credentials (`AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`) | |
and region (`AWS_REGION`) to be set up in a `.env` file or environment variables. | |
The Qdrant vector store is **in-memory** and will be reset on app restart. | |
""" | |
) | |
# --- Configuration --- | |
# Set AWS region (adjust if needed, loaded from .env or env var) | |
AWS_REGION = os.getenv("AWS_REGION") | |
if not AWS_REGION: | |
st.error("AWS_REGION not found in environment variables or .env file. Please set it.") | |
st.stop() | |
# Bedrock model IDs | |
EMBEDDING_MODEL_ID = "amazon.titan-embed-text-v1" | |
LLM_MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0" | |
# --- Qdrant Specific Configuration --- | |
QDRANT_COLLECTION_NAME = "math_research_papers_collection" | |
EMBEDDING_DIMENSION = 1536 # Titan Text Embeddings output 1536-dimensional vectors | |
# --- Initialize Bedrock Client (once) --- | |
def get_bedrock_client(): | |
"""Initializes and returns a boto3 Bedrock client. | |
Returns: Tuple (boto3_client, success_bool, error_message_str or None) | |
""" | |
try: | |
client = boto3.client( | |
service_name="bedrock-runtime", | |
region_name=AWS_REGION | |
) | |
return client, True, None # Success: client, True, no error message | |
except Exception as e: | |
return None, False, str(e) # Failure: None, False, error message | |
# Get the client and check its status | |
bedrock_client, bedrock_success, bedrock_error_msg = get_bedrock_client() | |
if not bedrock_success: | |
st.error(f"Error connecting to AWS Bedrock. Please check your AWS credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) and region (AWS_REGION) in your .env file or environment variables. Error: {bedrock_error_msg}") | |
st.stop() # Stop execution if Bedrock client cannot be initialized | |
else: | |
st.success(f"Successfully connected to AWS Bedrock in {AWS_REGION}!") | |
# --- LangChain Components --- | |
def get_embeddings_model(_client): # Prepend underscore to tell Streamlit not to hash | |
"""Returns the BedrockEmbeddings model.""" | |
return BedrockEmbeddings(client=_client, model_id=EMBEDDING_MODEL_ID) | |
def get_llm_model(_client): # Prepend underscore to tell Streamlit not to hash | |
"""Returns the Bedrock LLM model for Claude 3 Sonnet.""" | |
return ChatBedrock( | |
client=_client, | |
model_id=LLM_MODEL_ID, | |
streaming=False, | |
temperature=0.1, | |
model_kwargs={"max_tokens": 4000} | |
) | |
# --- PDF Processing and Vector Store Creation --- | |
def create_vector_store(pdf_file_path): | |
""" | |
Loads PDF, chunks it contextually for mathematical papers, | |
creates embeddings, and stores them in QdrantDB (in-memory). | |
""" | |
with st.spinner("Loading PDF and creating vector store..."): | |
# 1. Load PDF | |
loader = PyPDFLoader(pdf_file_path) | |
pages = loader.load_and_split() | |
st.info(f"Loaded {len(pages)} pages from the PDF.") | |
# 2. Contextual Chunking for Mathematical Papers | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1500, # Increased chunk size for math papers | |
chunk_overlap=150, # Generous overlap to maintain context | |
separators=[ | |
"\n\n", # Prefer splitting by paragraphs | |
"\n", # Then by newlines (might break equations but less likely than fixed char) | |
" ", # Then by spaces | |
"", # Fallback | |
], | |
length_function=len, | |
is_separator_regex=False, | |
) | |
chunks = text_splitter.split_documents(pages) | |
st.info(f"Split PDF into {len(chunks)} chunks.") | |
# 3. Create Embeddings and QdrantDB | |
embeddings = get_embeddings_model(bedrock_client) | |
# --- CHANGED: Qdrant vector store creation --- | |
vector_store = Qdrant.from_documents( | |
documents=chunks, | |
embedding=embeddings, | |
location=":memory:", # Use in-memory Qdrant instance | |
collection_name=QDRANT_COLLECTION_NAME, | |
# For persistent Qdrant (requires a running Qdrant server): | |
# url="http://localhost:6333", # Or your Qdrant Cloud URL | |
# api_key="YOUR_QDRANT_CLOUD_API_KEY", # Only for Qdrant Cloud | |
# prefer_grpc=True # Set to True if using gRPC for Qdrant Cloud | |
# force_recreate=True # Use with caution: deletes existing collection | |
) | |
# Note: LangChain's Qdrant integration will automatically create the collection | |
# if it doesn't exist, inferring vector_size from embeddings. | |
st.success("Vector store created and ready!") | |
return vector_store | |
# --- RAG Chain Construction --- | |
def get_rag_chain(vector_store): | |
"""Constructs the RAG chain using LCEL.""" | |
retriever = vector_store.as_retriever(search_kwargs={"k": 5}) # Retrieve top 5 relevant chunks | |
llm = get_llm_model(bedrock_client) | |
# Prompt Template optimized for mathematical research papers | |
prompt_template = ChatPromptTemplate.from_messages( | |
[ | |
("system", | |
"You are an expert AI assistant specialized in analyzing and explaining mathematical research papers. " | |
"Your goal is to provide precise, accurate, and concise answers based *only* on the provided context from the research paper. " | |
"When answering, focus on definitions, theorems, proofs, key mathematical concepts, and experimental results. " | |
"If the user asks about a mathematical notation, try to explain its meaning from the context. " | |
"If the answer is not found in the context, explicitly state that you cannot find the information within the provided document. " | |
"Do not invent information or make assumptions outside the given text.\n\n" | |
"Context:\n{context}"), | |
("user", "{question}"), | |
] | |
) | |
rag_chain = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| prompt_template | |
| llm | |
| StrOutputParser() | |
) | |
return rag_chain | |
# --- Streamlit UI Main Logic --- | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Initialize vector store and RAG chain | |
if "vector_store" not in st.session_state: | |
st.session_state.vector_store = None | |
if "rag_chain" not in st.session_state: | |
st.session_state.rag_chain = None | |
if "pdf_uploaded" not in st.session_state: | |
st.session_state.pdf_uploaded = False | |
# Sidebar for PDF Upload | |
with st.sidebar: | |
st.header("Upload PDF") | |
uploaded_file = st.file_uploader( | |
"Choose a PDF file", | |
type="pdf", | |
accept_multiple_files=False, | |
key="pdf_uploader" | |
) | |
if uploaded_file and not st.session_state.pdf_uploaded: | |
# Save the uploaded file temporarily | |
with open("temp_doc.pdf", "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
st.session_state.vector_store = create_vector_store("temp_doc.pdf") | |
st.session_state.rag_chain = get_rag_chain(st.session_state.vector_store) | |
st.session_state.pdf_uploaded = True | |
st.success("PDF processed successfully! You can now ask questions.") | |
# Clean up temporary file | |
os.remove("temp_doc.pdf") | |
elif st.session_state.pdf_uploaded: | |
st.info("PDF already processed. Ready for questions!") | |
# Display chat messages from history on app rerun | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Accept user input | |
if prompt := st.chat_input("Ask a question about the paper..."): | |
if not st.session_state.pdf_uploaded: | |
st.warning("Please upload a PDF first to start asking questions.") | |
else: | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Get response from RAG chain | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
try: | |
full_response = st.session_state.rag_chain.invoke(prompt) | |
st.markdown(full_response, unsafe_allow_html=True) | |
# Add assistant response to chat history | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |
except Exception as e: | |
st.error(f"An error occurred during response generation: {e}") | |
st.warning("Please try again or check your AWS Bedrock access permissions.") | |
# Optional: Clear chat and uploaded PDF | |
if st.session_state.pdf_uploaded: | |
if st.sidebar.button("Clear Chat and Upload New PDF"): | |
st.session_state.messages = [] | |
st.session_state.vector_store = None | |
st.session_state.rag_chain = None | |
st.session_state.pdf_uploaded = False | |
st.cache_resource.clear() # Clear streamlit caches for a clean slate | |
st.rerun() |