Spaces:
Sleeping
Sleeping
File size: 10,348 Bytes
41bdb8e 1bcd518 41bdb8e 1bcd518 41bdb8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
import streamlit as st
import boto3
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_aws import BedrockEmbeddings
# --- CHANGED: Import Qdrant instead of Chroma ---
from langchain_qdrant import Qdrant
# --- Optional: If you need direct Qdrant client interaction or for advanced setups ---
# from qdrant_client import QdrantClient, models
from langchain_aws import ChatBedrock
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
import os
from dotenv import load_dotenv # Import load_dotenv
# --- Load Environment Variables ---
load_dotenv() # This loads variables from .env file
# --- Streamlit UI Setup (MUST BE THE FIRST STREAMLIT COMMAND) ---
st.set_page_config(
page_title="Math Research Paper RAG Bot",
page_icon="π",
layout="wide"
)
st.title("π Math Research Paper RAG Chatbot")
st.markdown(
"""
Upload a mathematical research paper (PDF) and ask questions about its content.
This bot uses Amazon Bedrock (Claude 3 Sonnet for reasoning, Titan Embeddings for vectors)
and **QdrantDB** for Retrieval-Augmented Generation.
**Note:** This application requires AWS credentials (`AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`)
and region (`AWS_REGION`) to be set up in a `.env` file or environment variables.
The Qdrant vector store is **in-memory** and will be reset on app restart.
"""
)
# --- Configuration ---
# Set AWS region (adjust if needed, loaded from .env or env var)
AWS_REGION = os.getenv("AWS_REGION")
if not AWS_REGION:
st.error("AWS_REGION not found in environment variables or .env file. Please set it.")
st.stop()
# Bedrock model IDs
EMBEDDING_MODEL_ID = "amazon.titan-embed-text-v1"
LLM_MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0"
# --- Qdrant Specific Configuration ---
QDRANT_COLLECTION_NAME = "math_research_papers_collection"
EMBEDDING_DIMENSION = 1536 # Titan Text Embeddings output 1536-dimensional vectors
# --- Initialize Bedrock Client (once) ---
@st.cache_resource
def get_bedrock_client():
"""Initializes and returns a boto3 Bedrock client.
Returns: Tuple (boto3_client, success_bool, error_message_str or None)
"""
try:
client = boto3.client(
service_name="bedrock-runtime",
region_name=AWS_REGION
)
return client, True, None # Success: client, True, no error message
except Exception as e:
return None, False, str(e) # Failure: None, False, error message
# Get the client and check its status
bedrock_client, bedrock_success, bedrock_error_msg = get_bedrock_client()
if not bedrock_success:
st.error(f"Error connecting to AWS Bedrock. Please check your AWS credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) and region (AWS_REGION) in your .env file or environment variables. Error: {bedrock_error_msg}")
st.stop() # Stop execution if Bedrock client cannot be initialized
else:
st.success(f"Successfully connected to AWS Bedrock in {AWS_REGION}!")
# --- LangChain Components ---
@st.cache_resource
def get_embeddings_model(_client): # Prepend underscore to tell Streamlit not to hash
"""Returns the BedrockEmbeddings model."""
return BedrockEmbeddings(client=_client, model_id=EMBEDDING_MODEL_ID)
@st.cache_resource
def get_llm_model(_client): # Prepend underscore to tell Streamlit not to hash
"""Returns the Bedrock LLM model for Claude 3 Sonnet."""
return ChatBedrock(
client=_client,
model_id=LLM_MODEL_ID,
streaming=False,
temperature=0.1,
model_kwargs={"max_tokens": 4000}
)
# --- PDF Processing and Vector Store Creation ---
def create_vector_store(pdf_file_path):
"""
Loads PDF, chunks it contextually for mathematical papers,
creates embeddings, and stores them in QdrantDB (in-memory).
"""
with st.spinner("Loading PDF and creating vector store..."):
# 1. Load PDF
loader = PyPDFLoader(pdf_file_path)
pages = loader.load_and_split()
st.info(f"Loaded {len(pages)} pages from the PDF.")
# 2. Contextual Chunking for Mathematical Papers
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1500, # Increased chunk size for math papers
chunk_overlap=150, # Generous overlap to maintain context
separators=[
"\n\n", # Prefer splitting by paragraphs
"\n", # Then by newlines (might break equations but less likely than fixed char)
" ", # Then by spaces
"", # Fallback
],
length_function=len,
is_separator_regex=False,
)
chunks = text_splitter.split_documents(pages)
st.info(f"Split PDF into {len(chunks)} chunks.")
# 3. Create Embeddings and QdrantDB
embeddings = get_embeddings_model(bedrock_client)
# --- CHANGED: Qdrant vector store creation ---
vector_store = Qdrant.from_documents(
documents=chunks,
embedding=embeddings,
location=":memory:", # Use in-memory Qdrant instance
collection_name=QDRANT_COLLECTION_NAME,
# For persistent Qdrant (requires a running Qdrant server):
# url="http://localhost:6333", # Or your Qdrant Cloud URL
# api_key="YOUR_QDRANT_CLOUD_API_KEY", # Only for Qdrant Cloud
# prefer_grpc=True # Set to True if using gRPC for Qdrant Cloud
# force_recreate=True # Use with caution: deletes existing collection
)
# Note: LangChain's Qdrant integration will automatically create the collection
# if it doesn't exist, inferring vector_size from embeddings.
st.success("Vector store created and ready!")
return vector_store
# --- RAG Chain Construction ---
def get_rag_chain(vector_store):
"""Constructs the RAG chain using LCEL."""
retriever = vector_store.as_retriever(search_kwargs={"k": 5}) # Retrieve top 5 relevant chunks
llm = get_llm_model(bedrock_client)
# Prompt Template optimized for mathematical research papers
prompt_template = ChatPromptTemplate.from_messages(
[
("system",
"You are an expert AI assistant specialized in analyzing and explaining mathematical research papers. "
"Your goal is to provide precise, accurate, and concise answers based *only* on the provided context from the research paper. "
"When answering, focus on definitions, theorems, proofs, key mathematical concepts, and experimental results. "
"If the user asks about a mathematical notation, try to explain its meaning from the context. "
"If the answer is not found in the context, explicitly state that you cannot find the information within the provided document. "
"Do not invent information or make assumptions outside the given text.\n\n"
"Context:\n{context}"),
("user", "{question}"),
]
)
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt_template
| llm
| StrOutputParser()
)
return rag_chain
# --- Streamlit UI Main Logic ---
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Initialize vector store and RAG chain
if "vector_store" not in st.session_state:
st.session_state.vector_store = None
if "rag_chain" not in st.session_state:
st.session_state.rag_chain = None
if "pdf_uploaded" not in st.session_state:
st.session_state.pdf_uploaded = False
# Sidebar for PDF Upload
with st.sidebar:
st.header("Upload PDF")
uploaded_file = st.file_uploader(
"Choose a PDF file",
type="pdf",
accept_multiple_files=False,
key="pdf_uploader"
)
if uploaded_file and not st.session_state.pdf_uploaded:
# Save the uploaded file temporarily
with open("temp_doc.pdf", "wb") as f:
f.write(uploaded_file.getbuffer())
st.session_state.vector_store = create_vector_store("temp_doc.pdf")
st.session_state.rag_chain = get_rag_chain(st.session_state.vector_store)
st.session_state.pdf_uploaded = True
st.success("PDF processed successfully! You can now ask questions.")
# Clean up temporary file
os.remove("temp_doc.pdf")
elif st.session_state.pdf_uploaded:
st.info("PDF already processed. Ready for questions!")
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Accept user input
if prompt := st.chat_input("Ask a question about the paper..."):
if not st.session_state.pdf_uploaded:
st.warning("Please upload a PDF first to start asking questions.")
else:
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# Get response from RAG chain
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
try:
full_response = st.session_state.rag_chain.invoke(prompt)
st.markdown(full_response, unsafe_allow_html=True)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": full_response})
except Exception as e:
st.error(f"An error occurred during response generation: {e}")
st.warning("Please try again or check your AWS Bedrock access permissions.")
# Optional: Clear chat and uploaded PDF
if st.session_state.pdf_uploaded:
if st.sidebar.button("Clear Chat and Upload New PDF"):
st.session_state.messages = []
st.session_state.vector_store = None
st.session_state.rag_chain = None
st.session_state.pdf_uploaded = False
st.cache_resource.clear() # Clear streamlit caches for a clean slate
st.rerun() |