vtony commited on
Commit
2a3fbbd
·
verified ·
1 Parent(s): 7033d04

Upload 2 files

Browse files
Files changed (2) hide show
  1. agent.py +233 -0
  2. system_prompt.txt +32 -0
agent.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ from dotenv import load_dotenv
5
+ from langgraph.graph import StateGraph, END
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_community.tools import DuckDuckGoSearchRun
8
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
9
+ from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
10
+ from langchain_core.tools import tool
11
+ from tenacity import retry, stop_after_attempt, wait_exponential
12
+ from typing import TypedDict, Annotated, Sequence
13
+ import operator
14
+
15
+ # Load environment variables
16
+ load_dotenv()
17
+ google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
18
+ if not google_api_key:
19
+ raise ValueError("Missing GOOGLE_API_KEY environment variable")
20
+
21
+ # --- Math Tools ---
22
+ @tool
23
+ def multiply(a: int, b: int) -> int:
24
+ """Multiply two integers."""
25
+ return a * b
26
+
27
+ @tool
28
+ def add(a: int, b: int) -> int:
29
+ """Add two integers."""
30
+ return a + b
31
+
32
+ @tool
33
+ def subtract(a: int, b: int) -> int:
34
+ """Subtract b from a."""
35
+ return a - b
36
+
37
+ @tool
38
+ def divide(a: int, b: int) -> float:
39
+ """Divide a by b, error on zero."""
40
+ if b == 0:
41
+ raise ValueError("Cannot divide by zero.")
42
+ return a / b
43
+
44
+ @tool
45
+ def modulus(a: int, b: int) -> int:
46
+ """Compute a mod b."""
47
+ return a % b
48
+
49
+ # --- Browser Tools ---
50
+ @tool
51
+ def wiki_search(query: str) -> str:
52
+ """Search Wikipedia and return up to 3 relevant documents."""
53
+ try:
54
+ docs = WikipediaLoader(query=query, load_max_docs=3).load()
55
+ if not docs:
56
+ return "No Wikipedia results found."
57
+
58
+ results = []
59
+ for doc in docs:
60
+ title = doc.metadata.get('title', 'Unknown Title')
61
+ content = doc.page_content[:2000] # Limit content length
62
+ results.append(f"Title: {title}\nContent: {content}")
63
+
64
+ return "\n\n---\n\n".join(results)
65
+ except Exception as e:
66
+ return f"Wikipedia search error: {str(e)}"
67
+
68
+ @tool
69
+ def arxiv_search(query: str) -> str:
70
+ """Search Arxiv and return up to 3 relevant papers."""
71
+ try:
72
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
73
+ if not docs:
74
+ return "No arXiv papers found."
75
+
76
+ results = []
77
+ for doc in docs:
78
+ title = doc.metadata.get('Title', 'Unknown Title')
79
+ authors = ", ".join(doc.metadata.get('Authors', []))
80
+ content = doc.page_content[:2000] # Limit content length
81
+ results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
82
+
83
+ return "\n\n---\n\n".join(results)
84
+ except Exception as e:
85
+ return f"arXiv search error: {str(e)}"
86
+
87
+ @tool
88
+ def web_search(query: str) -> str:
89
+ """Search the web using DuckDuckGo and return top results."""
90
+ try:
91
+ search = DuckDuckGoSearchRun()
92
+ result = search.run(query)
93
+ return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length
94
+ except Exception as e:
95
+ return f"Web search error: {str(e)}"
96
+
97
+ # --- Load system prompt ---
98
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
99
+ system_prompt = f.read()
100
+
101
+ # --- Tool Setup ---
102
+ tools = [
103
+ multiply,
104
+ add,
105
+ subtract,
106
+ divide,
107
+ modulus,
108
+ wiki_search,
109
+ arxiv_search,
110
+ web_search,
111
+ ]
112
+
113
+ # --- Graph Builder ---
114
+ def build_graph():
115
+ # Initialize model with Gemini 2.5 Flash
116
+ llm = ChatGoogleGenerativeAI(
117
+ model="gemini-2.5-flash",
118
+ temperature=0.3,
119
+ google_api_key=google_api_key,
120
+ max_retries=3
121
+ )
122
+
123
+ # Bind tools to LLM
124
+ llm_with_tools = llm.bind_tools(tools)
125
+
126
+ # 1. 定义状态结构
127
+ class AgentState(TypedDict):
128
+ messages: Annotated[Sequence, operator.add]
129
+
130
+ # 2. 创建图
131
+ workflow = StateGraph(AgentState)
132
+
133
+ # 3. 定义节点函数
134
+ def agent_node(state: AgentState):
135
+ """主代理节点"""
136
+ try:
137
+ # 添加请求间隔
138
+ time.sleep(1)
139
+
140
+ # 带重试的调用
141
+ @retry(stop=stop_after_attempt(3),
142
+ wait=wait_exponential(multiplier=1, min=4, max=10))
143
+ def invoke_with_retry():
144
+ return llm_with_tools.invoke(state["messages"])
145
+
146
+ response = invoke_with_retry()
147
+ return {"messages": [response]}
148
+
149
+ except Exception as e:
150
+ error_type = "UNKNOWN"
151
+ if "429" in str(e):
152
+ error_type = "QUOTA_EXCEEDED"
153
+ elif "400" in str(e):
154
+ error_type = "INVALID_REQUEST"
155
+
156
+ error_msg = f"AGENT ERROR ({error_type}): {str(e)[:200]}"
157
+ return {"messages": [AIMessage(content=error_msg)]}
158
+
159
+ def tool_node(state: AgentState):
160
+ """工具执行节点"""
161
+ last_msg = state["messages"][-1]
162
+ tool_calls = last_msg.additional_kwargs.get("tool_calls", [])
163
+
164
+ responses = []
165
+ for call in tool_calls:
166
+ tool_name = call["function"]["name"]
167
+ tool_args = call["function"].get("arguments", {})
168
+
169
+ # 查找工具
170
+ tool_func = next((t for t in tools if t.name == tool_name), None)
171
+ if not tool_func:
172
+ responses.append(f"Tool {tool_name} not available")
173
+ continue
174
+
175
+ try:
176
+ # 解析参数
177
+ if isinstance(tool_args, str):
178
+ tool_args = json.loads(tool_args)
179
+
180
+ # 执行工具
181
+ result = tool_func.invoke(tool_args)
182
+ responses.append(f"{tool_name} result: {result[:1000]}") # 限制结果长度
183
+ except Exception as e:
184
+ responses.append(f"{tool_name} error: {str(e)}")
185
+
186
+ # 修复括号错误:确保正确关闭所有括号
187
+ tool_response_content = "\n".join(responses)
188
+ return {"messages": [AIMessage(content=tool_response_content)]}
189
+
190
+ # 4. 添加节点到工作流
191
+ workflow.add_node("agent", agent_node)
192
+ workflow.add_node("tools", tool_node)
193
+
194
+ # 5. 设置入口点
195
+ workflow.set_entry_point("agent")
196
+
197
+ # 6. 定义条件边
198
+ def should_continue(state: AgentState):
199
+ last_msg = state["messages"][-1]
200
+
201
+ # 错误情况直接结束
202
+ if "AGENT ERROR" in last_msg.content:
203
+ return "end"
204
+
205
+ # 有工具调用则转到工具节点
206
+ if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
207
+ return "tools"
208
+
209
+ # 包含最终答案则结束
210
+ if "FINAL ANSWER" in last_msg.content:
211
+ return "end"
212
+
213
+ # 其他情况继续代理处理
214
+ return "agent"
215
+
216
+ workflow.add_conditional_edges(
217
+ "agent",
218
+ should_continue,
219
+ {
220
+ "agent": "agent",
221
+ "tools": "tools",
222
+ "end": END
223
+ }
224
+ )
225
+
226
+ # 7. 定义工具节点后的流向
227
+ workflow.add_edge("tools", "agent")
228
+
229
+ # 8. 编译图
230
+ return workflow.compile()
231
+
232
+ # 初始化代理图
233
+ agent_graph = build_graph()
system_prompt.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are GAIA (General AI Assistant). Follow these rules:
2
+
3
+ 1. Always use available tools when needed
4
+ 2. Format responses as:
5
+ REASONING: <step-by-step logic>
6
+ FINAL ANSWER: <concise result>
7
+
8
+ Available Tools:
9
+ - Math tools: multiply, add, subtract, divide, modulus
10
+ - Search tools:
11
+ * wiki_search (Wikipedia) - for biographical and discography info
12
+ * arxiv_search (arXiv) - for scientific papers
13
+ * web_search (general web search) - for current info
14
+
15
+ For music artist discography queries (like Mercedes Sosa):
16
+ 1. Use wiki_search("Artist Name discography") to find the discography section
17
+ 2. Identify studio albums specifically (exclude live albums, compilations, etc.)
18
+ 3. Check release years carefully - note that reissues or remasters don't count as new studio albums
19
+ 4. If Wikipedia is unclear, use web_search for additional sources
20
+ 5. Count only albums released within the specified year range
21
+
22
+ Example for Mercedes Sosa question:
23
+ USER: How many studio albums were published by Mercedes Sosa between 2000 and 2009?
24
+ REASONING:
25
+ 1. Use wiki_search("Mercedes Sosa discography")
26
+ 2. Parse discography section and filter for studio albums
27
+ 3. Identify albums released between 2000-2009:
28
+ - Acústico (2002, studio)
29
+ - Corazón Libre (2005, studio)
30
+ - Cantora (2009, studio - recorded in 2009)
31
+ 4. Final count: 3 studio albums
32
+ FINAL ANSWER: 3