LamiaYT commited on
Commit
1204ffb
·
1 Parent(s): 60f9f04
Files changed (1) hide show
  1. agent.py +17 -52
agent.py CHANGED
@@ -2,13 +2,10 @@ import os
2
  import json
3
  from dotenv import load_dotenv
4
 
5
- # ---- Environment & Setup ----
6
  load_dotenv()
7
  os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
8
-
9
  hf_token = os.getenv("HUGGINGFACE_INFERENCE_TOKEN")
10
 
11
- # ---- Imports ----
12
  from langgraph.graph import START, StateGraph, MessagesState
13
  from langgraph.prebuilt import tools_condition, ToolNode
14
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
@@ -20,7 +17,7 @@ from langchain_core.messages import SystemMessage, HumanMessage
20
  from langchain_core.tools import tool
21
  from langchain.schema import Document
22
 
23
- # ---- Tools ----
24
 
25
  @tool
26
  def multiply(a: int, b: int) -> int:
@@ -51,87 +48,58 @@ def modulus(a: int, b: int) -> int:
51
 
52
  @tool
53
  def wiki_search(query: str) -> str:
54
- """Search Wikipedia for the given query and return formatted documents."""
55
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
56
  formatted = "\n\n---\n\n".join(
57
- [
58
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
59
- for doc in search_docs
60
- ]
61
  )
62
  return {"wiki_results": formatted}
63
 
64
  @tool
65
  def web_search(query: str) -> str:
66
- """Search the web using Tavily API for the given query."""
67
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
68
  formatted = "\n\n---\n\n".join(
69
- [
70
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
71
- for doc in search_docs
72
- ]
73
  )
74
  return {"web_results": formatted}
75
 
76
  @tool
77
  def arvix_search(query: str) -> str:
78
- """Search Arxiv for academic papers related to the query."""
79
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
80
  formatted = "\n\n---\n\n".join(
81
- [
82
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
83
- for doc in search_docs
84
- ]
85
  )
86
  return {"arvix_results": formatted}
87
 
88
- @tool
89
- def similar_question_search(query: str) -> str:
90
- """Searches for questions similar to the input query using a vector database."""
91
- matched_docs = vector_store.similarity_search(query, 3)
92
- formatted = "\n\n---\n\n".join(
93
- [
94
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
95
- for doc in matched_docs
96
- ]
97
- )
98
- return {"similar_questions": formatted}
99
-
100
-
101
- # ---- Embedding & Vector Store ----
102
-
103
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
104
-
105
- json_QA = []
106
- with open('metadata.jsonl', 'r') as jsonl_file:
107
- for line in jsonl_file:
108
- json_QA.append(json.loads(line))
109
-
110
  documents = [
111
  Document(
112
  page_content=f"Question : {sample['Question']}\n\nFinal answer : {sample['Final answer']}",
113
  metadata={"source": sample["task_id"]}
114
- )
115
- for sample in json_QA
116
  ]
117
-
118
  vector_store = Chroma.from_documents(
119
  documents=documents,
120
  embedding=embeddings,
121
  persist_directory="./chroma_db",
122
  collection_name="my_collection"
123
  )
124
- vector_store.persist()
125
  print("Documents inserted:", vector_store._collection.count())
126
 
127
  @tool
128
  def similar_question_search(query: str) -> str:
 
129
  matched_docs = vector_store.similarity_search(query, 3)
130
  formatted = "\n\n---\n\n".join(
131
- [
132
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
133
- for doc in matched_docs
134
- ]
135
  )
136
  return {"similar_questions": formatted}
137
 
@@ -143,17 +111,14 @@ Now, I will ask you a question. Report your thoughts, and finish your answer wit
143
  FINAL ANSWER: [YOUR FINAL ANSWER].
144
  YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings...
145
  """
146
-
147
  sys_msg = SystemMessage(content=system_prompt)
148
 
149
- # ---- Tool List ----
150
-
151
  tools = [
152
  multiply, add, subtract, divide, modulus,
153
  wiki_search, web_search, arvix_search, similar_question_search
154
  ]
155
 
156
- # ---- Graph Construction ----
157
 
158
  def build_graph(provider: str = "huggingface"):
159
  if provider == "huggingface":
 
2
  import json
3
  from dotenv import load_dotenv
4
 
 
5
  load_dotenv()
6
  os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
 
7
  hf_token = os.getenv("HUGGINGFACE_INFERENCE_TOKEN")
8
 
 
9
  from langgraph.graph import START, StateGraph, MessagesState
10
  from langgraph.prebuilt import tools_condition, ToolNode
11
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
 
17
  from langchain_core.tools import tool
18
  from langchain.schema import Document
19
 
20
+ # ---- Tool Definitions (with docstrings) ----
21
 
22
  @tool
23
  def multiply(a: int, b: int) -> int:
 
48
 
49
  @tool
50
  def wiki_search(query: str) -> str:
51
+ """Search Wikipedia for the query and return text of up to 2 documents."""
52
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
53
  formatted = "\n\n---\n\n".join(
54
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
55
+ for doc in search_docs
 
 
56
  )
57
  return {"wiki_results": formatted}
58
 
59
  @tool
60
  def web_search(query: str) -> str:
61
+ """Search the web for the query using Tavily and return up to 3 results."""
62
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
63
  formatted = "\n\n---\n\n".join(
64
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
65
+ for doc in search_docs
 
 
66
  )
67
  return {"web_results": formatted}
68
 
69
  @tool
70
  def arvix_search(query: str) -> str:
71
+ """Search Arxiv for the query and return content from up to 3 papers."""
72
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
73
  formatted = "\n\n---\n\n".join(
74
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
75
+ for doc in search_docs
 
 
76
  )
77
  return {"arvix_results": formatted}
78
 
79
+ # Build vector store once
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
81
+ json_QA = [json.loads(line) for line in open("metadata.jsonl", "r")]
 
 
 
 
 
82
  documents = [
83
  Document(
84
  page_content=f"Question : {sample['Question']}\n\nFinal answer : {sample['Final answer']}",
85
  metadata={"source": sample["task_id"]}
86
+ ) for sample in json_QA
 
87
  ]
 
88
  vector_store = Chroma.from_documents(
89
  documents=documents,
90
  embedding=embeddings,
91
  persist_directory="./chroma_db",
92
  collection_name="my_collection"
93
  )
 
94
  print("Documents inserted:", vector_store._collection.count())
95
 
96
  @tool
97
  def similar_question_search(query: str) -> str:
98
+ """Search for questions similar to the input query using the vector store."""
99
  matched_docs = vector_store.similarity_search(query, 3)
100
  formatted = "\n\n---\n\n".join(
101
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
102
+ for doc in matched_docs
 
 
103
  )
104
  return {"similar_questions": formatted}
105
 
 
111
  FINAL ANSWER: [YOUR FINAL ANSWER].
112
  YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings...
113
  """
 
114
  sys_msg = SystemMessage(content=system_prompt)
115
 
 
 
116
  tools = [
117
  multiply, add, subtract, divide, modulus,
118
  wiki_search, web_search, arvix_search, similar_question_search
119
  ]
120
 
121
+ # ---- Graph Builder ----
122
 
123
  def build_graph(provider: str = "huggingface"):
124
  if provider == "huggingface":