Freddolin commited on
Commit
9befc16
·
verified ·
1 Parent(s): 198e1d4

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +67 -103
agent.py CHANGED
@@ -13,9 +13,11 @@ from langchain_community.document_loaders import ArxivLoader
13
  from langchain_core.messages import SystemMessage, HumanMessage
14
  from langchain_core.tools import tool
15
  from langchain.tools.retriever import create_retriever_tool
16
- from langchain_community.vectorstores import Chroma # Ny import för Chroma
17
- from langchain_core.documents import Document # Ny import för att skapa dokument
18
- import shutil # För att hantera kataloger
 
 
19
 
20
  load_dotenv()
21
 
@@ -122,13 +124,14 @@ sys_msg = SystemMessage(content=system_prompt)
122
  # --- Start ChromaDB Setup ---
123
  # Define the directory for ChromaDB persistence
124
  CHROMA_DB_DIR = "./chroma_db"
 
125
 
126
  # Build embeddings (this remains the same)
127
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
128
 
129
  # Initialize ChromaDB
130
- # If the directory exists, load the existing vector store.
131
- # Otherwise, create a new one and add some dummy documents.
132
  if os.path.exists(CHROMA_DB_DIR) and os.listdir(CHROMA_DB_DIR):
133
  print(f"Loading existing ChromaDB from {CHROMA_DB_DIR}")
134
  vector_store = Chroma(
@@ -136,40 +139,63 @@ if os.path.exists(CHROMA_DB_DIR) and os.listdir(CHROMA_DB_DIR):
136
  embedding_function=embeddings
137
  )
138
  else:
139
- print(f"Creating new ChromaDB at {CHROMA_DB_DIR} and adding dummy documents.")
140
  # Ensure the directory is clean before creating new
141
  if os.path.exists(CHROMA_DB_DIR):
142
  shutil.rmtree(CHROMA_DB_DIR)
143
  os.makedirs(CHROMA_DB_DIR)
144
 
145
- # Example dummy documents to populate the vector store
146
- # In a real application, you would load your actual documents here
147
- documents = [
148
- Document(page_content="What is the capital of France?", metadata={"source": "internal", "answer": "Paris"}),
149
- Document(page_content="Who wrote Hamlet?", metadata={"source": "internal", "answer": "William Shakespeare"}),
150
- Document(page_content="What is the highest mountain in the world?", metadata={"source": "internal", "answer": "Mount Everest"}),
151
- Document(page_content="When was the internet invented?", metadata={"source": "internal", "answer": "The internet, as we know it, evolved from ARPANET in the late 1960s and early 1970s. The TCP/IP protocol, which forms the basis of the internet, was standardized in 1978."}),
152
- Document(page_content="What is the square root of 64?", metadata={"source": "internal", "answer": "8"}),
153
- Document(page_content="Who is the current president of the United States?", metadata={"source": "internal", "answer": "Joe Biden"}),
154
- Document(page_content="What is the chemical symbol for water?", metadata={"source": "internal", "answer": "H2O"}),
155
- Document(page_content="What is the largest ocean on Earth?", metadata={"source": "internal", "answer": "Pacific Ocean"}),
156
- Document(page_content="What is the speed of light?", metadata={"source": "internal", "answer": "Approximately 299,792,458 meters per second in a vacuum."}),
157
- Document(page_content="What is the capital of Sweden?", metadata={"source": "internal", "answer": "Stockholm"}),
158
- ]
159
 
