MathWizard1729 commited on
Commit
6c67e85
·
verified ·
1 Parent(s): 9ce288c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -210
app.py CHANGED
@@ -1,217 +1,239 @@
1
- import os
2
- import logging
3
- import gradio as gr
4
- from dotenv import load_dotenv
5
- from langchain_aws import ChatBedrock
6
  from langchain_chroma import Chroma
7
- from langchain_community.embeddings import BedrockEmbeddings
8
- from langchain_core.prompts import ChatPromptTemplate
9
- from langchain_core.runnables import RunnablePassthrough
10
- from langchain_core.output_parsers import StrOutputParser
11
- from botocore.exceptions import ClientError
12
- from indexer import index_uploaded_pdfs
13
-
14
- # Set up logging
15
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
- logger = logging.getLogger(__name__)
17
-
18
- # Global variables to store state
19
- vector_store = None
20
- indexing_status = None
21
- mode = "General Chat"
22
- chat_history = []
23
-
24
- def load_environment():
25
- """Load environment variables from .env file or system environment."""
26
- load_dotenv()
27
- required_vars = ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'AWS_REGION']
28
- for var in required_vars:
29
- if not os.getenv(var):
30
- logger.error(f"Missing environment variable: {var}")
31
- raise ValueError(f"Missing environment variable: {var}")
32
- logger.info("Environment variables loaded successfully")
33
-
34
- def initialize_embeddings():
35
- """Initialize Amazon Bedrock embeddings."""
36
- try:
37
- embeddings = BedrockEmbeddings(
38
- model_id="amazon.titan-embed-text-v1",
39
- region_name=os.getenv("AWS_REGION")
40
- )
41
- logger.info("Initialized Bedrock embeddings")
42
- return embeddings
43
- except ClientError as e:
44
- logger.error(f"Error initializing Bedrock embeddings: {str(e)}")
45
- raise
46
-
47
- def initialize_vector_store(db_directory="./chroma_db", collection_name="pdf_rag"):
48
- """Initialize Chroma vector store."""
 
 
 
 
 
49
  try:
50
- embeddings = initialize_embeddings()
51
- vector_store = Chroma(
52
- collection_name=collection_name,
53
- embedding_function=embeddings,
54
- persist_directory=db_directory
55
  )
56
- logger.info(f"Initialized Chroma vector store from {db_directory}")
57
- return vector_store
 
 
58
  except Exception as e:
