vtony commited on
Commit
dcee690
·
verified ·
1 Parent(s): b3cbf22

Delete agent.py

Browse files
Files changed (1) hide show
  1. agent.py +0 -389
agent.py DELETED
@@ -1,389 +0,0 @@
1
- import os
2
- import time
3
- import json
4
- import logging
5
- from dotenv import load_dotenv
6
- from langgraph.graph import StateGraph, END
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_community.tools import DuckDuckGoSearchRun
9
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
- from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
11
- from langchain_core.tools import tool
12
- from typing import TypedDict, Annotated, Sequence
13
- import operator
14
- import random
15
-
16
- # Configure logging
17
- logging.basicConfig(level=logging.INFO)
18
- logger = logging.getLogger("GAIA_Agent")
19
-
20
- # Load environment variables
21
- load_dotenv()
22
- google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
23
- if not google_api_key:
24
- raise ValueError("Missing GOOGLE_API_KEY environment variable")
25
-
26
- # --- Math Tools ---
27
- @tool
28
- def multiply(a: int, b: int) -> int:
29
- """Multiply two integers."""
30
- return a * b
31
-
32
- @tool
33
- def add(a: int, b: int) -> int:
34
- """Add two integers."""
35
- return a + b
36
-
37
- @tool
38
- def subtract(a: int, b: int) -> int:
39
- """Subtract b from a."""
40
- return a - b
41
-
42
- @tool
43
- def divide(a: int, b: int) -> float:
44
- """Divide a by b, error on zero."""
45
- if b == 0:
46
- raise ValueError("Cannot divide by zero.")
47
- return a / b
48
-
49
- @tool
50
- def modulus(a: int, b: int) -> int:
51
- """Compute a mod b."""
52
- return a % b
53
-
54
- # --- Browser Tools ---
55
- @tool
56
- def wiki_search(query: str) -> str:
57
- """Search Wikipedia and return up to 3 relevant documents."""
58
- try:
59
- # Ensure query contains "discography" keyword
60
- if "discography" not in query.lower():
61
- query = f"{query} discography"
62
-
63
- docs = WikipediaLoader(query=query, load_max_docs=3).load()
64
- if not docs:
65
- return "No Wikipedia results found."
66
-
67
- results = []
68
- for doc in docs:
69
- title = doc.metadata.get('title', 'Unknown Title')
70
- content = doc.page_content[:2000] # Limit content length
71
- results.append(f"Title: {title}\nContent: {content}")
72
-
73
- return "\n\n---\n\n".join(results)
74
- except Exception as e:
75
- return f"Wikipedia search error: {str(e)}"
76
-
77
- @tool
78
- def arxiv_search(query: str) -> str:
79
- """Search Arxiv and return up to 3 relevant papers."""
80
- try:
81
- docs = ArxivLoader(query=query, load_max_docs=3).load()
82
- if not docs:
83
- return "No arXiv papers found."
84
-
85
- results = []
86
- for doc in docs:
87
- title = doc.metadata.get('Title', 'Unknown Title')
88
- authors = ", ".join(doc.metadata.get('Authors', []))
89
- content = doc.page_content[:2000] # Limit content length
90
- results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
91
-
92
- return "\n\n---\n\n".join(results)
93
- except Exception as e:
94
- return f"arXiv search error: {str(e)}"
95
-
96
- @tool
97
- def web_search(query: str) -> str:
98
- """Search the web using DuckDuckGo and return top results."""
99
- try:
100
- search = DuckDuckGoSearchRun()
101
- result = search.run(query)
102
- return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length
103
- except Exception as e:
104
- return f"Web search error: {str(e)}"
105
-
106
- # --- Load system prompt ---
107
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
108
- system_prompt = f.read()
109
-
110
- # --- Tool Setup ---
111
- tools = [
112
- multiply,
113
- add,
114
- subtract,
115
- divide,
116
- modulus,
117
- wiki_search,
118
- arxiv_search,
119
- web_search,
120
- ]
121
-
122
- # --- Graph Builder ---
123
- def build_graph():
124
- # Initialize model with Gemini 2.5 Flash
125
- llm = ChatGoogleGenerativeAI(
126
- model="gemini-2.5-flash",
127
- temperature=0.3,
128
- google_api_key=google_api_key,
129
- max_retries=0, # Disable internal retries
130
- request_timeout=30 # Keep timeout reasonable
131
- )
132
-
133
- # Bind tools to LLM
134
- llm_with_tools = llm.bind_tools(tools)
135
-
136
- # 1. Define state structure
137
- class AgentState(TypedDict):
138
- messages: Annotated[Sequence, operator.add]
139
- step_count: int
140
- start_time: float
141
- last_action: str
142
- api_errors: int # Track consecutive API errors
143
-
144
- # 2. Create graph
145
- workflow = StateGraph(AgentState)
146
-
147
- # 3. Define node functions
148
- def agent_node(state: AgentState):
149
- """Main agent node with manual retry handling"""
150
- # Ensure state has required fields
151
- state.setdefault("start_time", time.time())
152
- state.setdefault("step_count", 0)
153
- state.setdefault("last_action", "start")
154
- state.setdefault("api_errors", 0)
155
-
156
- # Check global timeout (2 minutes)
157
- if time.time() - state["start_time"] > 120:
158
- return {
159
- "messages": [AIMessage(content="AGENT ERROR (GLOBAL_TIMEOUT): Execution exceeded 2-minute limit")],
160
- "step_count": state["step_count"] + 1,
161
- "start_time": state["start_time"],
162
- "last_action": "timeout",
163
- "api_errors": state["api_errors"]
164
- }
165
-
166
- # Check step limit (max 8 steps)
167
- if state["step_count"] >= 8:
168
- return {
169
- "messages": [AIMessage(content="AGENT ERROR (STEP_LIMIT): Exceeded maximum step count of 8")],
170
- "step_count": state["step_count"] + 1,
171
- "start_time": state["start_time"],
172
- "last_action": "step_limit",
173
- "api_errors": state["api_errors"]
174
- }
175
-
176
- # Check consecutive API errors
177
- if state["api_errors"] >= 3:
178
- return {
179
- "messages": [AIMessage(content="AGENT ERROR (API_LIMIT): Too many consecutive API errors")],
180
- "step_count": state["step_count"] + 1,
181
- "start_time": state["start_time"],
182
- "last_action": "api_limit",
183
- "api_errors": state["api_errors"]
184
- }
185
-
186
- try:
187
- # Add variable delay to avoid rate limiting
188
- delay = 2 + random.uniform(0, 3) # 2-5 seconds
189
- time.sleep(delay)
190
-
191
- # Call API without automatic retries
192
- response = llm_with_tools.invoke(state["messages"])
193
-
194
- # Reset error counter on success
195
- return {
196
- "messages": [response],
197
- "step_count": state["step_count"] + 1,
198
- "start_time": state["start_time"],
199
- "last_action": "agent",
200
- "api_errors": 0 # Reset error counter
201
- }
202
-
203
- except Exception as e:
204
- # Detailed error logging
205
- error_details = f"Gemini API Error: {type(e).__name__}: {str(e)}"
206
- logger.error(error_details)
207
-
208
- error_type = "UNKNOWN"
209
- if "429" in str(e) or "ResourceExhausted" in str(e):
210
- error_type = "RESOURCE_EXHAUSTED"
211
- elif "400" in str(e):
212
- error_type = "INVALID_REQUEST"
213
- elif "503" in str(e):
214
- error_type = "SERVICE_UNAVAILABLE"
215
-
216
- error_msg = f"AGENT ERROR ({error_type}): {error_details[:300]}"
217
-
218
- return {
219
- "messages": [AIMessage(content=error_msg)],
220
- "step_count": state["step_count"] + 1,
221
- "start_time": state["start_time"],
222
- "last_action": "error",
223
- "api_errors": state["api_errors"] + 1 # Increment error counter
224
- }
225
-
226
- def tool_node(state: AgentState):
227
- """Tool execution node"""
228
- # Ensure state has required fields
229
- state.setdefault("start_time", time.time())
230
- state.setdefault("step_count", 0)
231
- state.setdefault("last_action", "start")
232
- state.setdefault("api_errors", 0)
233
-
234
- # Check global timeout (2 minutes)
235
- if time.time() - state["start_time"] > 120:
236
- return {
237
- "messages": [AIMessage(content="AGENT ERROR (GLOBAL_TIMEOUT): Execution exceeded 2-minute limit")],
238
- "step_count": state["step_count"] + 1,
239
- "start_time": state["start_time"],
240
- "last_action": "timeout",
241
- "api_errors": state["api_errors"]
242
- }
243
-
244
- last_msg = state["messages"][-1]
245
- tool_calls = last_msg.additional_kwargs.get("tool_calls", [])
246
-
247
- responses = []
248
- for call in tool_calls:
249
- tool_name = call["function"]["name"]
250
- tool_args = call["function"].get("arguments", {})
251
-
252
- tool_func = next((t for t in tools if t.name == tool_name), None)
253
- if not tool_func:
254
- responses.append(f"Tool {tool_name} not available")
255
- continue
256
-
257
- try:
258
- # Parse arguments
259
- if isinstance(tool_args, str):
260
- try:
261
- tool_args = json.loads(tool_args)
262
- except json.JSONDecodeError:
263
- if "query" in tool_args:
264
- tool_args = {"query": tool_args}
265
- else:
266
- tool_args = {"query": tool_args}
267
-
268
- # Execute tool
269
- result = tool_func.invoke(tool_args)
270
- responses.append(f"{tool_name} result: {str(result)[:1000]}")
271
- except Exception as e:
272
- responses.append(f"{tool_name} error: {str(e)}")
273
-
274
- tool_response_content = "\n".join(responses)
275
- return {
276
- "messages": [AIMessage(content=tool_response_content)],
277
- "step_count": state["step_count"] + 1,
278
- "start_time": state["start_time"],
279
- "last_action": "tool",
280
- "api_errors": state["api_errors"] # Preserve error count
281
- }
282
-
283
- # 4. Add nodes to workflow
284
- workflow.add_node("agent", agent_node)
285
- workflow.add_node("tools", tool_node)
286
-
287
- # 5. Set entry point
288
- workflow.set_entry_point("agent")
289
-
290
- # 6. Define conditional edges
291
- def should_continue(state: AgentState):
292
- last_msg = state["messages"][-1]
293
-
294
- # Handle timeout or step limit errors
295
- if "AGENT ERROR (GLOBAL_TIMEOUT)" in last_msg.content or "AGENT ERROR (STEP_LIMIT)" in last_msg.content or "AGENT ERROR (API_LIMIT)" in last_msg.content:
296
- return "end"
297
-
298
- # Handle all other errors
299
- if "AGENT ERROR" in last_msg.content:
300
- # For RESOURCE_EXHAUSTED errors, wait longer before retrying
301
- if "RESOURCE_EXHAUSTED" in last_msg.content:
302
- time.sleep(10 + random.uniform(0, 10)) # Wait 10-20 seconds
303
- return "agent"
304
-
305
- # Route to tools if tool calls exist
306
- if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
307
- return "tools"
308
-
309
- # End if final answer is present
310
- if "FINAL ANSWER" in last_msg.content:
311
- return "end"
312
-
313
- # Continue to agent otherwise
314
- return "agent"
315
-
316
- workflow.add_conditional_edges(
317
- "agent",
318
- should_continue,
319
- {
320
- "agent": "agent",
321
- "tools": "tools",
322
- "end": END
323
- }
324
- )
325
-
326
- # 7. Define flow after tool node
327
- workflow.add_edge("tools", "agent")
328
-
329
- # 8. Compile graph
330
- return workflow.compile()
331
-
332
- # Initialize agent graph
333
- agent_graph = build_graph()
334
-
335
- # Wrapper function to ensure execution within time limits
336
- def run_agent(question):
337
- # Create initial state with all required fields
338
- initial_state = {
339
- "messages": [
340
- SystemMessage(content=system_prompt),
341
- HumanMessage(content=question)
342
- ],
343
- "step_count": 0,
344
- "start_time": time.time(),
345
- "last_action": "start",
346
- "api_errors": 0
347
- }
348
-
349
- # Run with overall timeout
350
- start_time = time.time()
351
- result = None
352
- end_state_reached = False
353
-
354
- try:
355
- # Execute with 3-minute overall timeout
356
- for step in agent_graph.stream(initial_state):
357
- # Check overall timeout every step
358
- if time.time() - start_time > 180: # 3 minutes
359
- return {"error": "Overall execution timeout (3 minutes)"}
360
-
361
- # Capture the final state when the graph completes
362
- if END in step:
363
- result = step[END]
364
- end_state_reached = True
365
- break
366
- except Exception as e:
367
- return {"error": f"Execution failed: {str(e)}"}
368
-
369
- # Extract final answer safely
370
- if end_state_reached and result is not None:
371
- if "messages" in result and result["messages"]:
372
- return {"answer": result["messages"][-1].content}
373
- else:
374
- return {"error": "Agent finished but produced no messages"}
375
- else:
376
- return {"error": "Agent did not complete execution"}
377
-
378
- # 示例调用函数(在app.py中使用)
379
- def process_question(question):
380
- # Add initial delay to avoid burst requests
381
- time.sleep(1 + random.uniform(0, 2))
382
-
383
- response = run_agent(question)
384
- if "answer" in response:
385
- return response["answer"]
386
- elif "error" in response:
387
- return f"Error: {response['error']}"
388
- else:
389
- return "Unexpected response format"