vtony commited on
Commit
3cb3189
·
verified ·
1 Parent(s): 08bddd0

Upload agent.py

Browse files
Files changed (1) hide show
  1. agent.py +154 -0
agent.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from typing import TypedDict, Annotated, Sequence
4
+ import operator
5
+ from langgraph.graph import StateGraph, END
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_community.tools import DuckDuckGoSearchRun
9
+ from langchain_community.document_loaders import WikipediaLoader
10
+ from langchain_core.messages import BaseMessage, SystemMessage, AIMessage
11
+ from langchain.agents import Tool
12
+ from langchain_core.tools import tool
13
+
14
+ # Load environment variables
15
+ google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
16
+ if not google_api_key:
17
+ raise ValueError("Missing GOOGLE_API_KEY environment variable")
18
+
19
+ # Load system prompt
20
+ with open("System_Prompt.txt", "r", encoding="utf-8") as f:
21
+ system_prompt = f.read()
22
+ sys_msg = SystemMessage(content=system_prompt)
23
+
24
+ # Tool Definitions
25
+ @tool
26
+ def multiply(a: int, b: int) -> int:
27
+ """Multiply two integers together."""
28
+ return a * b
29
+
30
+ @tool
31
+ def add(a: int, b: int) -> int:
32
+ """Add two integers together."""
33
+ return a + b
34
+
35
+ @tool
36
+ def subtract(a: int, b: int) -> int:
37
+ """Subtract b from a."""
38
+ return a - b
39
+
40
+ @tool
41
+ def divide(a: int, b: int) -> float:
42
+ """Divide a by b. Returns float. Raises error if b is zero."""
43
+ if b == 0:
44
+ raise ValueError("Cannot divide by zero.")
45
+ return a / b
46
+
47
+ @tool
48
+ def wiki_search(query: str) -> str:
49
+ """Search Wikipedia and return up to 2 relevant documents."""
50
+ docs = WikipediaLoader(query=query, load_max_docs=2).load()
51
+ if not docs:
52
+ return "No Wikipedia results found."
53
+ return "\n\n".join([d.page_content[:1000] for d in docs])
54
+
55
+ # Tool inventory with valid names
56
+ tools = [
57
+ Tool(name="Math_Multiply", func=multiply, description="Multiplies two integers"),
58
+ Tool(name="Math_Add", func=add, description="Adds two integers"),
59
+ Tool(name="Math_Subtract", func=subtract, description="Subtracts two integers"),
60
+ Tool(name="Math_Divide", func=divide, description="Divides two numbers"),
61
+ Tool(name="Search_Wikipedia", func=wiki_search, description="Searches Wikipedia"),
62
+ Tool(
63
+ name="Search_Web",
64
+ func=DuckDuckGoSearchRun().run,
65
+ description="Searches the web using DuckDuckGo"
66
+ )
67
+ ]
68
+
69
+ # Graph Definition
70
+ class AgentState(TypedDict):
71
+ """State definition for the agent workflow"""
72
+ messages: Annotated[Sequence[BaseMessage], operator.add]
73
+
74
+ def build_graph():
75
+ """Constructs and compiles the LangGraph workflow"""
76
+
77
+ # Initialize LLM with Gemini 2.0 Flash and rate limiting
78
+ llm = ChatGoogleGenerativeAI(
79
+ model="gemini-2.0-flash-exp",
80
+ temperature=0.3,
81
+ google_api_key=google_api_key,
82
+ max_retries=3, # 添加内置重试
83
+ request_timeout=30 # 设置超时
84
+ )
85
+ llm_with_tools = llm.bind_tools(tools)
86
+
87
+ # Custom tool condition to avoid '__end__' issues
88
+ def custom_tools_condition(state: AgentState):
89
+ """Determines whether to use tools or end the workflow"""
90
+ last_message = state["messages"][-1]
91
+
92
+ # Check if it's an AI message with tool calls
93
+ if hasattr(last_message, "tool_calls") and last_message.tool_calls:
94
+ return "use_tools"
95
+
96
+ # Check if it's an error message
97
+ if isinstance(last_message, AIMessage) and "ERROR" in last_message.content:
98
+ return "end"
99
+
100
+ # Check if it's a final answer
101
+ if "FINAL ANSWER" in last_message.content:
102
+ return "end"
103
+
104
+ return "use_tools"
105
+
106
+ # Node definitions with error handling
107
+ def agent_node(state: AgentState):
108
+ """Main agent node that processes messages with retry logic"""
109
+ try:
110
+ # Add rate limiting
111
+ time.sleep(1) # 1 second delay between requests
112
+
113
+ response = llm_with_tools.invoke(state["messages"])
114
+ return {"messages": [response]}
115
+
116
+ except Exception as e:
117
+ # Handle specific errors
118
+ error_type = "UNKNOWN"
119
+ if "429" in str(e):
120
+ error_type = "QUOTA_EXCEEDED"
121
+ elif "400" in str(e):
122
+ error_type = "INVALID_REQUEST"
123
+
124
+ error_msg = f"AGENT ERROR ({error_type}): {str(e)[:200]}"
125
+ return {"messages": [AIMessage(content=error_msg)]}
126
+
127
+ # Graph construction
128
+ workflow = StateGraph(AgentState)
129
+
130
+ # Add nodes to the workflow
131
+ workflow.add_node("agent", agent_node)
132
+ workflow.add_node("tools", ToolNode(tools))
133
+
134
+ # Configure graph flow with clear endpoints
135
+ workflow.set_entry_point("agent")
136
+
137
+ # Add conditional edges with custom condition
138
+ workflow.add_conditional_edges(
139
+ "agent",
140
+ custom_tools_condition,
141
+ {
142
+ "use_tools": "tools",
143
+ "end": END # Directly to END
144
+ }
145
+ )
146
+
147
+ # Add edge from tools back to agent
148
+ workflow.add_edge("tools", "agent")
149
+
150
+ # Ensure END has no incoming edges except through condition
151
+ return workflow.compile()
152
+
153
+ # Initialize the agent graph
154
+ agent_graph = build_graph()