Freddolin commited on
Commit
6179a99
·
verified ·
1 Parent(s): 447a00c

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +125 -48
agent.py CHANGED
@@ -10,11 +10,12 @@ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingF
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
12
  from langchain_community.document_loaders import ArxivLoader
13
- from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
  from langchain.tools.retriever import create_retriever_tool
17
- from supabase.client import Client, create_client
 
 
18
 
19
  load_dotenv()
20
 
@@ -111,8 +112,6 @@ def arvix_search(query: str) -> str:
111
  ])
112
  return {"arvix_results": formatted_search_docs}
113
 
114
-
115
-
116
  # load the system prompt from the file
117
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
  system_prompt = f.read()
@@ -120,25 +119,60 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
120
  # System message
121
  sys_msg = SystemMessage(content=system_prompt)
122
 
123
- # build a retriever
124
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
- supabase: Client = create_client(
126
- os.environ.get("SUPABASE_URL"),
127
- os.environ.get("SUPABASE_SERVICE_KEY"))
128
- vector_store = SupabaseVectorStore(
129
- client=supabase,
130
- embedding= embeddings,
131
- table_name="documents",
132
- query_name="match_documents_langchain",
133
- )
134
- create_retriever_tool = create_retriever_tool(
135
- retriever=vector_store.as_retriever(),
136
- name="Question Search",
137
- description="A tool to retrieve similar questions from a vector store.",
138
- )
139
 
 
 
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  tools = [
143
  multiply,
144
  add,
@@ -148,6 +182,7 @@ tools = [
148
  wiki_search,
149
  web_search,
150
  arvix_search,
 
151
  ]
152
 
153
  # Build graph function
@@ -178,39 +213,82 @@ def build_graph(provider: str = "google"):
178
  """Assistant node"""
179
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
 
181
- # def retriever(state: MessagesState):
182
- # """Retriever node"""
183
- # similar_question = vector_store.similarity_search(state["messages"][0].content)
184
- #example_msg = HumanMessage(
185
- # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
- # )
187
- # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
-
189
  from langchain_core.messages import AIMessage
190
 
191
  def retriever(state: MessagesState):
192
  query = state["messages"][-1].content
193
- similar_doc = vector_store.similarity_search(query, k=1)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- content = similar_doc.page_content
196
- if "Final answer :" in content:
197
- answer = content.split("Final answer :")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  else:
199
- answer = content.strip()
200
-
201
- return {"messages": [AIMessage(content=answer)]}
202
-
203
- # builder = StateGraph(MessagesState)
204
- #builder.add_node("retriever", retriever)
205
- #builder.add_node("assistant", assistant)
206
- #builder.add_node("tools", ToolNode(tools))
207
- #builder.add_edge(START, "retriever")
208
- #builder.add_edge("retriever", "assistant")
209
- #builder.add_conditional_edges(
210
- # "assistant",
211
- # tools_condition,
212
- #)
213
- #builder.add_edge("tools", "assistant")
214
 
215
  builder = StateGraph(MessagesState)
216
  builder.add_node("retriever", retriever)
@@ -222,4 +300,3 @@ def build_graph(provider: str = "google"):
222
  # Compile graph
223
  return builder.compile()
224
 
225
-
 
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
12
  from langchain_community.document_loaders import ArxivLoader
 
13
  from langchain_core.messages import SystemMessage, HumanMessage
14
  from langchain_core.tools import tool
15
  from langchain.tools.retriever import create_retriever_tool
16
+ from langchain_community.vectorstores import Chroma # Ny import för Chroma
17
+ from langchain_core.documents import Document # Ny import för att skapa dokument
18
+ import shutil # För att hantera kataloger
19
 
20
  load_dotenv()
21
 
 
112
  ])
113
  return {"arvix_results": formatted_search_docs}
114
 
 
 
115
  # load the system prompt from the file
116
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
117
  system_prompt = f.read()
 
119
  # System message
120
  sys_msg = SystemMessage(content=system_prompt)
121
 
