File size: 10,348 Bytes
41bdb8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bcd518
 
41bdb8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bcd518
41bdb8e
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import streamlit as st
import boto3
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_aws import BedrockEmbeddings
# --- CHANGED: Import Qdrant instead of Chroma ---
from langchain_qdrant import Qdrant 
# --- Optional: If you need direct Qdrant client interaction or for advanced setups ---
# from qdrant_client import QdrantClient, models 

from langchain_aws import ChatBedrock
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
import os
from dotenv import load_dotenv # Import load_dotenv

# --- Load Environment Variables ---
load_dotenv() # This loads variables from .env file

# --- Streamlit UI Setup (MUST BE THE FIRST STREAMLIT COMMAND) ---
st.set_page_config(
    page_title="Math Research Paper RAG Bot",
    page_icon="πŸ“š",
    layout="wide"
)

st.title("πŸ“š Math Research Paper RAG Chatbot")
st.markdown(
    """
    Upload a mathematical research paper (PDF) and ask questions about its content. 
    This bot uses Amazon Bedrock (Claude 3 Sonnet for reasoning, Titan Embeddings for vectors) 
    and **QdrantDB** for Retrieval-Augmented Generation.
    
    **Note:** This application requires AWS credentials (`AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`) 
    and region (`AWS_REGION`) to be set up in a `.env` file or environment variables.
    The Qdrant vector store is **in-memory** and will be reset on app restart.
    """
)

# --- Configuration ---
# Set AWS region (adjust if needed, loaded from .env or env var)
AWS_REGION = os.getenv("AWS_REGION") 
if not AWS_REGION:
    st.error("AWS_REGION not found in environment variables or .env file. Please set it.")
    st.stop()

# Bedrock model IDs
EMBEDDING_MODEL_ID = "amazon.titan-embed-text-v1"
LLM_MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0" 

# --- Qdrant Specific Configuration ---
QDRANT_COLLECTION_NAME = "math_research_papers_collection"
EMBEDDING_DIMENSION = 1536 # Titan Text Embeddings output 1536-dimensional vectors

# --- Initialize Bedrock Client (once) ---
@st.cache_resource
def get_bedrock_client():
    """Initializes and returns a boto3 Bedrock client.
    Returns: Tuple (boto3_client, success_bool, error_message_str or None)
    """
    try:
        client = boto3.client(
            service_name="bedrock-runtime",
            region_name=AWS_REGION
        )
        return client, True, None # Success: client, True, no error message
    except Exception as e:
        return None, False, str(e) # Failure: None, False, error message

# Get the client and check its status
bedrock_client, bedrock_success, bedrock_error_msg = get_bedrock_client()

if not bedrock_success:
    st.error(f"Error connecting to AWS Bedrock. Please check your AWS credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) and region (AWS_REGION) in your .env file or environment variables. Error: {bedrock_error_msg}")
    st.stop() # Stop execution if Bedrock client cannot be initialized
else:
    st.success(f"Successfully connected to AWS Bedrock in {AWS_REGION}!")


# --- LangChain Components ---
@st.cache_resource
def get_embeddings_model(_client): # Prepend underscore to tell Streamlit not to hash
    """Returns the BedrockEmbeddings model."""
    return BedrockEmbeddings(client=_client, model_id=EMBEDDING_MODEL_ID)

@st.cache_resource
def get_llm_model(_client): # Prepend underscore to tell Streamlit not to hash
    """Returns the Bedrock LLM model for Claude 3 Sonnet."""
    return ChatBedrock(
        client=_client,
        model_id=LLM_MODEL_ID,
        streaming=False, 
        temperature=0.1, 
        model_kwargs={"max_tokens": 4000}
    )

