vtony commited on
Commit
0f4f3d7
·
verified ·
1 Parent(s): a371b5e

Upload agent.py

Browse files
Files changed (1) hide show
  1. agent.py +389 -0
agent.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import logging
5
+ from dotenv import load_dotenv
6
+ from langgraph.graph import StateGraph, END
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_community.tools import DuckDuckGoSearchRun
9
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
+ from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
11
+ from langchain_core.tools import tool
12
+ from typing import TypedDict, Annotated, Sequence
13
+ import operator
14
+ import random
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger("GAIA_Agent")
19
+
20
+ # Load environment variables
21
+ load_dotenv()
22
+ google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
23
+ if not google_api_key:
24
+ raise ValueError("Missing GOOGLE_API_KEY environment variable")
25
+
26
+ # --- Math Tools ---
27
+ @tool
28
+ def multiply(a: int, b: int) -> int:
29
+ """Multiply two integers."""
30
+ return a * b
31
+
32
+ @tool
33
+ def add(a: int, b: int) -> int:
34
+ """Add two integers."""
35
+ return a + b
36
+
37
+ @tool
38
+ def subtract(a: int, b: int) -> int:
39
+ """Subtract b from a."""
40
+ return a - b
41
+
42
+ @tool
43
+ def divide(a: int, b: int) -> float:
44
+ """Divide a by b, error on zero."""
45
+ if b == 0:
46
+ raise ValueError("Cannot divide by zero.")
47
+ return a / b
48
+
49
+ @tool
50
+ def modulus(a: int, b: int) -> int:
51
+ """Compute a mod b."""
52
+ return a % b
53
+
54
+ # --- Browser Tools ---
55
+ @tool
56
+ def wiki_search(query: str) -> str:
57
+ """Search Wikipedia and return up to 3 relevant documents."""
58
+ try:
59
+ # Ensure query contains "discography" keyword
60
+ if "discography" not in query.lower():
61
+ query = f"{query} discography"
62
+
63
+ docs = WikipediaLoader(query=query, load_max_docs=3).load()
64
+ if not docs:
65
+ return "No Wikipedia results found."
66
+
67
+ results = []
68
+ for doc in docs:
69
+ title = doc.metadata.get('title', 'Unknown Title')
70
+ content = doc.page_content[:2000] # Limit content length
71
+ results.append(f"Title: {title}\nContent: {content}")
72
+
73
+ return "\n\n---\n\n".join(results)
74
+ except Exception as e:
75
+ return f"Wikipedia search error: {str(e)}"
76
+
77
+ @tool
78
+ def arxiv_search(query: str) -> str:
79
+ """Search Arxiv and return up to 3 relevant papers."""
80
+ try:
81
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
82
+ if not docs:
83
+ return "No arXiv papers found."
84
+
85
+ results = []
86
+ for doc in docs:
87
+ title = doc.metadata.get('Title', 'Unknown Title')
88
+ authors = ", ".join(doc.metadata.get('Authors', []))
89
+ content = doc.page_content[:2000] # Limit content length
90
+ results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
91
+
92
+ return "\n\n---\n\n".join(results)
93
+ except Exception as e:
94
+ return f"arXiv search error: {str(e)}"
95
+
96
+ @tool
97
+ def web_search(query: str) -> str:
98
+ """Search the web using DuckDuckGo and return top results."""
99
+ try:
100
+ search = DuckDuckGoSearchRun()
101
+ result = search.run(query)
102
+ return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length
103
+ except Exception as e:
104
+ return f"Web search error: {str(e)}"
105
+
106
+ # --- Load system prompt ---
107
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
108
+ system_prompt = f.read()
109
+
110
+ # --- Tool Setup ---
111
+ tools = [
112
+ multiply,
113
+ add,
114
+ subtract,
115
+ divide,
116
+ modulus,
117
+ wiki_search,
118
+ arxiv_search,
119
+ web_search,
120
+ ]
121
+
122
+ # --- Graph Builder ---
123
+ def build_graph():
124
+ # Initialize model with Gemini 2.5 Flash
125
+ llm = ChatGoogleGenerativeAI(
126
+ model="gemini-2.5-flash",
127
+ temperature=0.3,
128
+ google_api_key=google_api_key,
129
+ max_retries=0, # Disable internal retries
130
+ request_timeout=30 # Keep timeout reasonable
131
+ )
132
+
133
+ # Bind tools to LLM
134
+ llm_with_tools = llm.bind_tools(tools)
135
+
136
+ # 1. Define state structure
137
+ class AgentState(TypedDict):
138
+ messages: Annotated[Sequence, operator.add]
139
+ step_count: int
140
+ start_time: float
141
+ last_action: str
142
+ api_errors: int # Track consecutive API errors
143
+
144
+ # 2. Create graph
145
+ workflow = StateGraph(AgentState)
146
+
147
+ # 3. Define node functions
148
+ def agent_node(state: AgentState):
149
+ """Main agent node with manual retry handling"""
150
+ # Ensure state has required fields
151
+ state.setdefault("start_time", time.time())
152
+ state.setdefault("step_count", 0)
153
+ state.setdefault("last_action", "start")
154
+ state.setdefault("api_errors", 0)
155
+
156
+ # Check global timeout (2 minutes)
157
+ if time.time() - state["start_time"] > 120:
158
+ return {
159
+ "messages": [AIMessage(content="AGENT ERROR (GLOBAL_TIMEOUT): Execution exceeded 2-minute limit")],
160
+ "step_count": state["step_count"] + 1,
161
+ "start_time": state["start_time"],
162
+ "last_action": "timeout",
163
+ "api_errors": state["api_errors"]
164
+ }
165
+
166
+ # Check step limit (max 8 steps)
167
+ if state["step_count"] >= 8:
168
+ return {
169
+ "messages": [AIMessage(content="AGENT ERROR (STEP_LIMIT): Exceeded maximum step count of 8")],
170
+ "step_count": state["step_count"] + 1,
171
+ "start_time": state["start_time"],
172
+ "last_action": "step_limit",
173
+ "api_errors": state["api_errors"]
174
+ }
175
+
176
+ # Check consecutive API errors
177
+ if state["api_errors"] >= 3:
178
+ return {
179
+ "messages": [AIMessage(content="AGENT ERROR (API_LIMIT): Too many consecutive API errors")],
180
+ "step_count": state["step_count"] + 1,
181
+ "start_time": state["start_time"],
182
+ "last_action": "api_limit",
183
+ "api_errors": state["api_errors"]
184
+ }
185
+
186
+ try:
187
+ # Add variable delay to avoid rate limiting
188
+ delay = 2 + random.uniform(0, 3) # 2-5 seconds
189
+ time.sleep(delay)
190
+
191
+ # Call API without automatic retries
192
+ response = llm_with_tools.invoke(state["messages"])
193
+
194
+ # Reset error counter on success
195
+ return {
196
+ "messages": [response],
197
+ "step_count": state["step_count"] + 1,
198
+ "start_time": state["start_time"],
199
+ "last_action": "agent",
200
+ "api_errors": 0 # Reset error counter
201
+ }
202
+
203
+ except Exception as e:
204
+ # Detailed error logging
205
+ error_details = f"Gemini API Error: {type(e).__name__}: {str(e)}"
206
+ logger.error(error_details)
207
+
208
+ error_type = "UNKNOWN"
209
+ if "429" in str(e) or "ResourceExhausted" in str(e):
210
+ error_type = "RESOURCE_EXHAUSTED"
211
+ elif "400" in str(e):
212
+ error_type = "INVALID_REQUEST"
213
+ elif "503" in str(e):
214
+ error_type = "SERVICE_UNAVAILABLE"
215
+
216
+ error_msg = f"AGENT ERROR ({error_type}): {error_details[:300]}"
217
+
218
+ return {
219
+ "messages": [AIMessage(content=error_msg)],
220
+ "step_count": state["step_count"] + 1,
221
+ "start_time": state["start_time"],
222
+ "last_action": "error",
223
+ "api_errors": state["api_errors"] + 1 # Increment error counter
224
+ }
225
+
226
+ def tool_node(state: AgentState):
227
+ """Tool execution node"""
228
+ # Ensure state has required fields
229
+ state.setdefault("start_time", time.time())
230
+ state.setdefault("step_count", 0)
231
+ state.setdefault("last_action", "start")
232
+ state.setdefault("api_errors", 0)
233
+
234
+ # Check global timeout (2 minutes)
235
+ if time.time() - state["start_time"] > 120:
236
+ return {
237
+ "messages": [AIMessage(content="AGENT ERROR (GLOBAL_TIMEOUT): Execution exceeded 2-minute limit")],
238
+ "step_count": state["step_count"] + 1,
239
+ "start_time": state["start_time"],
240
+ "last_action": "timeout",
241
+ "api_errors": state["api_errors"]
242
+ }
243
+
244
+ last_msg = state["messages"][-1]
245
+ tool_calls = last_msg.additional_kwargs.get("tool_calls", [])
246
+
247
+ responses = []
248
+ for call in tool_calls:
249
+ tool_name = call["function"]["name"]
250
+ tool_args = call["function"].get("arguments", {})
251
+
252
+ tool_func = next((t for t in tools if t.name == tool_name), None)
253
+ if not tool_func:
254
+ responses.append(f"Tool {tool_name} not available")
255
+ continue
256
+
257
+ try:
258
+ # Parse arguments
259
+ if isinstance(tool_args, str):
260
+ try:
261
+ tool_args = json.loads(tool_args)
262
+ except json.JSONDecodeError:
263
+ if "query" in tool_args:
264
+ tool_args = {"query": tool_args}
265
+ else:
266
+ tool_args = {"query": tool_args}
267
+
268
+ # Execute tool
269
+ result = tool_func.invoke(tool_args)
270
+ responses.append(f"{tool_name} result: {str(result)[:1000]}")
271
+ except Exception as e:
272
+ responses.append(f"{tool_name} error: {str(e)}")
273
+
274
+ tool_response_content = "\n".join(responses)
275
+ return {
276
+ "messages": [AIMessage(content=tool_response_content)],
277
+ "step_count": state["step_count"] + 1,
278
+ "start_time": state["start_time"],
279
+ "last_action": "tool",
280
+ "api_errors": state["api_errors"] # Preserve error count
281
+ }
282
+
283
+ # 4. Add nodes to workflow
284
+ workflow.add_node("agent", agent_node)
285
+ workflow.add_node("tools", tool_node)
286
+
287
+ # 5. Set entry point
288
+ workflow.set_entry_point("agent")
289
+
290
+ # 6. Define conditional edges
291
+ def should_continue(state: AgentState):
292
+ last_msg = state["messages"][-1]
293
+
294
+ # Handle timeout or step limit errors
295
+ if "AGENT ERROR (GLOBAL_TIMEOUT)" in last_msg.content or "AGENT ERROR (STEP_LIMIT)" in last_msg.content or "AGENT ERROR (API_LIMIT)" in last_msg.content:
296
+ return "end"
297
+
298
+ # Handle all other errors
299
+ if "AGENT ERROR" in last_msg.content:
300
+ # For RESOURCE_EXHAUSTED errors, wait longer before retrying
301
+ if "RESOURCE_EXHAUSTED" in last_msg.content:
302
+ time.sleep(10 + random.uniform(0, 10)) # Wait 10-20 seconds
303
+ return "agent"
304
+
305
+ # Route to tools if tool calls exist
306
+ if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
307
+ return "tools"
308
+
309
+ # End if final answer is present
310
+ if "FINAL ANSWER" in last_msg.content:
311
+ return "end"
312
+
313
+ # Continue to agent otherwise
314
+ return "agent"
315
+
316
+ workflow.add_conditional_edges(
317
+ "agent",
318
+ should_continue,
319
+ {
320
+ "agent": "agent",
321
+ "tools": "tools",
322
+ "end": END
323
+ }
324
+ )
325
+
326
+ # 7. Define flow after tool node
327
+ workflow.add_edge("tools", "agent")
328
+
329
+ # 8. Compile graph
330
+ return workflow.compile()
331
+
332
+ # Initialize agent graph
333
+ agent_graph = build_graph()
334
+
335
+ # Wrapper function to ensure execution within time limits
336
+ def run_agent(question):
337
+ # Create initial state with all required fields
338
+ initial_state = {
339
+ "messages": [
340
+ SystemMessage(content=system_prompt),
341
+ HumanMessage(content=question)
342
+ ],
343
+ "step_count": 0,
344
+ "start_time": time.time(),
345
+ "last_action": "start",
346
+ "api_errors": 0
347
+ }
348
+
349
+ # Run with overall timeout
350
+ start_time = time.time()
351
+ result = None
352
+ end_state_reached = False
353
+
354
+ try:
355
+ # Execute with 3-minute overall timeout
356
+ for step in agent_graph.stream(initial_state):
357
+ # Check overall timeout every step
358
+ if time.time() - start_time > 180: # 3 minutes
359
+ return {"error": "Overall execution timeout (3 minutes)"}
360
+
361
+ # Capture the final state when the graph completes
362
+ if END in step:
363
+ result = step[END]
364
+ end_state_reached = True
365
+ break
366
+ except Exception as e:
367
+ return {"error": f"Execution failed: {str(e)}"}
368
+
369
+ # Extract final answer safely
370
+ if end_state_reached and result is not None:
371
+ if "messages" in result and result["messages"]:
372
+ return {"answer": result["messages"][-1].content}
373
+ else:
374
+ return {"error": "Agent finished but produced no messages"}
375
+ else:
376
+ return {"error": "Agent did not complete execution"}
377
+
378
+ # 示例调用函数(在app.py中使用)
379
+ def process_question(question):
380
+ # Add initial delay to avoid burst requests
381
+ time.sleep(1 + random.uniform(0, 2))
382
+
383
+ response = run_agent(question)
384
+ if "answer" in response:
385
+ return response["answer"]
386
+ elif "error" in response:
387
+ return f"Error: {response['error']}"
388
+ else:
389
+ return "Unexpected response format"