vtony commited on
Commit
69ec645
·
verified ·
1 Parent(s): 03f2355

Delete agent.py

Browse files
Files changed (1) hide show
  1. agent.py +0 -243
agent.py DELETED
@@ -1,243 +0,0 @@
1
- import os
2
- import time
3
- from dotenv import load_dotenv
4
- from langgraph.graph import StateGraph, END
5
- from langgraph.prebuilt import ToolNode, tools_condition
6
- from langchain_google_genai import ChatGoogleGenerativeAI
7
- from langchain_community.tools import DuckDuckGoSearchRun
8
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
9
- from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
10
- from langchain_core.tools import tool
11
- from tenacity import retry, stop_after_attempt, wait_exponential
12
-
13
- # Load environment variables
14
- load_dotenv()
15
- google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
16
- if not google_api_key:
17
- raise ValueError("Missing GOOGLE_API_KEY environment variable")
18
-
19
- # --- Math Tools ---
20
- @tool
21
- def multiply(a: int, b: int) -> int:
22
- """Multiply two integers."""
23
- return a * b
24
-
25
- @tool
26
- def add(a: int, b: int) -> int:
27
- """Add two integers."""
28
- return a + b
29
-
30
- @tool
31
- def subtract(a: int, b: int) -> int:
32
- """Subtract b from a."""
33
- return a - b
34
-
35
- @tool
36
- def divide(a: int, b: int) -> float:
37
- """Divide a by b, error on zero."""
38
- if b == 0:
39
- raise ValueError("Cannot divide by zero.")
40
- return a / b
41
-
42
- @tool
43
- def modulus(a: int, b: int) -> int:
44
- """Compute a mod b."""
45
- return a % b
46
-
47
- # --- Browser Tools ---
48
- @tool
49
- def wiki_search(query: str) -> str:
50
- """Search Wikipedia and return up to 3 relevant documents."""
51
- try:
52
- docs = WikipediaLoader(query=query, load_max_docs=3).load()
53
- if not docs:
54
- return "No Wikipedia results found."
55
-
56
- results = []
57
- for doc in docs:
58
- title = doc.metadata.get('title', 'Unknown Title')
59
- content = doc.page_content[:2000] # Limit content length
60
- results.append(f"Title: {title}\nContent: {content}")
61
-
62
- return "\n\n---\n\n".join(results)
63
- except Exception as e:
64
- return f"Wikipedia search error: {str(e)}"
65
-
66
- @tool
67
- def arxiv_search(query: str) -> str:
68
- """Search Arxiv and return up to 3 relevant papers."""
69
- try:
70
- docs = ArxivLoader(query=query, load_max_docs=3).load()
71
- if not docs:
72
- return "No arXiv papers found."
73
-
74
- results = []
75
- for doc in docs:
76
- title = doc.metadata.get('Title', 'Unknown Title')
77
- authors = ", ".join(doc.metadata.get('Authors', []))
78
- content = doc.page_content[:2000] # Limit content length
79
- results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
80
-
81
- return "\n\n---\n\n".join(results)
82
- except Exception as e:
83
- return f"arXiv search error: {str(e)}"
84
-
85
- @tool
86
- def web_search(query: str) -> str:
87
- """Search the web using DuckDuckGo and return top results."""
88
- try:
89
- search = DuckDuckGoSearchRun()
90
- result = search.run(query)
91
- return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length
92
- except Exception as e:
93
- return f"Web search error: {str(e)}"
94
-
95
- # --- Load system prompt ---
96
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
97
- system_prompt = f.read()
98
-
99
- # --- System message ---
100
- sys_msg = SystemMessage(content=system_prompt)
101
-
102
- # --- Tool Setup ---
103
- tools = [
104
- multiply,
105
- add,
106
- subtract,
107
- divide,
108
- modulus,
109
- wiki_search,
110
- arxiv_search,
111
- web_search,
112
- ]
113
-
114
- # --- Graph Builder ---
115
- def build_graph():
116
- # Initialize model (Gemini 2.0 Flash)
117
- llm = ChatGoogleGenerativeAI(
118
- model="gemini-2.0-flash-exp",
119
- temperature=0.3,
120
- google_api_key=google_api_key,
121
- max_retries=3
122
- )
123
-
124
- # Bind tools to LLM
125
- llm_with_tools = llm.bind_tools(tools)
126
-
127
- # Define state
128
- class AgentState:
129
- def __init__(self, messages):
130
- self.messages = messages
131
-
132
- # Node definitions with error handling
133
- def agent_node(state: AgentState):
134
- """Main agent node that processes messages with retry logic"""
135
- try:
136
- # Add rate limiting
137
- time.sleep(1) # 1 second delay between requests
138
-
139
- # Add retry logic for API quota issues
140
- @retry(stop=stop_after_attempt(3),
141
- wait=wait_exponential(multiplier=1, min=4, max=10))
142
- def invoke_llm_with_retry():
143
- return llm_with_tools.invoke(state.messages)
144
-
145
- response = invoke_llm_with_retry()
146
- return AgentState(state.messages + [response])
147
-
148
- except Exception as e:
149
- # Handle specific errors
150
- error_type = "UNKNOWN"
151
- if "429" in str(e):
152
- error_type = "QUOTA_EXCEEDED"
153
- elif "400" in str(e):
154
- error_type = "INVALID_REQUEST"
155
-
156
- error_msg = f"AGENT ERROR ({error_type}): {str(e)[:200]}"
157
- return AgentState(state.messages + [AIMessage(content=error_msg)])
158
-
159
- # Tool node
160
- def tool_node(state: AgentState):
161
- """Execute tools based on agent's request"""
162
- last_message = state.messages[-1]
163
- tool_calls = last_message.additional_kwargs.get("tool_calls", [])
164
-
165
- tool_responses = []
166
- for tool_call in tool_calls:
167
- tool_name = tool_call["function"]["name"]
168
- tool_args = tool_call["function"].get("arguments", {})
169
-
170
- # Find the tool
171
- tool_func = next((t for t in tools if t.name == tool_name), None)
172
- if not tool_func:
173
- tool_responses.append(f"Tool {tool_name} not found")
174
- continue
175
-
176
- try:
177
- # Execute the tool
178
- if isinstance(tool_args, str):
179
- # Parse JSON if arguments are in string format
180
- import json
181
- tool_args = json.loads(tool_args)
182
-
183
- result = tool_func.invoke(tool_args)
184
- tool_responses.append(f"Tool {tool_name} result: {result}")
185
- except Exception as e:
186
- tool_responses.append(f"Tool {tool_name} error: {str(e)}")
187
-
188
- # 修复括号错误
189
- tool_response_content = "\n".join(tool_responses)
190
- return AgentState(state.messages + [AIMessage(content=tool_response_content)])
191
-
192
- # Custom condition function
193
- def should_continue(state: AgentState):
194
- last_message = state.messages[-1]
195
-
196
- # If there was an error, end
197
- if "AGENT ERROR" in last_message.content:
198
- return "end"
199
-
200
- # Check for tool calls
201
- if hasattr(last_message, "tool_calls") and last_message.tool_calls:
202
- return "tools"
203
-
204
- # Check for final answer
205
- if "FINAL ANSWER" in last_message.content:
206
- return "end"
207
-
208
- # Otherwise, continue to agent
209
- return "agent"
210
-
211
- # Build the graph
212
- workflow = StateGraph(AgentState)
213
-
214
- # Add nodes
215
- workflow.add_node("agent", agent_node)
216
- workflow.add_node("tools", tool_node)
217
-
218
- # Set entry point
219
- workflow.set_entry_point("agent")
220
-
221
- # Define edges
222
- workflow.add_conditional_edges(
223
- "agent",
224
- should_continue,
225
- {
226
- "agent": "agent",
227
- "tools": "tools",
228
- "end": END
229
- }
230
- )
231
-
232
- workflow.add_conditional_edges(
233
- "tools",
234
- lambda state: "agent",
235
- {
236
- "agent": "agent"
237
- }
238
- )
239
-
240
- return workflow.compile()
241
-
242
- # Initialize the agent graph
243
- agent_graph = build_graph()