59
- logger.error(f"Error initializing Chroma vector store: {str(e)}")
60
- raise
61
-
62
- def initialize_llm():
63
- """Initialize Anthropic Claude model via Bedrock."""
64
- try:
65
- llm = ChatBedrock(
66
- model_id="anthropic.claude-3-5-sonnet-20240620-v1:0",
67
- region_name=os.getenv("AWS_REGION"),
68
- model_kwargs={"max_tokens": 1000}
69
- )
70
- logger.info("Initialized Claude 3.5 Sonnet model")
71
- return llm
72
- except ClientError as e:
73
- logger.error(f"Error initializing Claude model: {str(e)}")
74
- raise
75
-
76
- def create_rag_chain(vector_store, llm):
77
- """Create RAG chain with vector store and LLM."""
78
- try:
79
- retriever = vector_store.as_retriever(search_kwargs={"k": 3})
80
- prompt_template = """You are a helpful assistant. Use the following context to answer the user's question, focusing on extracting relevant skills or information if applicable.
81
- If you don't know the answer, say so, but try to provide a helpful response based on the context.
82
-
83
- Context:
84
- {context}
85
-
86
- Question: {question}
87
-
88
- Answer:
89
- """
90
- prompt = ChatPromptTemplate.from_template(prompt_template)
91
- rag_chain = (
92
- {"context": retriever | (lambda docs: "\n\n".join(doc.page_content for doc in docs)), "question": RunnablePassthrough()}
93
- | prompt
94
- | llm
95
- | StrOutputParser()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
- logger.info("Initialized RAG chain")
98
- return rag_chain
99
- except Exception as e:
100
- logger.error(f"Error creating RAG chain: {str(e)}")
101
- raise
102
-
103
- def create_general_chat_chain(llm):
104
- """Create a general chat chain without RAG."""
105
- try:
106
- prompt_template = """You are a helpful assistant. Answer the user's question to the best of your knowledge.
107
-
108
- Question: {question}
109
-
110
- Answer:
111
- """
112
- prompt = ChatPromptTemplate.from_template(prompt_template)
113
- chat_chain = (
114
- {"question": RunnablePassthrough()}
115
- | prompt
116
- | llm
117
- | StrOutputParser()
118
  )
119
- logger.info("Initialized general chat chain")
120
- return chat_chain
121
- except Exception as e:
122
- logger.error(f"Error creating general chat chain: {str(e)}")
123
- raise
124
-
125
- def handle_pdf_upload(uploaded_files):
126
- """Handle PDF uploads and index them."""
127
- global vector_store, indexing_status, mode
128
- if uploaded_files:
129
- try:
130
- vector_store, indexing_status = index_uploaded_pdfs(uploaded_files)
131
- if indexing_status["pdf_count"] > 0:
132
- mode = "PDF RAG"
133
- return (
134
- f"Indexed {indexing_status['pdf_count']} PDFs, "
135
- f"{indexing_status['page_count']} pages, "
136
- f"{indexing_status['chunk_count']} chunks. "
137
- f"Database stored at {indexing_status['db_location']}.\n\nMode switched to: {mode}"
138
- )
139
- else:
140
- mode = "General Chat"
141
- return "No PDFs were indexed. Please upload valid PDF files.\n\nMode remains: General Chat"
142
- except Exception as e:
143
- logger.error(f"Error indexing PDFs: {str(e)}")
144
- return f"Error indexing PDFs: {str(e)}\n\nMode remains: General Chat"
145
- return "No PDFs uploaded.\n\nMode remains: General Chat"
146
-
147
- def chat(message, history):
148
- """Handle chat interactions."""
149
- global vector_store, mode, chat_history
150
- try:
151
- # Initialize LLM
152
- llm = initialize_llm()
153
-
154
- # Select appropriate chain
155
- if vector_store and mode == "PDF RAG":
156
- chain = create_rag_chain(vector_store, llm)
157
- else:
158
- chain = create_general_chat_chain(llm)
159
-
160
- # Update chat history
161
- chat_history = history or []
162
- chat_history.append(("user", message))
163
-
164
- # Get response
165
- response = chain.invoke(message)
166
-
167
- # Update chat history
168
- chat_history.append(("assistant", response))
169
-
170
- # Format history for Gradio
171
- formatted_history = []
172
- for role, content in chat_history:
173
- if role == "user":
174
- formatted_history.append((content, None))
175
- else:
176
- formatted_history.append((None, content))
177
-
178
- return formatted_history, response
179
- except Exception as e:
180
- logger.error(f"Error generating response: {str(e)}")
181
- return chat_history, f"Error generating response: {str(e)}"
182
-
183
- def main():
184
- """Main function to create Gradio interface."""
185
- try:
186
- # Load environment
187
- load_environment()
188
-
189
- # Gradio interface
190
- with gr.Blocks(title="Chatbot with Optional PDF Upload") as demo:
191
- gr.Markdown("# Chatbot with Optional PDF Upload")
192
- gr.Markdown("Chat with the bot directly or upload PDFs to enable RAG-based queries (e.g., extracting skills).")
193
-
194
- # PDF uploader
195
- pdf_input = gr.Files(label="Upload PDF files (optional)", file_types=[".pdf"])
196
-
197
- # Indexing status display
198
- indexing_output = gr.Textbox(label="Indexing Status", value=f"Current Mode: {mode}")
199
-
200
- # Chat interface
201
- chatbot = gr.Chatbot(label="Chat")
202
- msg = gr.Textbox(label="Your Question", placeholder="Ask a question...")
203
- clear = gr.Button("Clear Chat")
204
-
205
- # Event handlers
206
- pdf_input.upload(handle_pdf_upload, inputs=pdf_input, outputs=indexing_output)
207
- msg.submit(chat, inputs=[msg, chatbot], outputs=[chatbot, msg])
208
- clear.click(lambda: ([], "Chat cleared.\n\nCurrent Mode: " + mode), None, [chatbot, msg])
209
-
210
- return demo
211
- except Exception as e:
212
- logger.error(f"Gradio interface initialization failed: {str(e)}")
213
- raise
214
 
215
- if __name__ == "__main__":
216
- demo = main()
217
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import boto3
3
+ from langchain_community.document_loaders import PyPDFLoader
4
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
5
+ from langchain_aws import BedrockEmbeddings
6
  from langchain_chroma import Chroma
7
+ from langchain_aws import ChatBedrock
8
+ from langchain.prompts import ChatPromptTemplate
9
+ from langchain.schema import StrOutputParser
10
+ from langchain.schema.runnable import RunnablePassthrough
11
+ import os
12
+ from dotenv import load_dotenv # Import load_dotenv
13
+
14
+ # --- Load Environment Variables ---
15
+ load_dotenv() # This loads variables from .env file
16
+
17
+ # --- Streamlit UI Setup (MUST BE THE FIRST STREAMLIT COMMAND) ---
18
+ st.set_page_config(
19
+ page_title="Math Research Paper RAG Bot",
20
+ page_icon="📚",
21
+ layout="wide"
22
+ )
23
+
24
+ st.title("📚 Math Research Paper RAG Chatbot")
25
+ st.markdown(
26
+ """
27
+ Upload a mathematical research paper (PDF) and ask questions about its content.
28
+ This bot uses Amazon Bedrock (Claude 3 Sonnet for reasoning, Titan Embeddings for vectors)
29
+ and ChromaDB for Retrieval-Augmented Generation.
30
+
31
+ **Note:** This application requires AWS credentials (`AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`)
32
+ and region (`AWS_REGION`) to be set up in a `.env` file or environment variables.
33
+ """
34
+ )
35
+
36
+ # --- Configuration ---
37
+ # Set AWS region (adjust if needed, loaded from .env or env var)
38
+ AWS_REGION = os.getenv("AWS_REGION")
39
+ if not AWS_REGION:
40
+ st.error("AWS_REGION not found in environment variables or .env file. Please set it.")
41
+ st.stop()
42
+
43
+ # Bedrock model IDs
44
+ EMBEDDING_MODEL_ID = "amazon.titan-embed-text-v1"
45
+ # Claude 4 is not generally available via Bedrock. Using Claude 3 Sonnet.
46
+ LLM_MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0"
47
+
48
+ # --- Initialize Bedrock Client (once) ---
49
+ @st.cache_resource
50
+ def get_bedrock_client():
51
+ """Initializes and returns a boto3 Bedrock client.
52
+ Returns: Tuple (boto3_client, success_bool, error_message_str or None)
53
+ """
54
  try:
55
+ client = boto3.client(
56
+ service_name="bedrock-runtime",
57
+ region_name=AWS_REGION
 
 
58
  )
59
+ # Optional: Verify credentials by trying a simple API call.
60
+ # This will raise an exception if permissions/credentials are wrong.
61
+ # client.list_foundation_models(byOutputModality='TEXT')
62
+ return client, True, None # Success: client, True, no error message
63
  except Exception as e:
64
+ return None, False, str(e) # Failure: None, False, error message
65
+
66
+ # Get the client and check its status
67
+ bedrock_client, bedrock_success, bedrock_error_msg = get_bedrock_client()
68
+
69
+ if not bedrock_success:
70
+ st.error(f"Error connecting to AWS Bedrock. Please check your AWS credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) and region (AWS_REGION) in your .env file or environment variables. Error: {bedrock_error_msg}")
71
+ st.stop() # Stop execution if Bedrock client cannot be initialized
72
+ else:
73
+ st.success(f"Successfully connected to AWS Bedrock in {AWS_REGION}!")
74
+
75
+
76
+ # --- LangChain Components ---
77
+ @st.cache_resource
78
+ def get_embeddings_model(_client): # Prepend underscore to tell Streamlit not to hash
79
+ """Returns the BedrockEmbeddings model."""
80
+ return BedrockEmbeddings(client=_client, model_id=EMBEDDING_MODEL_ID)
81
+
82
+ @st.cache_resource
83
+ def get_llm_model(_client): # Prepend underscore to tell Streamlit not to hash
84
+ """Returns the Bedrock LLM model for Claude 3 Sonnet."""
85
+ return ChatBedrock(
86
+ client=_client,
87
+ model_id=LLM_MODEL_ID,
88
+ streaming=False, # <--- CHANGED: Set streaming to False
89
+ temperature=0.1, # Lower temperature for factual accuracy in research
90
+ model_kwargs={"max_tokens": 4000} # Claude 3 can handle larger outputs
91
+ )
92
+
93
+ # --- PDF Processing and Vector Store Creation ---
94
+ def create_vector_store(pdf_file_path):
95
+ """
96
+ Loads PDF, chunks it contextually for mathematical papers,
97
+ creates embeddings, and stores them in ChromaDB.
98
+ """
99
+ with st.spinner("Loading PDF and creating vector store..."):
100
+ # 1. Load PDF
101
+ loader = PyPDFLoader(pdf_file_path)
102
+ pages = loader.load_and_split()
103
+ st.info(f"Loaded {len(pages)} pages from the PDF.")
104
+
105
+ # 2. Contextual Chunking for Mathematical Papers
106
+ text_splitter = RecursiveCharacterTextSplitter(
107
+ chunk_size=1500, # Increased chunk size for math papers
108
+ chunk_overlap=150, # Generous overlap to maintain context
109
+ separators=[
110
+ "\n\n", # Prefer splitting by paragraphs
111
+ "\n", # Then by newlines (might break equations but less likely than fixed char)
112
+ " ", # Then by spaces
113
+ "", # Fallback
114
+ ],
115
+ length_function=len,
116
+ is_separator_regex=False,
117
  )
118
+ chunks = text_splitter.split_documents(pages)
119
+ st.info(f"Split PDF into {len(chunks)} chunks.")
120
+
121
+ # 3. Create Embeddings and ChromaDB
122
+ # Pass the bedrock_client to the cached embedding model function
123
+ embeddings = get_embeddings_model(bedrock_client)
124
+ vector_store = Chroma.from_documents(
125
+ documents=chunks,
126
+ embedding=embeddings,
127
+ persist_directory="./chroma_db" # Persist for faster reloads (optional)
 
 
 
 
 
 
 
 
 
 
 
128
  )
129
+ st.success("Vector store created and ready!")
130
+ return vector_store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ # --- RAG Chain Construction ---
133
+ def get_rag_chain(vector_store):
134
+ """Constructs the RAG chain using LCEL."""
135
+ retriever = vector_store.as_retriever(search_kwargs={"k": 5}) # Retrieve top 5 relevant chunks
136
+ # Pass the bedrock_client to the cached LLM model function
137
+ llm = get_llm_model(bedrock_client)
138
+
139
+ # Prompt Template optimized for mathematical research papers
140
+ prompt_template = ChatPromptTemplate.from_messages(
141
+ [
142
+ ("system",
143
+ "You are an expert AI assistant specialized in analyzing and explaining mathematical research papers. "
144
+ "Your goal is to provide precise, accurate, and concise answers based *only* on the provided context from the research paper. "
145
+ "When answering, focus on definitions, theorems, proofs, key mathematical concepts, and experimental results. "
146
+ "If the user asks about a mathematical notation, try to explain its meaning from the context. "
147
+ "If the answer is not found in the context, explicitly state that you cannot find the information within the provided document. "
148
+ "Do not invent information or make assumptions outside the given text.\n\n"
149
+ "Context:\n{context}"),
150
+ ("user", "{question}"),
151
+ ]
152
+ )
153
+
154
+ rag_chain = (
155
+ {"context": retriever, "question": RunnablePassthrough()}
156
+ | prompt_template
157
+ | llm
158
+ | StrOutputParser()
159
+ )
160
+ return rag_chain
161
+
162
+ # --- Streamlit UI Main Logic ---
163
+
164
+ # Initialize chat history
165
+ if "messages" not in st.session_state:
166
+ st.session_state.messages = []
167
+
168
+ # Initialize vector store and RAG chain
169
+ if "vector_store" not in st.session_state:
170
+ st.session_state.vector_store = None
171
+ if "rag_chain" not in st.session_state:
172
+ st.session_state.rag_chain = None
173
+ if "pdf_uploaded" not in st.session_state:
174
+ st.session_state.pdf_uploaded = False
175
+
176
+
177
+ # Sidebar for PDF Upload
178
+ with st.sidebar:
179
+ st.header("Upload PDF")
180
+ uploaded_file = st.file_uploader(
181
+ "Choose a PDF file",
182
+ type="pdf",
183
+ accept_multiple_files=False,
184
+ key="pdf_uploader"
185
+ )
186
+
187
+ if uploaded_file and not st.session_state.pdf_uploaded:
188
+ # Save the uploaded file temporarily
189
+ with open("temp_doc.pdf", "wb") as f:
190
+ f.write(uploaded_file.getbuffer())
191
+
192
+ st.session_state.vector_store = create_vector_store("temp_doc.pdf")
193
+ st.session_state.rag_chain = get_rag_chain(st.session_state.vector_store)
194
+ st.session_state.pdf_uploaded = True
195
+ st.success("PDF processed successfully! You can now ask questions.")
196
+ # Clean up temporary file
197
+ os.remove("temp_doc.pdf")
198
+ elif st.session_state.pdf_uploaded:
199
+ st.info("PDF already processed. Ready for questions!")
200
+
201
+
202
+ # Display chat messages from history on app rerun
203
+ for message in st.session_state.messages:
204
+ with st.chat_message(message["role"]):
205
+ st.markdown(message["content"])
206
+
207
+ # Accept user input
208
+ if prompt := st.chat_input("Ask a question about the paper..."):
209
+ if not st.session_state.pdf_uploaded:
210
+ st.warning("Please upload a PDF first to start asking questions.")
211
+ else:
212
+ # Add user message to chat history
213
+ st.session_state.messages.append({"role": "user", "content": prompt})
214
+ with st.chat_message("user"):
215
+ st.markdown(prompt)
216
+
217
+ # Get response from RAG chain
218
+ with st.chat_message("assistant"):
219
+ with st.spinner("Thinking..."):
220
+ try:
221
+ # <--- CHANGED: Use invoke() instead of stream()
222
+ full_response = st.session_state.rag_chain.invoke(prompt)
223
+ st.markdown(full_response, unsafe_allow_html=True)
224
+
225
+ # Add assistant response to chat history
226
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
227
+ except Exception as e:
228
+ st.error(f"An error occurred during response generation: {e}")
229
+ st.warning("Please try again or check your AWS Bedrock access permissions.")
230
+
231
+ # Optional: Clear chat and uploaded PDF
232
+ if st.session_state.pdf_uploaded:
233
+ if st.sidebar.button("Clear Chat and Upload New PDF"):
234
+ st.session_state.messages = []
235
+ st.session_state.vector_store = None
236
+ st.session_state.rag_chain = None
237
+ st.session_state.pdf_uploaded = False
238
+ st.cache_resource.clear() # Clear streamlit caches for a clean slate
239
+ st.rerun()