vtony commited on
Commit
c9fdd06
·
verified ·
1 Parent(s): 6553704

Upload 2 files

Browse files
Files changed (2) hide show
  1. agent.py +257 -0
  2. requirements.txt +9 -0
agent.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import logging
5
+ from dotenv import load_dotenv
6
+ from langgraph.graph import StateGraph, END
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_community.tools import DuckDuckGoSearchRun
9
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
+ from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
11
+ from langchain_core.tools import tool
12
+ from tenacity import retry, stop_after_attempt, wait_exponential
13
+ from typing import TypedDict, Annotated, Sequence
14
+ import operator
15
+
16
+ # 配置日志
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger("GAIA_Agent")
19
+
20
+ # 加载环境变量
21
+ load_dotenv()
22
+ google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
23
+ if not google_api_key:
24
+ raise ValueError("Missing GOOGLE_API_KEY environment variable")
25
+
26
+ # --- Math Tools ---
27
+ @tool
28
+ def multiply(a: int, b: int) -> int:
29
+ """Multiply two integers."""
30
+ return a * b
31
+
32
+ @tool
33
+ def add(a: int, b: int) -> int:
34
+ """Add two integers."""
35
+ return a + b
36
+
37
+ @tool
38
+ def subtract(a: int, b: int) -> int:
39
+ """Subtract b from a."""
40
+ return a - b
41
+
42
+ @tool
43
+ def divide(a: int, b: int) -> float:
44
+ """Divide a by b, error on zero."""
45
+ if b == 0:
46
+ raise ValueError("Cannot divide by zero.")
47
+ return a / b
48
+
49
+ @tool
50
+ def modulus(a: int, b: int) -> int:
51
+ """Compute a mod b."""
52
+ return a % b
53
+
54
+ # --- Browser Tools ---
55
+ @tool
56
+ def wiki_search(query: str) -> str:
57
+ """Search Wikipedia and return up to 3 relevant documents."""
58
+ try:
59
+ # 确保查询包含"discography"关键词
60
+ if "discography" not in query.lower():
61
+ query = f"{query} discography"
62
+
63
+ docs = WikipediaLoader(query=query, load_max_docs=3).load()
64
+ if not docs:
65
+ return "No Wikipedia results found."
66
+
67
+ results = []
68
+ for doc in docs:
69
+ title = doc.metadata.get('title', 'Unknown Title')
70
+ content = doc.page_content[:2000] # 限制内容长度
71
+ results.append(f"Title: {title}\nContent: {content}")
72
+
73
+ return "\n\n---\n\n".join(results)
74
+ except Exception as e:
75
+ return f"Wikipedia search error: {str(e)}"
76
+
77
+ @tool
78
+ def arxiv_search(query: str) -> str:
79
+ """Search Arxiv and return up to 3 relevant papers."""
80
+ try:
81
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
82
+ if not docs:
83
+ return "No arXiv papers found."
84
+
85
+ results = []
86
+ for doc in docs:
87
+ title = doc.metadata.get('Title', 'Unknown Title')
88
+ authors = ", ".join(doc.metadata.get('Authors', []))
89
+ content = doc.page_content[:2000] # 限制内容长度
90
+ results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
91
+
92
+ return "\n\n---\n\n".join(results)
93
+ except Exception as e:
94
+ return f"arXiv search error: {str(e)}"
95
+
96
+ @tool
97
+ def web_search(query: str) -> str:
98
+ """Search the web using DuckDuckGo and return top results."""
99
+ try:
100
+ search = DuckDuckGoSearchRun()
101
+ result = search.run(query)
102
+ return f"Web search results for '{query}':\n{result[:2000]}" # 限制内容长度
103
+ except Exception as e:
104
+ return f"Web search error: {str(e)}"
105
+
106
+ # --- 加载系统提示 ---
107
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
108
+ system_prompt = f.read()
109
+
110
+ # --- 工具设置 ---
111
+ tools = [
112
+ multiply,
113
+ add,
114
+ subtract,
115
+ divide,
116
+ modulus,
117
+ wiki_search,
118
+ arxiv_search,
119
+ web_search,
120
+ ]
121
+
122
+ # --- 图构建器 ---
123
+ def build_graph():
124
+ # 初始化模型 - 使用指定的 gemini-2.5-flash
125
+ llm = ChatGoogleGenerativeAI(
126
+ model="gemini-2.5-flash", # 保持使用指定的模型
127
+ temperature=0.3,
128
+ google_api_key=google_api_key,
129
+ max_retries=5,
130
+ request_timeout=60
131
+ )
132
+
133
+ # 绑定工具到LLM
134
+ llm_with_tools = llm.bind_tools(tools)
135
+
136
+ # 1. 定义状态结构
137
+ class AgentState(TypedDict):
138
+ messages: Annotated[Sequence, operator.add]
139
+ retry_count: int
140
+
141
+ # 2. 创建图
142
+ workflow = StateGraph(AgentState)
143
+
144
+ # 3. 定义节点函数
145
+ def agent_node(state: AgentState):
146
+ """主代理节点"""
147
+ try:
148
+ # 添加请求间隔
149
+ time.sleep(2)
150
+
151
+ # 带重试的调用
152
+ @retry(stop=stop_after_attempt(5),
153
+ wait=wait_exponential(multiplier=1, min=4, max=30))
154
+ def invoke_with_retry():
155
+ return llm_with_tools.invoke(state["messages"])
156
+
157
+ response = invoke_with_retry()
158
+ return {"messages": [response], "retry_count": 0}
159
+
160
+ except Exception as e:
161
+ error_details = f"Gemini API Error: {type(e).__name__}: {str(e)}"
162
+ logger.error(error_details)
163
+
164
+ error_type = "UNKNOWN"
165
+ if "429" in str(e):
166
+ error_type = "QUOTA_EXCEEDED"
167
+ elif "400" in str(e):
168
+ error_type = "INVALID_REQUEST"
169
+ elif "503" in str(e):
170
+ error_type = "SERVICE_UNAVAILABLE"
171
+
172
+ new_retry_count = state.get("retry_count", 0) + 1
173
+ error_msg = f"AGENT ERROR ({error_type}): {error_details[:300]}"
174
+
175
+ if new_retry_count < 3:
176
+ error_msg += "\n\nWill retry after delay..."
177
+ else:
178
+ error_msg += "\n\nMax retries exceeded. Please try again later."
179
+
180
+ return {"messages": [AIMessage(content=error_msg)], "retry_count": new_retry_count}
181
+
182
+ def tool_node(state: AgentState):
183
+ """工具执行节点"""
184
+ last_msg = state["messages"][-1]
185
+ tool_calls = last_msg.additional_kwargs.get("tool_calls", [])
186
+
187
+ responses = []
188
+ for call in tool_calls:
189
+ tool_name = call["function"]["name"]
190
+ tool_args = call["function"].get("arguments", {})
191
+
192
+ tool_func = next((t for t in tools if t.name == tool_name), None)
193
+ if not tool_func:
194
+ responses.append(f"Tool {tool_name} not available")
195
+ continue
196
+
197
+ try:
198
+ if isinstance(tool_args, str):
199
+ try:
200
+ tool_args = json.loads(tool_args)
201
+ except json.JSONDecodeError:
202
+ if "query" in tool_args:
203
+ tool_args = {"query": tool_args}
204
+ else:
205
+ tool_args = {"query": tool_args}
206
+
207
+ result = tool_func.invoke(tool_args)
208
+ responses.append(f"{tool_name} result: {str(result)[:1000]}")
209
+ except Exception as e:
210
+ responses.append(f"{tool_name} error: {str(e)}")
211
+
212
+ tool_response_content = "\n".join(responses)
213
+ return {"messages": [AIMessage(content=tool_response_content)], "retry_count": 0}
214
+
215
+ # 4. 添加节点到工作流
216
+ workflow.add_node("agent", agent_node)
217
+ workflow.add_node("tools", tool_node)
218
+
219
+ # 5. 设置入口点
220
+ workflow.set_entry_point("agent")
221
+
222
+ # 6. 定义条件边
223
+ def should_continue(state: AgentState):
224
+ last_msg = state["messages"][-1]
225
+ retry_count = state.get("retry_count", 0)
226
+
227
+ if "AGENT ERROR" in last_msg.content:
228
+ if retry_count < 3:
229
+ return "agent"
230
+ return "end"
231
+
232
+ if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
233
+ return "tools"
234
+
235
+ if "FINAL ANSWER" in last_msg.content:
236
+ return "end"
237
+
238
+ return "agent"
239
+
240
+ workflow.add_conditional_edges(
241
+ "agent",
242
+ should_continue,
243
+ {
244
+ "agent": "agent",
245
+ "tools": "tools",
246
+ "end": END
247
+ }
248
+ )
249
+
250
+ # 7. 定义工具节点后的流向
251
+ workflow.add_edge("tools", "agent")
252
+
253
+ # 8. 编译图
254
+ return workflow.compile()
255
+
256
+ # 初始化代理图
257
+ agent_graph = build_graph()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ langchain-google-genai
2
+ langchain
3
+ langgraph
4
+ duckduckgo-search
5
+ wikipedia
6
+ arxiv
7
+ tenacity
8
+ python-dotenv
9
+ google-generativeai