vtony commited on
Commit
d5d9834
·
verified ·
1 Parent(s): c9fdd06

Delete agent.py

Browse files
Files changed (1) hide show
  1. agent.py +0 -257
agent.py DELETED
@@ -1,257 +0,0 @@
1
- import os
2
- import time
3
- import json
4
- import logging
5
- from dotenv import load_dotenv
6
- from langgraph.graph import StateGraph, END
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_community.tools import DuckDuckGoSearchRun
9
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
- from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
11
- from langchain_core.tools import tool
12
- from tenacity import retry, stop_after_attempt, wait_exponential
13
- from typing import TypedDict, Annotated, Sequence
14
- import operator
15
-
16
- # 配置日志
17
- logging.basicConfig(level=logging.INFO)
18
- logger = logging.getLogger("GAIA_Agent")
19
-
20
- # 加载环境变量
21
- load_dotenv()
22
- google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
23
- if not google_api_key:
24
- raise ValueError("Missing GOOGLE_API_KEY environment variable")
25
-
26
- # --- Math Tools ---
27
- @tool
28
- def multiply(a: int, b: int) -> int:
29
- """Multiply two integers."""
30
- return a * b
31
-
32
- @tool
33
- def add(a: int, b: int) -> int:
34
- """Add two integers."""
35
- return a + b
36
-
37
- @tool
38
- def subtract(a: int, b: int) -> int:
39
- """Subtract b from a."""
40
- return a - b
41
-
42
- @tool
43
- def divide(a: int, b: int) -> float:
44
- """Divide a by b, error on zero."""
45
- if b == 0:
46
- raise ValueError("Cannot divide by zero.")
47
- return a / b
48
-
49
- @tool
50
- def modulus(a: int, b: int) -> int:
51
- """Compute a mod b."""
52
- return a % b
53
-
54
- # --- Browser Tools ---
55
- @tool
56
- def wiki_search(query: str) -> str:
57
- """Search Wikipedia and return up to 3 relevant documents."""
58
- try:
59
- # 确保查询包含"discography"关键词
60
- if "discography" not in query.lower():
61
- query = f"{query} discography"
62
-
63
- docs = WikipediaLoader(query=query, load_max_docs=3).load()
64
- if not docs:
65
- return "No Wikipedia results found."
66
-
67
- results = []
68
- for doc in docs:
69
- title = doc.metadata.get('title', 'Unknown Title')
70
- content = doc.page_content[:2000] # 限制内容长度
71
- results.append(f"Title: {title}\nContent: {content}")
72
-
73
- return "\n\n---\n\n".join(results)
74
- except Exception as e:
75
- return f"Wikipedia search error: {str(e)}"
76
-
77
- @tool
78
- def arxiv_search(query: str) -> str:
79
- """Search Arxiv and return up to 3 relevant papers."""
80
- try:
81
- docs = ArxivLoader(query=query, load_max_docs=3).load()
82
- if not docs:
83
- return "No arXiv papers found."
84
-
85
- results = []
86
- for doc in docs:
87
- title = doc.metadata.get('Title', 'Unknown Title')
88
- authors = ", ".join(doc.metadata.get('Authors', []))
89
- content = doc.page_content[:2000] # 限制内容长度
90
- results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
91
-
92
- return "\n\n---\n\n".join(results)
93
- except Exception as e:
94
- return f"arXiv search error: {str(e)}"
95
-
96
- @tool
97
- def web_search(query: str) -> str:
98
- """Search the web using DuckDuckGo and return top results."""
99
- try:
100
- search = DuckDuckGoSearchRun()
101
- result = search.run(query)
102
- return f"Web search results for '{query}':\n{result[:2000]}" # 限制内容长度
103
- except Exception as e:
104
- return f"Web search error: {str(e)}"
105
-
106
- # --- 加载系统提示 ---
107
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
108
- system_prompt = f.read()
109
-
110
- # --- 工具设置 ---
111
- tools = [
112
- multiply,
113
- add,
114
- subtract,
115
- divide,
116
- modulus,
117
- wiki_search,
118
- arxiv_search,
119
- web_search,
120
- ]
121
-
122
- # --- 图构建器 ---
123
- def build_graph():
124
- # 初始化模型 - 使用指定的 gemini-2.5-flash
125
- llm = ChatGoogleGenerativeAI(
126
- model="gemini-2.5-flash", # 保持使用指定的模型
127
- temperature=0.3,
128
- google_api_key=google_api_key,
129
- max_retries=5,
130
- request_timeout=60
131
- )
132
-
133
- # 绑定工具到LLM
134
- llm_with_tools = llm.bind_tools(tools)
135
-
136
- # 1. 定义状态结构
137
- class AgentState(TypedDict):
138
- messages: Annotated[Sequence, operator.add]
139
- retry_count: int
140
-
141
- # 2. 创建图
142
- workflow = StateGraph(AgentState)
143
-
144
- # 3. 定义节点函数
145
- def agent_node(state: AgentState):
146
- """主代理节点"""
147
- try:
148
- # 添加请求间隔
149
- time.sleep(2)
150
-
151
- # 带重试的调用
152
- @retry(stop=stop_after_attempt(5),
153
- wait=wait_exponential(multiplier=1, min=4, max=30))
154
- def invoke_with_retry():
155
- return llm_with_tools.invoke(state["messages"])
156
-
157
- response = invoke_with_retry()
158
- return {"messages": [response], "retry_count": 0}
159
-
160
- except Exception as e:
161
- error_details = f"Gemini API Error: {type(e).__name__}: {str(e)}"
162
- logger.error(error_details)
163
-
164
- error_type = "UNKNOWN"
165
- if "429" in str(e):
166
- error_type = "QUOTA_EXCEEDED"
167
- elif "400" in str(e):
168
- error_type = "INVALID_REQUEST"
169
- elif "503" in str(e):
170
- error_type = "SERVICE_UNAVAILABLE"
171
-
172
- new_retry_count = state.get("retry_count", 0) + 1
173
- error_msg = f"AGENT ERROR ({error_type}): {error_details[:300]}"
174
-
175
- if new_retry_count < 3:
176
- error_msg += "\n\nWill retry after delay..."
177
- else:
178
- error_msg += "\n\nMax retries exceeded. Please try again later."
179
-
180
- return {"messages": [AIMessage(content=error_msg)], "retry_count": new_retry_count}
181
-
182
- def tool_node(state: AgentState):
183
- """工具执行节点"""
184
- last_msg = state["messages"][-1]
185
- tool_calls = last_msg.additional_kwargs.get("tool_calls", [])
186
-
187
- responses = []
188
- for call in tool_calls:
189
- tool_name = call["function"]["name"]
190
- tool_args = call["function"].get("arguments", {})
191
-
192
- tool_func = next((t for t in tools if t.name == tool_name), None)
193
- if not tool_func:
194
- responses.append(f"Tool {tool_name} not available")
195
- continue
196
-
197
- try:
198
- if isinstance(tool_args, str):
199
- try:
200
- tool_args = json.loads(tool_args)
201
- except json.JSONDecodeError:
202
- if "query" in tool_args:
203
- tool_args = {"query": tool_args}
204
- else:
205
- tool_args = {"query": tool_args}
206
-
207
- result = tool_func.invoke(tool_args)
208
- responses.append(f"{tool_name} result: {str(result)[:1000]}")
209
- except Exception as e:
210
- responses.append(f"{tool_name} error: {str(e)}")
211
-
212
- tool_response_content = "\n".join(responses)
213
- return {"messages": [AIMessage(content=tool_response_content)], "retry_count": 0}
214
-
215
- # 4. 添加节点到工作流
216
- workflow.add_node("agent", agent_node)
217
- workflow.add_node("tools", tool_node)
218
-
219
- # 5. 设置入口点
220
- workflow.set_entry_point("agent")
221
-
222
- # 6. 定义条件边
223
- def should_continue(state: AgentState):
224
- last_msg = state["messages"][-1]
225
- retry_count = state.get("retry_count", 0)
226
-
227
- if "AGENT ERROR" in last_msg.content:
228
- if retry_count < 3:
229
- return "agent"
230
- return "end"
231
-
232
- if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
233
- return "tools"
234
-
235
- if "FINAL ANSWER" in last_msg.content:
236
- return "end"
237
-
238
- return "agent"
239
-
240
- workflow.add_conditional_edges(
241
- "agent",
242
- should_continue,
243
- {
244
- "agent": "agent",
245
- "tools": "tools",
246
- "end": END
247
- }
248
- )
249
-
250
- # 7. 定义工具节点后的流向
251
- workflow.add_edge("tools", "agent")
252
-
253
- # 8. 编译图
254
- return workflow.compile()
255
-
256
- # 初始化代理图
257
- agent_graph = build_graph()