122
+ # --- Start ChromaDB Setup ---
123
+ # Define the directory for ChromaDB persistence
124
+ CHROMA_DB_DIR = "./chroma_db"
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ # Build embeddings (this remains the same)
127
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
128
 
129
+ # Initialize ChromaDB
130
+ # If the directory exists, load the existing vector store.
131
+ # Otherwise, create a new one and add some dummy documents.
132
+ if os.path.exists(CHROMA_DB_DIR) and os.listdir(CHROMA_DB_DIR):
133
+ print(f"Loading existing ChromaDB from {CHROMA_DB_DIR}")
134
+ vector_store = Chroma(
135
+ persist_directory=CHROMA_DB_DIR,
136
+ embedding_function=embeddings
137
+ )
138
+ else:
139
+ print(f"Creating new ChromaDB at {CHROMA_DB_DIR} and adding dummy documents.")
140
+ # Ensure the directory is clean before creating new
141
+ if os.path.exists(CHROMA_DB_DIR):
142
+ shutil.rmtree(CHROMA_DB_DIR)
143
+ os.makedirs(CHROMA_DB_DIR)
144
 
145
+ # Example dummy documents to populate the vector store
146
+ # In a real application, you would load your actual documents here
147
+ documents = [
148
+ Document(page_content="What is the capital of France?", metadata={"source": "internal", "answer": "Paris"}),
149
+ Document(page_content="Who wrote Hamlet?", metadata={"source": "internal", "answer": "William Shakespeare"}),
150
+ Document(page_content="What is the highest mountain in the world?", metadata={"source": "internal", "answer": "Mount Everest"}),
151
+ Document(page_content="When was the internet invented?", metadata={"source": "internal", "answer": "The internet, as we know it, evolved from ARPANET in the late 1960s and early 1970s. The TCP/IP protocol, which forms the basis of the internet, was standardized in 1978."}),
152
+ Document(page_content="What is the square root of 64?", metadata={"source": "internal", "answer": "8"}),
153
+ Document(page_content="Who is the current president of the United States?", metadata={"source": "internal", "answer": "Joe Biden"}),
154
+ Document(page_content="What is the chemical symbol for water?", metadata={"source": "internal", "answer": "H2O"}),
155
+ Document(page_content="What is the largest ocean on Earth?", metadata={"source": "internal", "answer": "Pacific Ocean"}),
156
+ Document(page_content="What is the speed of light?", metadata={"source": "internal", "answer": "Approximately 299,792,458 meters per second in a vacuum."}),
157
+ Document(page_content="What is the capital of Sweden?", metadata={"source": "internal", "answer": "Stockholm"}),
158
+ ]
159
+
160
+ vector_store = Chroma.from_documents(
161
+ documents=documents,
162
+ embedding=embeddings,
163
+ persist_directory=CHROMA_DB_DIR
164
+ )
165
+ vector_store.persist() # Save the new vector store to disk
166
+ print("ChromaDB initialized and persisted with dummy documents.")
167
+
168
+ # Create retriever tool using the Chroma vector store
169
+ retriever_tool = create_retriever_tool( # Changed variable name to avoid conflict with function name
170
+ retriever=vector_store.as_retriever(),
171
+ name="Question_Search", # Changed name to be more descriptive and valid for tool use
172
+ description="A tool to retrieve similar questions from a vector store and their answers.",
173
+ )
174
+
175
+ # Add the new retriever tool to your list of tools
176
  tools = [
177
  multiply,
178
  add,
 
182
  wiki_search,
183
  web_search,
184
  arvix_search,
185
+ retriever_tool, # Add the new retriever tool here
186
  ]
187
 
188
  # Build graph function
 
213
  """Assistant node"""
214
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
215
 
 
 
 
 
 
 
 
 
216
  from langchain_core.messages import AIMessage
217
 
218
  def retriever(state: MessagesState):
219
  query = state["messages"][-1].content
