vtony commited on
Commit
c1b8918
·
verified ·
1 Parent(s): 697c132

Upload agent.py

Browse files
Files changed (1) hide show
  1. agent.py +243 -0
agent.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import StateGraph, END
5
+ from langgraph.prebuilt import ToolNode, tools_condition
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_community.tools import DuckDuckGoSearchRun
8
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
9
+ from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
10
+ from langchain_core.tools import tool
11
+ from tenacity import retry, stop_after_attempt, wait_exponential
12
+
13
+ # Load environment variables
14
+ load_dotenv()
15
+ google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
16
+ if not google_api_key:
17
+ raise ValueError("Missing GOOGLE_API_KEY environment variable")
18
+
19
+ # --- Math Tools ---
20
+ @tool
21
+ def multiply(a: int, b: int) -> int:
22
+ """Multiply two integers."""
23
+ return a * b
24
+
25
+ @tool
26
+ def add(a: int, b: int) -> int:
27
+ """Add two integers."""
28
+ return a + b
29
+
30
+ @tool
31
+ def subtract(a: int, b: int) -> int:
32
+ """Subtract b from a."""
33
+ return a - b
34
+
35
+ @tool
36
+ def divide(a: int, b: int) -> float:
37
+ """Divide a by b, error on zero."""
38
+ if b == 0:
39
+ raise ValueError("Cannot divide by zero.")
40
+ return a / b
41
+
42
+ @tool
43
+ def modulus(a: int, b: int) -> int:
44
+ """Compute a mod b."""
45
+ return a % b
46
+
47
+ # --- Browser Tools ---
48
+ @tool
49
+ def wiki_search(query: str) -> str:
50
+ """Search Wikipedia and return up to 3 relevant documents."""
51
+ try:
52
+ docs = WikipediaLoader(query=query, load_max_docs=3).load()
53
+ if not docs:
54
+ return "No Wikipedia results found."
55
+
56
+ results = []
57
+ for doc in docs:
58
+ title = doc.metadata.get('title', 'Unknown Title')
59
+ content = doc.page_content[:2000] # Limit content length
60
+ results.append(f"Title: {title}\nContent: {content}")
61
+
62
+ return "\n\n---\n\n".join(results)
63
+ except Exception as e:
64
+ return f"Wikipedia search error: {str(e)}"
65
+
66
+ @tool
67
+ def arxiv_search(query: str) -> str:
68
+ """Search Arxiv and return up to 3 relevant papers."""
69
+ try:
70
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
71
+ if not docs:
72
+ return "No arXiv papers found."
73
+
74
+ results = []
75
+ for doc in docs:
76
+ title = doc.metadata.get('Title', 'Unknown Title')
77
+ authors = ", ".join(doc.metadata.get('Authors', []))
78
+ content = doc.page_content[:2000] # Limit content length
79
+ results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
80
+
81
+ return "\n\n---\n\n".join(results)
82
+ except Exception as e:
83
+ return f"arXiv search error: {str(e)}"
84
+
85
+ @tool
86
+ def web_search(query: str) -> str:
87
+ """Search the web using DuckDuckGo and return top results."""
88
+ try:
89
+ search = DuckDuckGoSearchRun()
90
+ result = search.run(query)
91
+ return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length
92
+ except Exception as e:
93
+ return f"Web search error: {str(e)}"
94
+
95
+ # --- Load system prompt ---
96
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
97
+ system_prompt = f.read()
98
+
99
+ # --- System message ---
100
+ sys_msg = SystemMessage(content=system_prompt)
101
+
102
+ # --- Tool Setup ---
103
+ tools = [
104
+ multiply,
105
+ add,
106
+ subtract,
107
+ divide,
108
+ modulus,
109
+ wiki_search,
110
+ arxiv_search,
111
+ web_search,
112
+ ]
113
+
114
+ # --- Graph Builder ---
115
+ def build_graph():
116
+ # Initialize model (Gemini 2.0 Flash)
117
+ llm = ChatGoogleGenerativeAI(
118
+ model="gemini-2.0-flash-exp",
119
+ temperature=0.3,
120
+ google_api_key=google_api_key,
121
+ max_retries=3
122
+ )
123
+
124
+ # Bind tools to LLM
125
+ llm_with_tools = llm.bind_tools(tools)
126
+
127
+ # Define state
128
+ class AgentState:
129
+ def __init__(self, messages):
130
+ self.messages = messages
131
+
132
+ # Node definitions with error handling
133
+ def agent_node(state: AgentState):
134
+ """Main agent node that processes messages with retry logic"""
135
+ try:
136
+ # Add rate limiting
137
+ time.sleep(1) # 1 second delay between requests
138
+
139
+ # Add retry logic for API quota issues
140
+ @retry(stop=stop_after_attempt(3),
141
+ wait=wait_exponential(multiplier=1, min=4, max=10))
142
+ def invoke_llm_with_retry():
143
+ return llm_with_tools.invoke(state.messages)
144
+
145
+ response = invoke_llm_with_retry()
146
+ return AgentState(state.messages + [response])
147
+
148
+ except Exception as e:
149
+ # Handle specific errors
150
+ error_type = "UNKNOWN"
151
+ if "429" in str(e):
152
+ error_type = "QUOTA_EXCEEDED"
153
+ elif "400" in str(e):
154
+ error_type = "INVALID_REQUEST"
155
+
156
+ error_msg = f"AGENT ERROR ({error_type}): {str(e)[:200]}"
157
+ return AgentState(state.messages + [AIMessage(content=error_msg)])
158
+
159
+ # Tool node
160
+ def tool_node(state: AgentState):
161
+ """Execute tools based on agent's request"""
162
+ last_message = state.messages[-1]
163
+ tool_calls = last_message.additional_kwargs.get("tool_calls", [])
164
+
165
+ tool_responses = []
166
+ for tool_call in tool_calls:
167
+ tool_name = tool_call["function"]["name"]
168
+ tool_args = tool_call["function"].get("arguments", {})
169
+
170
+ # Find the tool
171
+ tool_func = next((t for t in tools if t.name == tool_name), None)
172
+ if not tool_func:
173
+ tool_responses.append(f"Tool {tool_name} not found")
174
+ continue
175
+
176
+ try:
177
+ # Execute the tool
178
+ if isinstance(tool_args, str):
179
+ # Parse JSON if arguments are in string format
180
+ import json
181
+ tool_args = json.loads(tool_args)
182
+
183
+ result = tool_func.invoke(tool_args)
184
+ tool_responses.append(f"Tool {tool_name} result: {result}")
185
+ except Exception as e:
186
+ tool_responses.append(f"Tool {tool_name} error: {str(e)}")
187
+
188
+ # 修复括号错误
189
+ tool_response_content = "\n".join(tool_responses)
190
+ return AgentState(state.messages + [AIMessage(content=tool_response_content)])
191
+
192
+ # Custom condition function
193
+ def should_continue(state: AgentState):
194
+ last_message = state.messages[-1]
195
+
196
+ # If there was an error, end
197
+ if "AGENT ERROR" in last_message.content:
198
+ return "end"
199
+
200
+ # Check for tool calls
201
+ if hasattr(last_message, "tool_calls") and last_message.tool_calls:
202
+ return "tools"
203
+
204
+ # Check for final answer
205
+ if "FINAL ANSWER" in last_message.content:
206
+ return "end"
207
+
208
+ # Otherwise, continue to agent
209
+ return "agent"
210
+
211
+ # Build the graph
212
+ workflow = StateGraph(AgentState)
213
+
214
+ # Add nodes
215
+ workflow.add_node("agent", agent_node)
216
+ workflow.add_node("tools", tool_node)
217
+
218
+ # Set entry point
219
+ workflow.set_entry_point("agent")
220
+
221
+ # Define edges
222
+ workflow.add_conditional_edges(
223
+ "agent",
224
+ should_continue,
225
+ {
226
+ "agent": "agent",
227
+ "tools": "tools",
228
+ "end": END
229
+ }
230
+ )
231
+
232
+ workflow.add_conditional_edges(
233
+ "tools",
234
+ lambda state: "agent",
235
+ {
236
+ "agent": "agent"
237
+ }
238
+ )
239
+
240
+ return workflow.compile()
241
+
242
+ # Initialize the agent graph
243
+ agent_graph = build_graph()