160
- vector_store = Chroma.from_documents(
161
- documents=documents,
162
- embedding=embeddings,
163
- persist_directory=CHROMA_DB_DIR
164
- )
165
- vector_store.persist() # Save the new vector store to disk
166
- print("ChromaDB initialized and persisted with dummy documents.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  # Create retriever tool using the Chroma vector store
169
- retriever_tool = create_retriever_tool( # Changed variable name to avoid conflict with function name
170
  retriever=vector_store.as_retriever(),
171
- name="Question_Search", # Changed name to be more descriptive and valid for tool use
172
- description="A tool to retrieve similar questions from a vector store and their answers.",
173
  )
174
 
175
  # Add the new retriever tool to your list of tools
@@ -182,21 +208,17 @@ tools = [
182
  wiki_search,
183
  web_search,
184
  arvix_search,
185
- retriever_tool, # Add the new retriever tool here
186
  ]
187
 
188
  # Build graph function
189
  def build_graph(provider: str = "google"):
190
  """Build the graph"""
191
- # Load environment variables from .env file
192
  if provider == "google":
193
- # Google Gemini
194
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
195
  elif provider == "groq":
196
- # Groq https://console.groq.com/docs/models
197
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
198
  elif provider == "huggingface":
199
- # TODO: Add huggingface endpoint
200
  llm = ChatHuggingFace(
201
  llm=HuggingFaceEndpoint(
202
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
@@ -205,10 +227,9 @@ def build_graph(provider: str = "google"):
205
  )
206
  else:
207
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
208
- # Bind tools to LLM
209
  llm_with_tools = llm.bind_tools(tools)
210
 
211
- # Node
212
  def assistant(state: MessagesState):
213
  """Assistant node"""
214
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
@@ -217,86 +238,29 @@ def build_graph(provider: str = "google"):
217
 
218
  def retriever(state: MessagesState):
219
  query = state["messages"][-1].content
220
- # Use the retriever tool to get similar documents
221
- similar_docs = retriever_tool.invoke(query) # Call the tool directly
222
 
223
- # The tool returns a list of Documents, so we need to process it
224
- # Assuming the tool returns a list of documents, we take the first one
225
  if similar_docs:
226
- # The tool output is a string representation of the documents.
227
- # We need to parse it or adjust the tool to return structured data.
228
- # For simplicity, let's assume the tool returns a list of Document objects
229
- # or a string that can be directly used.
230
- # Given the original `retriever` node, it expected `similar_question[0].page_content`.
231
- # If `retriever_tool.invoke(query)` returns a list of Document objects,
232
- # then `similar_docs[0].page_content` is correct.
233
- # If it returns a string, we need to adapt.
234
- # For now, let's assume it returns a list of Documents or a string that contains the answer.
235
-
236
- # If retriever_tool returns a string directly (as per your tool definition):
237
- # content = similar_docs # This would be the string output from the tool
238
-
239
- # If retriever_tool returns a list of Document objects from its internal retriever:
240
- # Let's assume the `retriever_tool` internally uses `vector_store.as_retriever().invoke(query)`
241
- # which returns a list of `Document` objects.
242
- # The `create_retriever_tool` wraps this, so `retriever_tool.invoke` will return a string
243
- # that is the `page_content` of the retrieved documents.
244
-
245
- # The original `retriever` node was using `vector_store.similarity_search` directly.
246
- # Now `retriever_tool` is a LangChain tool.
247
- # When `retriever_tool.invoke(query)` is called, it will return the formatted string
248
- # from the `create_retriever_tool` definition.
249
- # So, `similar_docs` will be a string.
250
-
251
- # We need to parse the `similar_docs` string to extract the answer.
252
- # The `Question_Search` tool description is "A tool to retrieve similar questions from a vector store and their answers."
253
- # The `create_retriever_tool` automatically formats the output of the retriever.
254
- # Let's assume the output string from `retriever_tool.invoke(query)` will look something like:
255
- # "content='What is the capital of Sweden?' metadata={'source': 'internal', 'answer': 'Stockholm'}"
256
- # We need to extract the 'answer' part.
257
-
258
- # A more robust way would be to make the retriever node *call* the tool,
259
- # and then the LLM decides if it wants to use the tool.
260
- # However, your current graph structure has a dedicated "retriever" node
261
- # that directly fetches and returns an AIMessage.
262
-
263
- # Let's refine the retriever node to parse the output of the tool more robustly.
264
- # The `create_retriever_tool` returns a string where documents are joined.
265
- # We need to extract the content that would be the "answer".
266
-
267
- # The dummy documents have `metadata={"source": "internal", "answer": "..."}`.
268
- # The `create_retriever_tool` will return `doc.page_content` by default.
269
- # So, `similar_docs` will contain the question itself.
270
- # We need to ensure the retriever provides the *answer* not just the question.
271
-
272
- # Let's adjust the `retriever` node to directly access the `vector_store`
273
- # for `similarity_search` and then extract the answer from metadata,
274
- # similar to your original implementation. This bypasses the tool wrapper
275
- # for this specific node, ensuring we get the full Document object.
276
-
277
- similar_doc = vector_store.similarity_search(query, k=1)[0]
278
-
279
- # Check if an 'answer' is directly available in metadata
280
- if "answer" in similar_doc.metadata:
281
- answer = similar_doc.metadata["answer"]
282
  elif "Final answer :" in similar_doc.page_content:
283
  answer = similar_doc.page_content.split("Final answer :")[-1].strip()
284
  else:
285
  answer = similar_doc.page_content.strip() # Fallback to page_content if no explicit answer
286
 
 
 
287
  return {"messages": [AIMessage(content=answer)]}
288
  else:
289
- # If no similar documents found, return an empty AIMessage or a message indicating no answer
290
  return {"messages": [AIMessage(content="No similar questions found in the knowledge base.")]}
291
 
292
-
293
  builder = StateGraph(MessagesState)
294
  builder.add_node("retriever", retriever)
295
-
296
- # Retriever ist Start und Endpunkt
297
  builder.set_entry_point("retriever")
298
  builder.set_finish_point("retriever")
299
 
300
- # Compile graph
301
  return builder.compile()
302
 
 
13
  from langchain_core.messages import SystemMessage, HumanMessage
14
  from langchain_core.tools import tool
15
  from langchain.tools.retriever import create_retriever_tool
16
+ from langchain_community.vectorstores import Chroma
17
+ from langchain_core.documents import Document
18
+ import shutil
19
+ import pandas as pd # Ny import för pandas
20
+ import json # För att parsa metadata-kolumnen
21
 
22
  load_dotenv()
23
 
 
124
  # --- Start ChromaDB Setup ---
125
  # Define the directory for ChromaDB persistence
126
  CHROMA_DB_DIR = "./chroma_db"
127
+ CSV_FILE_PATH = "./supabase.docs.csv" # Path to your CSV file
128
 
129
  # Build embeddings (this remains the same)
130
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
131
 
132
  # Initialize ChromaDB
133
+ # If the directory exists and contains data, load the existing vector store.
134
+ # Otherwise, create a new one and add documents from the CSV file.
135
  if os.path.exists(CHROMA_DB_DIR) and os.listdir(CHROMA_DB_DIR):
136
  print(f"Loading existing ChromaDB from {CHROMA_DB_DIR}")
137
  vector_store = Chroma(
 
139
  embedding_function=embeddings
140
  )
141
  else:
142
+ print(f"Creating new ChromaDB at {CHROMA_DB_DIR} and loading documents from {CSV_FILE_PATH}.")
143
  # Ensure the directory is clean before creating new
144
  if os.path.exists(CHROMA_DB_DIR):
145
  shutil.rmtree(CHROMA_DB_DIR)
146
  os.makedirs(CHROMA_DB_DIR)
147
 
148
+ # Load data from the CSV file
149
+ if not os.path.exists(CSV_FILE_PATH):
150
+ raise FileNotFoundError(f"CSV file not found at {CSV_FILE_PATH}. Please ensure it's in the root directory.")
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ df = pd.read_csv(CSV_FILE_PATH)
153
+ documents = []
154
+ for index, row in df.iterrows():
155
+ content = row["content"]
156
+
157
+ # Extract the question part from the content
158
+ # Assuming the question is everything before "Final answer :"
159
+ question_part = content.split("Final answer :")[0].strip()
160
+
161
+ # Extract the final answer part from the content
162
+ final_answer_part = content.split("Final answer :")[-1].strip() if "Final answer :" in content else ""
163
+
164
+ # Parse the metadata string into a dictionary
165
+ # The metadata column might be stored as a string representation of a dictionary
166
+ try:
167
+ metadata = json.loads(row["metadata"].replace("'", "\"")) # Replace single quotes for valid JSON
168
+ except json.JSONDecodeError:
169
+ metadata = {} # Fallback if parsing fails
170
+
171
+ # Add the extracted final answer to the metadata for easy retrieval
172
+ metadata["final_answer"] = final_answer_part
173
+
174
+ # Create a Document object. The page_content should be the question for similarity search.
175
+ # The answer will be in metadata.
176
+ documents.append(Document(page_content=question_part, metadata=metadata))
177
+
178
+ if not documents:
179
+ print("No documents loaded from CSV. ChromaDB will be empty.")
180
+ # Create an empty ChromaDB if no documents are found
181
+ vector_store = Chroma(
182
+ persist_directory=CHROMA_DB_DIR,
183
+ embedding_function=embeddings
184
+ )
185
+ else:
186
+ vector_store = Chroma.from_documents(
187
+ documents=documents,
188
+ embedding=embeddings,
189
+ persist_directory=CHROMA_DB_DIR
190
+ )
191
+ vector_store.persist() # Save the new vector store to disk
192
+ print(f"ChromaDB initialized and persisted with {len(documents)} documents from CSV.")
193
 
194
  # Create retriever tool using the Chroma vector store
195
+ retriever_tool = create_retriever_tool(
196
  retriever=vector_store.as_retriever(),
197
+ name="Question_Search",
198
+ description="A tool to retrieve similar questions from a vector store. The retrieved document's metadata contains the 'final_answer' to the question.",
199
  )
200
 
201
  # Add the new retriever tool to your list of tools
 
208
  wiki_search,
209
  web_search,
210
  arvix_search,
211
+ retriever_tool,
212
  ]
213
 
214
  # Build graph function
215
  def build_graph(provider: str = "google"):
216
  """Build the graph"""
 
217
  if provider == "google":
 
218
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
219
  elif provider == "groq":
220
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
 
221
  elif provider == "huggingface":
 
222
  llm = ChatHuggingFace(
223
  llm=HuggingFaceEndpoint(
224
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
 
227
  )
228
  else:
229
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
230
+
231
  llm_with_tools = llm.bind_tools(tools)
232
 
 
233
  def assistant(state: MessagesState):
234
  """Assistant node"""
235
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
238
 
239
  def retriever(state: MessagesState):
240
  query = state["messages"][-1].content
241
+ # Use the vector_store directly for similarity search to get the full Document object
242
+ similar_docs = vector_store.similarity_search(query, k=1)
243
 
 
 
244
  if similar_docs:
245
+ similar_doc = similar_docs[0]
246
+ # Prioritize 'final_answer' from metadata, then check page_content
247
+ if "final_answer" in similar_doc.metadata and similar_doc.metadata["final_answer"]:
248
+ answer = similar_doc.metadata["final_answer"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  elif "Final answer :" in similar_doc.page_content:
250
  answer = similar_doc.page_content.split("Final answer :")[-1].strip()
251
  else:
252
  answer = similar_doc.page_content.strip() # Fallback to page_content if no explicit answer
253
 
254
+ # The system prompt expects "FINAL ANSWER: [ANSWER]".
255
+ # We should return the extracted answer directly, as the prompt handles the formatting.
256
  return {"messages": [AIMessage(content=answer)]}
257
  else:
 
258
  return {"messages": [AIMessage(content="No similar questions found in the knowledge base.")]}
259
 
 
260
  builder = StateGraph(MessagesState)
261
  builder.add_node("retriever", retriever)
 
 
262
  builder.set_entry_point("retriever")
263
  builder.set_finish_point("retriever")
264
 
 
265
  return builder.compile()
266