220
+ # Use the retriever tool to get similar documents
221
+ similar_docs = retriever_tool.invoke(query) # Call the tool directly
222
+
223
+ # The tool returns a list of Documents, so we need to process it
224
+ # Assuming the tool returns a list of documents, we take the first one
225
+ if similar_docs:
226
+ # The tool output is a string representation of the documents.
227
+ # We need to parse it or adjust the tool to return structured data.
228
+ # For simplicity, let's assume the tool returns a list of Document objects
229
+ # or a string that can be directly used.
230
+ # Given the original `retriever` node, it expected `similar_question[0].page_content`.
231
+ # If `retriever_tool.invoke(query)` returns a list of Document objects,
232
+ # then `similar_docs[0].page_content` is correct.
233
+ # If it returns a string, we need to adapt.
234
+ # For now, let's assume it returns a list of Documents or a string that contains the answer.
235
+
236
+ # If retriever_tool returns a string directly (as per your tool definition):
237
+ # content = similar_docs # This would be the string output from the tool
238
+
239
+ # If retriever_tool returns a list of Document objects from its internal retriever:
240
+ # Let's assume the `retriever_tool` internally uses `vector_store.as_retriever().invoke(query)`
241
+ # which returns a list of `Document` objects.
242
+ # The `create_retriever_tool` wraps this, so `retriever_tool.invoke` will return a string
243
+ # that is the `page_content` of the retrieved documents.
244
+
245
+ # The original `retriever` node was using `vector_store.similarity_search` directly.
246
+ # Now `retriever_tool` is a LangChain tool.
247
+ # When `retriever_tool.invoke(query)` is called, it will return the formatted string
248
+ # from the `create_retriever_tool` definition.
249
+ # So, `similar_docs` will be a string.
250
+
251
+ # We need to parse the `similar_docs` string to extract the answer.
252
+ # The `Question_Search` tool description is "A tool to retrieve similar questions from a vector store and their answers."
253
+ # The `create_retriever_tool` automatically formats the output of the retriever.
254
+ # Let's assume the output string from `retriever_tool.invoke(query)` will look something like:
255
+ # "content='What is the capital of Sweden?' metadata={'source': 'internal', 'answer': 'Stockholm'}"
256
+ # We need to extract the 'answer' part.
257
 
258
+ # A more robust way would be to make the retriever node *call* the tool,
259
+ # and then the LLM decides if it wants to use the tool.
260
+ # However, your current graph structure has a dedicated "retriever" node
261
+ # that directly fetches and returns an AIMessage.
262
+
263
+ # Let's refine the retriever node to parse the output of the tool more robustly.
264
+ # The `create_retriever_tool` returns a string where documents are joined.
265
+ # We need to extract the content that would be the "answer".
266
+
267
+ # The dummy documents have `metadata={"source": "internal", "answer": "..."}`.
268
+ # The `create_retriever_tool` will return `doc.page_content` by default.
269
+ # So, `similar_docs` will contain the question itself.
270
+ # We need to ensure the retriever provides the *answer* not just the question.
271
+
272
+ # Let's adjust the `retriever` node to directly access the `vector_store`
273
+ # for `similarity_search` and then extract the answer from metadata,
274
+ # similar to your original implementation. This bypasses the tool wrapper
275
+ # for this specific node, ensuring we get the full Document object.
276
+
277
+ similar_doc = vector_store.similarity_search(query, k=1)[0]
278
+
279
+ # Check if an 'answer' is directly available in metadata
280
+ if "answer" in similar_doc.metadata:
281
+ answer = similar_doc.metadata["answer"]
282
+ elif "Final answer :" in similar_doc.page_content:
283
+ answer = similar_doc.page_content.split("Final answer :")[-1].strip()
284
+ else:
285
+ answer = similar_doc.page_content.strip() # Fallback to page_content if no explicit answer
286
+
287
+ return {"messages": [AIMessage(content=answer)]}
288
  else:
289
+ # If no similar documents found, return an empty AIMessage or a message indicating no answer
290
+ return {"messages": [AIMessage(content="No similar questions found in the knowledge base.")]}
291
+
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  builder = StateGraph(MessagesState)
294
  builder.add_node("retriever", retriever)
 
300
  # Compile graph
301
  return builder.compile()
302