MathWizard1729 commited on
Commit
41bdb8e
·
verified ·
1 Parent(s): 38cae35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +249 -61
app.py CHANGED
@@ -1,64 +1,252 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
  )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import boto3
3
+ from langchain_community.document_loaders import PyPDFLoader
4
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
5
+ from langchain_aws import BedrockEmbeddings
6
+ # --- CHANGED: Import Qdrant instead of Chroma ---
7
+ from langchain_qdrant import Qdrant
8
+ # --- Optional: If you need direct Qdrant client interaction or for advanced setups ---
9
+ # from qdrant_client import QdrantClient, models
10
+
11
+ from langchain_aws import ChatBedrock
12
+ from langchain.prompts import ChatPromptTemplate
13
+ from langchain.schema import StrOutputParser
14
+ from langchain.schema.runnable import RunnablePassthrough
15
+ import os
16
+ from dotenv import load_dotenv # Import load_dotenv
17
+
18
+ # --- Load Environment Variables ---
19
+ load_dotenv() # This loads variables from .env file
20
+
21
+ # --- Streamlit UI Setup (MUST BE THE FIRST STREAMLIT COMMAND) ---
22
+ st.set_page_config(
23
+ page_title="Math Research Paper RAG Bot",
24
+ page_icon="📚",
25
+ layout="wide"
26
+ )
27
+
28
+ st.title("📚 Math Research Paper RAG Chatbot")
29
+ st.markdown(
30
+ """
31
+ Upload a mathematical research paper (PDF) and ask questions about its content.
32
+ This bot uses Amazon Bedrock (Claude 3 Sonnet for reasoning, Titan Embeddings for vectors)
33
+ and **QdrantDB** for Retrieval-Augmented Generation.
34
+
35
+ **Note:** This application requires AWS credentials (`AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`)
36
+ and region (`AWS_REGION`) to be set up in a `.env` file or environment variables.
37
+ The Qdrant vector store is **in-memory** and will be reset on app restart.
38
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
40
 
41
+ # --- Configuration ---
42
+ # Set AWS region (adjust if needed, loaded from .env or env var)
43
+ AWS_REGION = os.getenv("AWS_REGION")
44
+ if not AWS_REGION:
45
+ st.error("AWS_REGION not found in environment variables or .env file. Please set it.")
46
+ st.stop()
47
+
48
+ # Bedrock model IDs
49
+ EMBEDDING_MODEL_ID = "amazon.titan-embed-text-v1"
50
+ LLM_MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0"
51
+
52
+ # --- Qdrant Specific Configuration ---
53
+ QDRANT_COLLECTION_NAME = "math_research_papers_collection"
54
+ EMBEDDING_DIMENSION = 1536 # Titan Text Embeddings output 1536-dimensional vectors
55
+
56
+ # --- Initialize Bedrock Client (once) ---
57
+ @st.cache_resource
58
+ def get_bedrock_client():
59
+ """Initializes and returns a boto3 Bedrock client.
60
+ Returns: Tuple (boto3_client, success_bool, error_message_str or None)
61
+ """
62
+ try:
63
+ client = boto3.client(
64
+ service_name="bedrock-runtime",
65
+ region_name=AWS_REGION
66
+ )
67
+ return client, True, None # Success: client, True, no error message
68
+ except Exception as e:
69
+ return None, False, str(e) # Failure: None, False, error message
70
+
71
+ # Get the client and check its status
72
+ bedrock_client, bedrock_success, bedrock_error_msg = get_bedrock_client()
73
+
74
+ if not bedrock_success:
75
+ st.error(f"Error connecting to AWS Bedrock. Please check your AWS credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) and region (AWS_REGION) in your .env file or environment variables. Error: {bedrock_error_msg}")
76
+ st.stop() # Stop execution if Bedrock client cannot be initialized
77
+ else:
78
+ st.success(f"Successfully connected to AWS Bedrock in {AWS_REGION}!")
79
+
80
+
81
+ # --- LangChain Components ---
82
+ @st.cache_resource
83
+ def get_embeddings_model(_client): # Prepend underscore to tell Streamlit not to hash
84
+ """Returns the BedrockEmbeddings model."""
85
+ return BedrockEmbeddings(client=_client, model_id=EMBEDDING_MODEL_ID)
86
+
87
+ @st.cache_resource
88
+ def get_llm_model(_client): # Prepend underscore to tell Streamlit not to hash
89
+ """Returns the Bedrock LLM model for Claude 3 Sonnet."""
90
+ return ChatBedrock(
91
+ client=_client,
92
+ model_id=LLM_MODEL_ID,
93
+ streaming=False,
94
+ temperature=0.1,
95
+ model_kwargs={"max_tokens": 4000}
96
+ )
97
+
98
+ # --- PDF Processing and Vector Store Creation ---
99
+ def create_vector_store(pdf_file_path):
100
+ """
101
+ Loads PDF, chunks it contextually for mathematical papers,
102
+ creates embeddings, and stores them in QdrantDB (in-memory).
103
+ """
104
+ with st.spinner("Loading PDF and creating vector store..."):
105
+ # 1. Load PDF
106
+ loader = PyPDFLoader(pdf_file_path)
107
+ pages = loader.load_and_split()
108
+ st.info(f"Loaded {len(pages)} pages from the PDF.")
109
+
110
+ # 2. Contextual Chunking for Mathematical Papers
111
+ text_splitter = RecursiveCharacterTextSplitter(
112
+ chunk_size=1500, # Increased chunk size for math papers
113
+ chunk_overlap=150, # Generous overlap to maintain context
114
+ separators=[
115
+ "\n\n", # Prefer splitting by paragraphs
116
+ "\n", # Then by newlines (might break equations but less likely than fixed char)
117
+ " ", # Then by spaces
118
+ "", # Fallback
119
+ ],
120
+ length_function=len,
121
+ is_separator_regex=False,
122
+ )
123
+ chunks = text_splitter.split_documents(pages)
124
+ st.info(f"Split PDF into {len(chunks)} chunks.")
125
+
126
+ # 3. Create Embeddings and QdrantDB
127
+ embeddings = get_embeddings_model(bedrock_client)
128
+
129
+ # --- CHANGED: Qdrant vector store creation ---
130
+ vector_store = Qdrant.from_documents(
131
+ documents=chunks,
132
+ embedding=embeddings,
133
+ location=":memory:", # Use in-memory Qdrant instance
134
+ collection_name=QDRANT_COLLECTION_NAME,
135
+ # For persistent Qdrant (requires a running Qdrant server):
136
+ # url="http://localhost:6333", # Or your Qdrant Cloud URL
137
+ # api_key="YOUR_QDRANT_CLOUD_API_KEY", # Only for Qdrant Cloud
138
+ # prefer_grpc=True # Set to True if using gRPC for Qdrant Cloud
139
+ # force_recreate=True # Use with caution: deletes existing collection
140
+ )
141
+ # Note: LangChain's Qdrant integration will automatically create the collection
142
+ # if it doesn't exist, inferring vector_size from embeddings.
143
+
144
+ st.success("Vector store created and ready!")
145
+ return vector_store
146
+
147
+ # --- RAG Chain Construction ---
148
+ def get_rag_chain(vector_store):
149
+ """Constructs the RAG chain using LCEL."""
150
+ retriever = vector_store.as_retriever(search_kwargs={"k": 5}) # Retrieve top 5 relevant chunks
151
+ llm = get_llm_model(bedrock_client)
152
+
153
+ # Prompt Template optimized for mathematical research papers
154
+ prompt_template = ChatPromptTemplate.from_messages(
155
+ [
156
+ ("system",
157
+ "You are an expert AI assistant specialized in analyzing and explaining mathematical research papers. "
158
+ "Your goal is to provide precise, accurate, and concise answers based *only* on the provided context from the research paper. "
159
+ "When answering, focus on definitions, theorems, proofs, key mathematical concepts, and experimental results. "
160
+ "If the user asks about a mathematical notation, try to explain its meaning from the context. "
161
+ "If the answer is not found in the context, explicitly state that you cannot find the information within the provided document. "
162
+ "Do not invent information or make assumptions outside the given text.\n\n"
163
+ "Context:\n{context}"),
164
+ ("user", "{question}"),
165
+ ]
166
+ )
167
+
168
+ rag_chain = (
169
+ {"context": retriever, "question": RunnablePassthrough()}
170
+ | prompt_template
171
+ | llm
172
+ | StrOutputParser()
173
+ )
174
+ return rag_chain
175
+
176
+ # --- Streamlit UI Main Logic ---
177
+
178
+ # Initialize chat history
179
+ if "messages" not in st.session_state:
180
+ st.session_state.messages = []
181
+
182
+ # Initialize vector store and RAG chain
183
+ if "vector_store" not in st.session_state:
184
+ st.session_state.vector_store = None
185
+ if "rag_chain" not in st.session_state:
186
+ st.session_state.rag_chain = None
187
+ if "pdf_uploaded" not in st.session_state:
188
+ st.session_state.pdf_uploaded = False
189
+
190
+
191
+ # Sidebar for PDF Upload
192
+ with st.sidebar:
193
+ st.header("Upload PDF")
194
+ uploaded_file = st.file_uploader(
195
+ "Choose a PDF file",
196
+ type="pdf",
197
+ accept_multiple_files=False,
198
+ key="pdf_uploader"
199
+ )
200
+
201
+ if uploaded_file and not st.session_state.pdf_uploaded:
202
+ # Save the uploaded file temporarily
203
+ with open("temp_doc.pdf", "wb") as f:
204
+ f.write(uploaded_file.getbuffer())
205
+
206
+ st.session_state.vector_store = create_vector_store("temp_doc.pdf")
207
+ st.session_state.rag_chain = get_rag_chain(st.session_state.vector_store)
208
+ st.session_state.pdf_uploaded = True
209
+ st.success("PDF processed successfully! You can now ask questions.")
210
+ # Clean up temporary file
211
+ os.remove("temp_doc.pdf")
212
+ elif st.session_state.pdf_uploaded:
213
+ st.info("PDF already processed. Ready for questions!")
214
+
215
+
216
+ # Display chat messages from history on app rerun
217
+ for message in st.session_state.messages:
218
+ with st.chat_message(message["role"]):
219
+ st.markdown(message["content"])
220
+
221
+ # Accept user input
222
+ if prompt := st.chat_input("Ask a question about the paper..."):
223
+ if not st.session_state.pdf_uploaded:
224
+ st.warning("Please upload a PDF first to start asking questions.")
225
+ else:
226
+ # Add user message to chat history
227
+ st.session_state.messages.append({"role": "user", "content": prompt})
228
+ with st.chat_message("user"):
229
+ st.markdown(prompt)
230
+
231
+ # Get response from RAG chain
232
+ with st.chat_message("assistant"):
233
+ with st.spinner("Thinking..."):
234
+ try:
235
+ full_response = st.session_state.rag_chain.invoke(prompt)
236
+ st.markdown(full_response, unsafe_allow_html=True)
237
+
238
+ # Add assistant response to chat history
239
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
240
+ except Exception as e:
241
+ st.error(f"An error occurred during response generation: {e}")
242
+ st.warning("Please try again or check your AWS Bedrock access permissions.")
243
 
244
+ # Optional: Clear chat and uploaded PDF
245
+ if st.session_state.pdf_uploaded:
246
+ if st.sidebar.button("Clear Chat and Upload New PDF"):
247
+ st.session_state.messages = []
248
+ st.session_state.vector_store = None
249
+ st.session_state.rag_chain = None
250
+ st.session_state.pdf_uploaded = False
251
+ st.cache_resource.clear() # Clear streamlit caches for a clean slate
252
+ st.rerun()