vtony commited on
Commit
a371b5e
·
verified ·
1 Parent(s): 26a62a5

Delete agent.py

Browse files
Files changed (1) hide show
  1. agent.py +0 -365
agent.py DELETED
@@ -1,365 +0,0 @@
1
- import os
2
- import time
3
- import json
4
- import logging
5
- from dotenv import load_dotenv
6
- from langgraph.graph import StateGraph, END
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_community.tools import DuckDuckGoSearchRun
9
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
- from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
11
- from langchain_core.tools import tool
12
- from tenacity import retry, stop_after_attempt, wait_exponential
13
- from typing import TypedDict, Annotated, Sequence
14
- import operator
15
-
16
- # Configure logging
17
- logging.basicConfig(level=logging.INFO)
18
- logger = logging.getLogger("GAIA_Agent")
19
-
20
- # Load environment variables
21
- load_dotenv()
22
- google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
23
- if not google_api_key:
24
- raise ValueError("Missing GOOGLE_API_KEY environment variable")
25
-
26
- # --- Math Tools ---
27
- @tool
28
- def multiply(a: int, b: int) -> int:
29
- """Multiply two integers."""
30
- return a * b
31
-
32
- @tool
33
- def add(a: int, b: int) -> int:
34
- """Add two integers."""
35
- return a + b
36
-
37
- @tool
38
- def subtract(a: int, b: int) -> int:
39
- """Subtract b from a."""
40
- return a - b
41
-
42
- @tool
43
- def divide(a: int, b: int) -> float:
44
- """Divide a by b, error on zero."""
45
- if b == 0:
46
- raise ValueError("Cannot divide by zero.")
47
- return a / b
48
-
49
- @tool
50
- def modulus(a: int, b: int) -> int:
51
- """Compute a mod b."""
52
- return a % b
53
-
54
- # --- Browser Tools ---
55
- @tool
56
- def wiki_search(query: str) -> str:
57
- """Search Wikipedia and return up to 3 relevant documents."""
58
- try:
59
- # Ensure query contains "discography" keyword
60
- if "discography" not in query.lower():
61
- query = f"{query} discography"
62
-
63
- docs = WikipediaLoader(query=query, load_max_docs=3).load()
64
- if not docs:
65
- return "No Wikipedia results found."
66
-
67
- results = []
68
- for doc in docs:
69
- title = doc.metadata.get('title', 'Unknown Title')
70
- content = doc.page_content[:2000] # Limit content length
71
- results.append(f"Title: {title}\nContent: {content}")
72
-
73
- return "\n\n---\n\n".join(results)
74
- except Exception as e:
75
- return f"Wikipedia search error: {str(e)}"
76
-
77
- @tool
78
- def arxiv_search(query: str) -> str:
79
- """Search Arxiv and return up to 3 relevant papers."""
80
- try:
81
- docs = ArxivLoader(query=query, load_max_docs=3).load()
82
- if not docs:
83
- return "No arXiv papers found."
84
-
85
- results = []
86
- for doc in docs:
87
- title = doc.metadata.get('Title', 'Unknown Title')
88
- authors = ", ".join(doc.metadata.get('Authors', []))
89
- content = doc.page_content[:2000] # Limit content length
90
- results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
91
-
92
- return "\n\n---\n\n".join(results)
93
- except Exception as e:
94
- return f"arXiv search error: {str(e)}"
95
-
96
- @tool
97
- def web_search(query: str) -> str:
98
- """Search the web using DuckDuckGo and return top results."""
99
- try:
100
- search = DuckDuckGoSearchRun()
101
- result = search.run(query)
102
- return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length
103
- except Exception as e:
104
- return f"Web search error: {str(e)}"
105
-
106
- # --- Load system prompt ---
107
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
108
- system_prompt = f.read()
109
-
110
- # --- Tool Setup ---
111
- tools = [
112
- multiply,
113
- add,
114
- subtract,
115
- divide,
116
- modulus,
117
- wiki_search,
118
- arxiv_search,
119
- web_search,
120
- ]
121
-
122
- # --- Graph Builder ---
123
- def build_graph():
124
- # Initialize model with Gemini 2.5 Flash
125
- llm = ChatGoogleGenerativeAI(
126
- model="gemini-2.5-flash",
127
- temperature=0.3,
128
- google_api_key=google_api_key,
129
- max_retries=2, # Reduce retries to prevent long delays
130
- request_timeout=20 # Reduce timeout
131
- )
132
-
133
- # Bind tools to LLM
134
- llm_with_tools = llm.bind_tools(tools)
135
-
136
- # 1. Define state structure
137
- class AgentState(TypedDict):
138
- messages: Annotated[Sequence, operator.add]
139
- step_count: int
140
- start_time: float
141
- last_action: str
142
-
143
- # 2. Create graph
144
- workflow = StateGraph(AgentState)
145
-
146
- # 3. Define node functions
147
- def agent_node(state: AgentState):
148
- """Main agent node"""
149
- # Ensure state has required fields
150
- state.setdefault("start_time", time.time())
151
- state.setdefault("step_count", 0)
152
- state.setdefault("last_action", "start")
153
-
154
- # Check global timeout (2 minutes)
155
- if time.time() - state["start_time"] > 120:
156
- return {
157
- "messages": [AIMessage(content="AGENT ERROR (GLOBAL_TIMEOUT): Execution exceeded 2-minute limit")],
158
- "step_count": state["step_count"] + 1,
159
- "start_time": state["start_time"],
160
- "last_action": "timeout"
161
- }
162
-
163
- # Check step limit (max 8 steps)
164
- if state["step_count"] >= 8:
165
- return {
166
- "messages": [AIMessage(content="AGENT ERROR (STEP_LIMIT): Exceeded maximum step count of 8")],
167
- "step_count": state["step_count"] + 1,
168
- "start_time": state["start_time"],
169
- "last_action": "step_limit"
170
- }
171
-
172
- try:
173
- # Add request delay to avoid rate limiting
174
- time.sleep(1)
175
-
176
- # Retry mechanism for API calls
177
- @retry(stop=stop_after_attempt(1), # Only 1 retry
178
- wait=wait_exponential(multiplier=1, min=1, max=5))
179
- def invoke_with_retry():
180
- return llm_with_tools.invoke(state["messages"])
181
-
182
- response = invoke_with_retry()
183
- return {
184
- "messages": [response],
185
- "step_count": state["step_count"] + 1,
186
- "start_time": state["start_time"],
187
- "last_action": "agent"
188
- }
189
-
190
- except Exception as e:
191
- # Detailed error logging
192
- error_details = f"Gemini API Error: {type(e).__name__}: {str(e)}"
193
- logger.error(error_details)
194
-
195
- error_type = "UNKNOWN"
196
- if "429" in str(e):
197
- error_type = "QUOTA_EXCEEDED"
198
- elif "400" in str(e):
199
- error_type = "INVALID_REQUEST"
200
- elif "503" in str(e):
201
- error_type = "SERVICE_UNAVAILABLE"
202
-
203
- error_msg = f"AGENT ERROR ({error_type}): {error_details[:300]}"
204
-
205
- return {
206
- "messages": [AIMessage(content=error_msg)],
207
- "step_count": state["step_count"] + 1,
208
- "start_time": state["start_time"],
209
- "last_action": "error"
210
- }
211
-
212
- def tool_node(state: AgentState):
213
- """Tool execution node"""
214
- # Ensure state has required fields
215
- state.setdefault("start_time", time.time())
216
- state.setdefault("step_count", 0)
217
- state.setdefault("last_action", "start")
218
-
219
- # Check global timeout (2 minutes)
220
- if time.time() - state["start_time"] > 120:
221
- return {
222
- "messages": [AIMessage(content="AGENT ERROR (GLOBAL_TIMEOUT): Execution exceeded 2-minute limit")],
223
- "step_count": state["step_count"] + 1,
224
- "start_time": state["start_time"],
225
- "last_action": "timeout"
226
- }
227
-
228
- last_msg = state["messages"][-1]
229
- tool_calls = last_msg.additional_kwargs.get("tool_calls", [])
230
-
231
- responses = []
232
- for call in tool_calls:
233
- tool_name = call["function"]["name"]
234
- tool_args = call["function"].get("arguments", {})
235
-
236
- tool_func = next((t for t in tools if t.name == tool_name), None)
237
- if not tool_func:
238
- responses.append(f"Tool {tool_name} not available")
239
- continue
240
-
241
- try:
242
- # Parse arguments
243
- if isinstance(tool_args, str):
244
- try:
245
- tool_args = json.loads(tool_args)
246
- except json.JSONDecodeError:
247
- if "query" in tool_args:
248
- tool_args = {"query": tool_args}
249
- else:
250
- tool_args = {"query": tool_args}
251
-
252
- # Execute tool
253
- result = tool_func.invoke(tool_args)
254
- responses.append(f"{tool_name} result: {str(result)[:1000]}")
255
- except Exception as e:
256
- responses.append(f"{tool_name} error: {str(e)}")
257
-
258
- tool_response_content = "\n".join(responses)
259
- return {
260
- "messages": [AIMessage(content=tool_response_content)],
261
- "step_count": state["step_count"] + 1,
262
- "start_time": state["start_time"],
263
- "last_action": "tool"
264
- }
265
-
266
- # 4. Add nodes to workflow
267
- workflow.add_node("agent", agent_node)
268
- workflow.add_node("tools", tool_node)
269
-
270
- # 5. Set entry point
271
- workflow.set_entry_point("agent")
272
-
273
- # 6. Define conditional edges
274
- def should_continue(state: AgentState):
275
- last_msg = state["messages"][-1]
276
-
277
- # Handle timeout or step limit errors
278
- if "AGENT ERROR (GLOBAL_TIMEOUT)" in last_msg.content or "AGENT ERROR (STEP_LIMIT)" in last_msg.content:
279
- return "end"
280
-
281
- # Handle all other errors
282
- if "AGENT ERROR" in last_msg.content:
283
- return "end"
284
-
285
- # Route to tools if tool calls exist
286
- if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
287
- return "tools"
288
-
289
- # End if final answer is present
290
- if "FINAL ANSWER" in last_msg.content:
291
- return "end"
292
-
293
- # Continue to agent otherwise
294
- return "agent"
295
-
296
- workflow.add_conditional_edges(
297
- "agent",
298
- should_continue,
299
- {
300
- "agent": "agent",
301
- "tools": "tools",
302
- "end": END
303
- }
304
- )
305
-
306
- # 7. Define flow after tool node
307
- workflow.add_edge("tools", "agent")
308
-
309
- # 8. Compile graph
310
- return workflow.compile()
311
-
312
- # Initialize agent graph
313
- agent_graph = build_graph()
314
-
315
- # Wrapper function to ensure execution within time limits
316
- def run_agent(question):
317
- # Create initial state with all required fields
318
- initial_state = {
319
- "messages": [
320
- SystemMessage(content=system_prompt),
321
- HumanMessage(content=question)
322
- ],
323
- "step_count": 0,
324
- "start_time": time.time(),
325
- "last_action": "start"
326
- }
327
-
328
- # Run with overall timeout
329
- start_time = time.time()
330
- result = None
331
- end_state_reached = False
332
-
333
- try:
334
- # Execute with 2-minute overall timeout
335
- for step in agent_graph.stream(initial_state):
336
- # Check overall timeout every step
337
- if time.time() - start_time > 120:
338
- return {"error": "Overall execution timeout (2 minutes)"}
339
-
340
- # Capture the final state when the graph completes
341
- if END in step:
342
- result = step[END]
343
- end_state_reached = True
344
- break
345
- except Exception as e:
346
- return {"error": f"Execution failed: {str(e)}"}
347
-
348
- # Extract final answer safely
349
- if end_state_reached and result is not None:
350
- if "messages" in result and result["messages"]:
351
- return {"answer": result["messages"][-1].content}
352
- else:
353
- return {"error": "Agent finished but produced no messages"}
354
- else:
355
- return {"error": "Agent did not complete execution"}
356
-
357
- # 示例调用函数(在app.py中使用)
358
- def process_question(question):
359
- response = run_agent(question)
360
- if "answer" in response:
361
- return response["answer"]
362
- elif "error" in response:
363
- return f"Error: {response['error']}"
364
- else:
365
- return "Unexpected response format"