# --- PDF Processing and Vector Store Creation ---
def create_vector_store(pdf_file_path):
    """
    Loads PDF, chunks it contextually for mathematical papers,
    creates embeddings, and stores them in QdrantDB (in-memory).
    """
    with st.spinner("Loading PDF and creating vector store..."):
        # 1. Load PDF
        loader = PyPDFLoader(pdf_file_path)
        pages = loader.load_and_split()
        st.info(f"Loaded {len(pages)} pages from the PDF.")

        # 2. Contextual Chunking for Mathematical Papers
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1500,  # Increased chunk size for math papers
            chunk_overlap=150, # Generous overlap to maintain context
            separators=[
                "\n\n",  # Prefer splitting by paragraphs
                "\n",    # Then by newlines (might break equations but less likely than fixed char)
                " ",     # Then by spaces
                "",      # Fallback
            ],
            length_function=len,
            is_separator_regex=False,
        )
        chunks = text_splitter.split_documents(pages)
        st.info(f"Split PDF into {len(chunks)} chunks.")

        # 3. Create Embeddings and QdrantDB
        embeddings = get_embeddings_model(bedrock_client)
        
        # --- CHANGED: Qdrant vector store creation ---
        vector_store = Qdrant.from_documents(
            documents=chunks,
            embedding=embeddings,
            location=":memory:",  # Use in-memory Qdrant instance
            collection_name=QDRANT_COLLECTION_NAME,
            # For persistent Qdrant (requires a running Qdrant server):
            # url="http://localhost:6333", # Or your Qdrant Cloud URL
            # api_key="YOUR_QDRANT_CLOUD_API_KEY", # Only for Qdrant Cloud
            # prefer_grpc=True # Set to True if using gRPC for Qdrant Cloud
            # force_recreate=True # Use with caution: deletes existing collection
        )
        # Note: LangChain's Qdrant integration will automatically create the collection
        # if it doesn't exist, inferring vector_size from embeddings.
        
        st.success("Vector store created and ready!")
        return vector_store

# --- RAG Chain Construction ---
def get_rag_chain(vector_store):
    """Constructs the RAG chain using LCEL."""
    retriever = vector_store.as_retriever(search_kwargs={"k": 5}) # Retrieve top 5 relevant chunks
    llm = get_llm_model(bedrock_client) 

    # Prompt Template optimized for mathematical research papers
    prompt_template = ChatPromptTemplate.from_messages(
        [
            ("system", 
             "You are an expert AI assistant specialized in analyzing and explaining mathematical research papers. "
             "Your goal is to provide precise, accurate, and concise answers based *only* on the provided context from the research paper. "
             "When answering, focus on definitions, theorems, proofs, key mathematical concepts, and experimental results. "
             "If the user asks about a mathematical notation, try to explain its meaning from the context. "
             "If the answer is not found in the context, explicitly state that you cannot find the information within the provided document. "
             "Do not invent information or make assumptions outside the given text.\n\n"
             "Context:\n{context}"),
            ("user", "{question}"),
        ]
    )

    rag_chain = (
        {"context": retriever, "question": RunnablePassthrough()}
        | prompt_template
        | llm
        | StrOutputParser()
    )
    return rag_chain

# --- Streamlit UI Main Logic ---

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# Initialize vector store and RAG chain
if "vector_store" not in st.session_state:
    st.session_state.vector_store = None
if "rag_chain" not in st.session_state:
    st.session_state.rag_chain = None
if "pdf_uploaded" not in st.session_state:
    st.session_state.pdf_uploaded = False


# Sidebar for PDF Upload
with st.sidebar:
    st.header("Upload PDF")
    uploaded_file = st.file_uploader(
        "Choose a PDF file",
        type="pdf",
        accept_multiple_files=False,
        key="pdf_uploader"
    )

    if uploaded_file and not st.session_state.pdf_uploaded:
        # Save the uploaded file temporarily
        with open("temp_doc.pdf", "wb") as f:
            f.write(uploaded_file.getbuffer())
        
        st.session_state.vector_store = create_vector_store("temp_doc.pdf")
        st.session_state.rag_chain = get_rag_chain(st.session_state.vector_store)
        st.session_state.pdf_uploaded = True
        st.success("PDF processed successfully! You can now ask questions.")
        # Clean up temporary file
        os.remove("temp_doc.pdf")
    elif st.session_state.pdf_uploaded:
        st.info("PDF already processed. Ready for questions!")


# Display chat messages from history on app rerun
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# Accept user input
if prompt := st.chat_input("Ask a question about the paper..."):
    if not st.session_state.pdf_uploaded:
        st.warning("Please upload a PDF first to start asking questions.")
    else:
        # Add user message to chat history
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)

        # Get response from RAG chain
        with st.chat_message("assistant"):
            with st.spinner("Thinking..."):
                try:
                    full_response = st.session_state.rag_chain.invoke(prompt) 
                    st.markdown(full_response, unsafe_allow_html=True) 

                    # Add assistant response to chat history
                    st.session_state.messages.append({"role": "assistant", "content": full_response})
                except Exception as e:
                    st.error(f"An error occurred during response generation: {e}")
                    st.warning("Please try again or check your AWS Bedrock access permissions.")

# Optional: Clear chat and uploaded PDF
if st.session_state.pdf_uploaded:
    if st.sidebar.button("Clear Chat and Upload New PDF"):
        st.session_state.messages = []
        st.session_state.vector_store = None
        st.session_state.rag_chain = None
        st.session_state.pdf_uploaded = False
        st.cache_resource.clear() # Clear streamlit caches for a clean slate
        st.rerun()