vtony commited on
Commit
1f94a07
·
verified ·
1 Parent(s): a51d8cd

Delete agent.py

Browse files
Files changed (1) hide show
  1. agent.py +0 -231
agent.py DELETED
@@ -1,231 +0,0 @@
1
- import os
2
- import time
3
- import json
4
- from dotenv import load_dotenv
5
- from langgraph.graph import StateGraph, END
6
- from langchain_google_genai import ChatGoogleGenerativeAI
7
- from langchain_community.tools import DuckDuckGoSearchRun
8
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
9
- from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
10
- from langchain_core.tools import tool
11
- from tenacity import retry, stop_after_attempt, wait_exponential
12
- from typing import TypedDict, Annotated, Sequence
13
- import operator
14
-
15
- # Load environment variables
16
- load_dotenv()
17
- google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
18
- if not google_api_key:
19
- raise ValueError("Missing GOOGLE_API_KEY environment variable")
20
-
21
- # --- Math Tools ---
22
- @tool
23
- def multiply(a: int, b: int) -> int:
24
- """Multiply two integers."""
25
- return a * b
26
-
27
- @tool
28
- def add(a: int, b: int) -> int:
29
- """Add two integers."""
30
- return a + b
31
-
32
- @tool
33
- def subtract(a: int, b: int) -> int:
34
- """Subtract b from a."""
35
- return a - b
36
-
37
- @tool
38
- def divide(a: int, b: int) -> float:
39
- """Divide a by b, error on zero."""
40
- if b == 0:
41
- raise ValueError("Cannot divide by zero.")
42
- return a / b
43
-
44
- @tool
45
- def modulus(a: int, b: int) -> int:
46
- """Compute a mod b."""
47
- return a % b
48
-
49
- # --- Browser Tools ---
50
- @tool
51
- def wiki_search(query: str) -> str:
52
- """Search Wikipedia and return up to 3 relevant documents."""
53
- try:
54
- docs = WikipediaLoader(query=query, load_max_docs=3).load()
55
- if not docs:
56
- return "No Wikipedia results found."
57
-
58
- results = []
59
- for doc in docs:
60
- title = doc.metadata.get('title', 'Unknown Title')
61
- content = doc.page_content[:2000] # Limit content length
62
- results.append(f"Title: {title}\nContent: {content}")
63
-
64
- return "\n\n---\n\n".join(results)
65
- except Exception as e:
66
- return f"Wikipedia search error: {str(e)}"
67
-
68
- @tool
69
- def arxiv_search(query: str) -> str:
70
- """Search Arxiv and return up to 3 relevant papers."""
71
- try:
72
- docs = ArxivLoader(query=query, load_max_docs=3).load()
73
- if not docs:
74
- return "No arXiv papers found."
75
-
76
- results = []
77
- for doc in docs:
78
- title = doc.metadata.get('Title', 'Unknown Title')
79
- authors = ", ".join(doc.metadata.get('Authors', []))
80
- content = doc.page_content[:2000] # Limit content length
81
- results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
82
-
83
- return "\n\n---\n\n".join(results)
84
- except Exception as e:
85
- return f"arXiv search error: {str(e)}"
86
-
87
- @tool
88
- def web_search(query: str) -> str:
89
- """Search the web using DuckDuckGo and return top results."""
90
- try:
91
- search = DuckDuckGoSearchRun()
92
- result = search.run(query)
93
- return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length
94
- except Exception as e:
95
- return f"Web search error: {str(e)}"
96
-
97
- # --- Load system prompt ---
98
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
99
- system_prompt = f.read()
100
-
101
- # --- Tool Setup ---
102
- tools = [
103
- multiply,
104
- add,
105
- subtract,
106
- divide,
107
- modulus,
108
- wiki_search,
109
- arxiv_search,
110
- web_search,
111
- ]
112
-
113
- # --- Graph Builder ---
114
- def build_graph():
115
- # Initialize model with Gemini 2.5 Flash
116
- llm = ChatGoogleGenerativeAI(
117
- model="gemini-2.5-flash",
118
- temperature=0.3,
119
- google_api_key=google_api_key,
120
- max_retries=3
121
- )
122
-
123
- # Bind tools to LLM
124
- llm_with_tools = llm.bind_tools(tools)
125
-
126
- # 1. 定义状态结构
127
- class AgentState(TypedDict):
128
- messages: Annotated[Sequence, operator.add]
129
-
130
- # 2. 创建图
131
- workflow = StateGraph(AgentState)
132
-
133
- # 3. 定义节点函数
134
- def agent_node(state: AgentState):
135
- """主代理节点"""
136
- try:
137
- # 添加请求间隔
138
- time.sleep(1)
139
-
140
- # 带重试的调用
141
- @retry(stop=stop_after_attempt(3),
142
- wait=wait_exponential(multiplier=1, min=4, max=10))
143
- def invoke_with_retry():
144
- return llm_with_tools.invoke(state["messages"])
145
-
146
- response = invoke_with_retry()
147
- return {"messages": [response]}
148
-
149
- except Exception as e:
150
- error_type = "UNKNOWN"
151
- if "429" in str(e):
152
- error_type = "QUOTA_EXCEEDED"
153
- elif "400" in str(e):
154
- error_type = "INVALID_REQUEST"
155
-
156
- error_msg = f"AGENT ERROR ({error_type}): {str(e)[:200]}"
157
- return {"messages": [AIMessage(content=error_msg)]}
158
-
159
- def tool_node(state: AgentState):
160
- """工具执行节点"""
161
- last_msg = state["messages"][-1]
162
- tool_calls = last_msg.additional_kwargs.get("tool_calls", [])
163
-
164
- responses = []
165
- for call in tool_calls:
166
- tool_name = call["function"]["name"]
167
- tool_args = call["function"].get("arguments", {})
168
-
169
- # 查找工具
170
- tool_func = next((t for t in tools if t.name == tool_name), None)
171
- if not tool_func:
172
- responses.append(f"Tool {tool_name} not available")
173
- continue
174
-
175
- try:
176
- # 解析参数
177
- if isinstance(tool_args, str):
178
- tool_args = json.loads(tool_args)
179
-
180
- # 执行工具
181
- result = tool_func.invoke(tool_args)
182
- responses.append(f"{tool_name} result: {result[:1000]}") # 限制结果长度
183
- except Exception as e:
184
- responses.append(f"{tool_name} error: {str(e)}")
185
-
186
- return {"messages": [AIMessage(content="\n".join(responses)]}
187
-
188
- # 4. 添加节点到工作流
189
- workflow.add_node("agent", agent_node)
190
- workflow.add_node("tools", tool_node)
191
-
192
- # 5. 设置入口点
193
- workflow.set_entry_point("agent")
194
-
195
- # 6. 定义条件边
196
- def should_continue(state: AgentState):
197
- last_msg = state["messages"][-1]
198
-
199
- # 错误情况直接结束
200
- if "AGENT ERROR" in last_msg.content:
201
- return "end"
202
-
203
- # 有工具调用则转到工具节点
204
- if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
205
- return "tools"
206
-
207
- # 包含最终答案则结束
208
- if "FINAL ANSWER" in last_msg.content:
209
- return "end"
210
-
211
- # 其他情况继续代理处理
212
- return "agent"
213
-
214
- workflow.add_conditional_edges(
215
- "agent",
216
- should_continue,
217
- {
218
- "agent": "agent",
219
- "tools": "tools",
220
- "end": END
221
- }
222
- )
223
-
224
- # 7. 定义工具节点后的流向
225
- workflow.add_edge("tools", "agent")
226
-
227
- # 8. 编译图
228
- return workflow.compile()
229
-
230
- # 初始化代理图
231
- agent_graph = build_graph()