vtony commited on
Commit
f446307
·
verified ·
1 Parent(s): 230aeb3

Delete agent.py

Browse files
Files changed (1) hide show
  1. agent.py +0 -244
agent.py DELETED
@@ -1,244 +0,0 @@
1
- import os
2
- import time
3
- import json
4
- from dotenv import load_dotenv
5
- from langgraph.graph import StateGraph, END
6
- from langchain_google_genai import ChatGoogleGenerativeAI
7
- from langchain_community.tools import DuckDuckGoSearchRun
8
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
9
- from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
10
- from langchain_core.tools import tool
11
- from tenacity import retry, stop_after_attempt, wait_exponential
12
-
13
- # Load environment variables
14
- load_dotenv()
15
- google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
16
- if not google_api_key:
17
- raise ValueError("Missing GOOGLE_API_KEY environment variable")
18
-
19
- # --- Math Tools ---
20
- @tool
21
- def multiply(a: int, b: int) -> int:
22
- """Multiply two integers."""
23
- return a * b
24
-
25
- @tool
26
- def add(a: int, b: int) -> int:
27
- """Add two integers."""
28
- return a + b
29
-
30
- @tool
31
- def subtract(a: int, b: int) -> int:
32
- """Subtract b from a."""
33
- return a - b
34
-
35
- @tool
36
- def divide(a: int, b: int) -> float:
37
- """Divide a by b, error on zero."""
38
- if b == 0:
39
- raise ValueError("Cannot divide by zero.")
40
- return a / b
41
-
42
- @tool
43
- def modulus(a: int, b: int) -> int:
44
- """Compute a mod b."""
45
- return a % b
46
-
47
- # --- Browser Tools ---
48
- @tool
49
- def wiki_search(query: str) -> str:
50
- """Search Wikipedia and return up to 3 relevant documents."""
51
- try:
52
- docs = WikipediaLoader(query=query, load_max_docs=3).load()
53
- if not docs:
54
- return "No Wikipedia results found."
55
-
56
- results = []
57
- for doc in docs:
58
- title = doc.metadata.get('title', 'Unknown Title')
59
- content = doc.page_content[:2000] # Limit content length
60
- results.append(f"Title: {title}\nContent: {content}")
61
-
62
- return "\n\n---\n\n".join(results)
63
- except Exception as e:
64
- return f"Wikipedia search error: {str(e)}"
65
-
66
- @tool
67
- def arxiv_search(query: str) -> str:
68
- """Search Arxiv and return up to 3 relevant papers."""
69
- try:
70
- docs = ArxivLoader(query=query, load_max_docs=3).load()
71
- if not docs:
72
- return "No arXiv papers found."
73
-
74
- results = []
75
- for doc in docs:
76
- title = doc.metadata.get('Title', 'Unknown Title')
77
- authors = ", ".join(doc.metadata.get('Authors', []))
78
- content = doc.page_content[:2000] # Limit content length
79
- results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
80
-
81
- return "\n\n---\n\n".join(results)
82
- except Exception as e:
83
- return f"arXiv search error: {str(e)}"
84
-
85
- @tool
86
- def web_search(query: str) -> str:
87
- """Search the web using DuckDuckGo and return top results."""
88
- try:
89
- search = DuckDuckGoSearchRun()
90
- result = search.run(query)
91
- return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length
92
- except Exception as e:
93
- return f"Web search error: {str(e)}"
94
-
95
- # --- Load system prompt ---
96
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
97
- system_prompt = f.read()
98
-
99
- # --- System message ---
100
- sys_msg = SystemMessage(content=system_prompt)
101
-
102
- # --- Tool Setup ---
103
- tools = [
104
- multiply,
105
- add,
106
- subtract,
107
- divide,
108
- modulus,
109
- wiki_search,
110
- arxiv_search,
111
- web_search,
112
- ]
113
-
114
- # --- Graph Builder ---
115
- def build_graph():
116
- # Initialize model with Gemini 2.5 Flash
117
- llm = ChatGoogleGenerativeAI(
118
- model="gemini-2.5-flash",
119
- temperature=0.3,
120
- google_api_key=google_api_key,
121
- max_retries=3
122
- )
123
-
124
- # Bind tools to LLM
125
- llm_with_tools = llm.bind_tools(tools)
126
-
127
- # 使用 TypedDict 定义状态而不是自定义类
128
- from typing import TypedDict, Annotated, Sequence
129
- import operator
130
-
131
- class AgentState(TypedDict):
132
- messages: Annotated[Sequence[dict], operator.add]
133
-
134
- # Node definitions with error handling
135
- def agent_node(state: AgentState):
136
- """Main agent node that processes messages with retry logic"""
137
- try:
138
- # Add rate limiting
139
- time.sleep(1) # 1 second delay between requests
140
-
141
- # Add retry logic for API quota issues
142
- @retry(stop=stop_after_attempt(3),
143
- wait=wait_exponential(multiplier=1, min=4, max=10))
144
- def invoke_llm_with_retry():
145
- return llm_with_tools.invoke(state["messages"])
146
-
147
- response = invoke_llm_with_retry()
148
- return {"messages": [response]}
149
-
150
- except Exception as e:
151
- # Handle specific errors
152
- error_type = "UNKNOWN"
153
- if "429" in str(e):
154
- error_type = "QUOTA_EXCEEDED"
155
- elif "400" in str(e):
156
- error_type = "INVALID_REQUEST"
157
-
158
- error_msg = f"AGENT ERROR ({error_type}): {str(e)[:200]}"
159
- return {"messages": [AIMessage(content=error_msg)]}
160
-
161
- # Tool node
162
- def tool_node(state: AgentState):
163
- """Execute tools based on agent's request"""
164
- last_message = state["messages"][-1]
165
- tool_calls = last_message.additional_kwargs.get("tool_calls", [])
166
-
167
- tool_responses = []
168
- for tool_call in tool_calls:
169
- tool_name = tool_call["function"]["name"]
170
- tool_args = tool_call["function"].get("arguments", {})
171
-
172
- # Find the tool
173
- tool_func = next((t for t in tools if t.name == tool_name), None)
174
- if not tool_func:
175
- tool_responses.append(f"Tool {tool_name} not found")
176
- continue
177
-
178
- try:
179
- # Execute the tool
180
- if isinstance(tool_args, str):
181
- # Parse JSON if arguments are in string format
182
- tool_args = json.loads(tool_args)
183
-
184
- result = tool_func.invoke(tool_args)
185
- tool_responses.append(f"Tool {tool_name} result: {result}")
186
- except Exception as e:
187
- tool_responses.append(f"Tool {tool_name} error: {str(e)}")
188
-
189
- tool_response_content = "\n".join(tool_responses)
190
- return {"messages": [AIMessage(content=tool_response_content)]}
191
-
192
- # Custom condition function
193
- def should_continue(state: AgentState):
194
- last_message = state["messages"][-1]
195
-
196
- # If there was an error, end
197
- if "AGENT ERROR" in last_message.content:
198
- return "end"
199
-
200
- # Check for tool calls
201
- if hasattr(last_message, "tool_calls") and last_message.tool_calls:
202
- return "tools"
203
-
204
- # Check for final answer
205
- if "FINAL ANSWER" in last_message.content:
206
- return "end"
207
-
208
- # Otherwise, continue to agent
209
- return "agent"
210
-
211
- # Build the graph
212
- workflow = StateGraph(AgentState)
213
-
214
- # Add nodes
215
- workflow.add_node("agent", agent_node)
216
- workflow.add_node("tools", tool_node)
217
-
218
- # Set entry point
219
- workflow.set_entry_point("agent")
220
-
221
- # Define edges
222
- workflow.add_conditional_edges(
223
- "agent",
224
- should_continue,
225
- {
226
- "agent": "agent",
227
- "tools": "tools",
228
- "end": END
229
- }
230
- )
231
-
232
- workflow.add_conditional_edges(
233
- "tools",
234
- # Always go back to agent after using tools
235
- lambda state: "agent",
236
- {
237
- "agent": "agent"
238
- }
239
- )
240
-
241
- return workflow.compile()
242
-
243
- # Initialize the agent graph
244
- agent_graph = build_graph()