CMAl3j0 commited on
Commit
321f759
·
verified ·
1 Parent(s): 81917a3

Custom Agent

Browse files
Files changed (1) hide show
  1. Agent +207 -0
Agent ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.prebuilt import create_react_agent
2
+ from langchain_community.tools.tavily_search import TavilySearchResults
3
+ from langchain_community.document_loaders import WikipediaLoader
4
+ from langchain_community.document_loaders import ArxivLoader
5
+ from dotenv import load_dotenv, find_dotenv
6
+ from langchain_core.tools import tool
7
+ from langchain_huggingface import HuggingFaceEmbeddings
8
+ from langchain_community.vectorstores import SupabaseVectorStore
9
+ from langchain_core.messages import HumanMessage
10
+ from supabase import create_client, Client
11
+ import os
12
+
13
+ load_dotenv(find_dotenv())
14
+
15
+ DEFAULT_PROMPT = """
16
+ You are a helpful assistant tasked with answering questions using a set of tools.
17
+ Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
18
+ FINAL ANSWER: [YOUR FINAL ANSWER].
19
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
20
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
21
+ """
22
+
23
+
24
+ @tool
25
+ def wiki_search(query: str) -> str:
26
+ """Search Wikipedia for a query and return maximum 2 results.
27
+ Args:
28
+ query: The search query."""
29
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
30
+ formatted_search_docs = "\n\n---\n\n".join(
31
+ [
32
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
33
+ for doc in search_docs
34
+ ]
35
+ )
36
+ return {"wiki_results": formatted_search_docs}
37
+
38
+
39
+ @tool
40
+ def web_search(query: str) -> str:
41
+ """Search Tavily for a query and return maximum 3 results.
42
+ Args:
43
+ query: The search query."""
44
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
45
+ formatted_search_docs = "\n\n---\n\n".join(
46
+ [
47
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
48
+ for doc in search_docs
49
+ ]
50
+ )
51
+ return {"web_results": formatted_search_docs}
52
+
53
+
54
+ @tool
55
+ def arvix_search(query: str) -> str:
56
+ """Search Arxiv for a query and return maximum 3 result.
57
+ Args:
58
+ query: The search query."""
59
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
60
+ formatted_search_docs = "\n\n---\n\n".join(
61
+ [
62
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
63
+ for doc in search_docs
64
+ ]
65
+ )
66
+ return {"arvix_results": formatted_search_docs}
67
+
68
+
69
+ @tool
70
+ def multiply(a: int, b: int) -> int:
71
+ """Multiply two numbers.
72
+ Args:
73
+ a: first int
74
+ b: second int
75
+ """
76
+ return a * b
77
+
78
+
79
+ @tool
80
+ def add(a: int, b: int) -> int:
81
+ """Add two numbers.
82
+ Args:
83
+ a: first int
84
+ b: second int
85
+ """
86
+ return a + b
87
+
88
+
89
+ @tool
90
+ def subtract(a: int, b: int) -> int:
91
+ """Subtract two numbers.
92
+ Args:
93
+ a: first int
94
+ b: second int
95
+ """
96
+ return a - b
97
+
98
+
99
+ @tool
100
+ def divide(a: int, b: int) -> int:
101
+ """Divide two numbers.
102
+ Args:
103
+ a: first int
104
+ b: second int
105
+ """
106
+ if b == 0:
107
+ raise ValueError("Cannot divide by zero.")
108
+ return a / b
109
+
110
+
111
+ @tool
112
+ def modulus(a: int, b: int) -> int:
113
+ """Get the modulus of two numbers.
114
+ Args:
115
+ a: first int
116
+ b: second int
117
+ """
118
+ return a % b
119
+
120
+
121
+ class CustomAgent:
122
+ def __init__(self):
123
+ print("CustomAgent initialized.")
124
+
125
+ # Initialize embeddings and vector store
126
+ self.embeddings = HuggingFaceEmbeddings(
127
+ model_name="sentence-transformers/all-mpnet-base-v2"
128
+ )
129
+
130
+ self.supabase: Client = create_client(
131
+ os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_ROLE_KEY")
132
+ )
133
+
134
+ self.vector_store = SupabaseVectorStore(
135
+ client=self.supabase,
136
+ embedding=self.embeddings,
137
+ table_name="documents_1",
138
+ query_name="match_documents_1",
139
+ )
140
+
141
+ # Create the agent
142
+ self.agent = create_react_agent(
143
+ model="openai:gpt-4.1",
144
+ tools=[
145
+ web_search,
146
+ add,
147
+ subtract,
148
+ multiply,
149
+ divide,
150
+ modulus,
151
+ wiki_search,
152
+ arvix_search,
153
+ ],
154
+ prompt=DEFAULT_PROMPT,
155
+ )
156
+
157
+ def retriever(self, query: str):
158
+ """Retriever"""
159
+ similar_question = self.vector_store.similarity_search(query)
160
+ return HumanMessage(
161
+ content=f"Here I provide a similar question and answer for reference, you can use it to answer the question: \n\n{similar_question[0].page_content}",
162
+ )
163
+
164
+ def __call__(self, question: str) -> str:
165
+ """Run the agent on a question and return the answer."""
166
+ print(f"CustomAgent received question (first 50 chars): {question[:50]}...")
167
+
168
+ try:
169
+ answer = self.agent.invoke(
170
+ {
171
+ "messages": [
172
+ self.retriever(question),
173
+ HumanMessage(content=question),
174
+ ]
175
+ }
176
+ )
177
+ result = answer["messages"][-1].content
178
+
179
+ if "FINAL ANSWER: " in result:
180
+ final_answer_start = result.find("FINAL ANSWER: ") + len(
181
+ "FINAL ANSWER: "
182
+ )
183
+ extracted_answer = result[final_answer_start:].strip()
184
+ print(f"CustomAgent extracted answer: {extracted_answer}")
185
+ return extracted_answer
186
+ else:
187
+ print(
188
+ f"CustomAgent returning full answer (no FINAL ANSWER found): {result}"
189
+ )
190
+ return result
191
+
192
+ except Exception as e:
193
+ print(f"Error in CustomAgent: {e}")
194
+ return f"Error: {e}"
195
+
196
+
197
+ if __name__ == "__main__":
198
+ agent = CustomAgent()
199
+ agent(
200
+ "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
201
+ )
202
+ agent(
203
+ "How many at bats did the Yankee with the most walks in the 1977 regular season have that same season?"
204
+ )
205
+ agent(
206
+ "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species to be on camera simultaneously?"
207
+ )