vtony commited on
Commit
600333b
·
verified ·
1 Parent(s): c2ba2d0

Upload 2 files

Browse files
Files changed (2) hide show
  1. agent.py +242 -0
  2. system_prompt.txt +32 -0
agent.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ from dotenv import load_dotenv
5
+ from langgraph.graph import StateGraph, END
6
+ from langgraph.prebuilt import ToolNode, tools_condition
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_community.tools import DuckDuckGoSearchRun
9
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
+ from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
11
+ from langchain_core.tools import tool
12
+ from tenacity import retry, stop_after_attempt, wait_exponential
13
+
14
+ # Load environment variables
15
+ load_dotenv()
16
+ google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
17
+ if not google_api_key:
18
+ raise ValueError("Missing GOOGLE_API_KEY environment variable")
19
+
20
+ # --- Math Tools ---
21
+ @tool
22
+ def multiply(a: int, b: int) -> int:
23
+ """Multiply two integers."""
24
+ return a * b
25
+
26
+ @tool
27
+ def add(a: int, b: int) -> int:
28
+ """Add two integers."""
29
+ return a + b
30
+
31
+ @tool
32
+ def subtract(a: int, b: int) -> int:
33
+ """Subtract b from a."""
34
+ return a - b
35
+
36
+ @tool
37
+ def divide(a: int, b: int) -> float:
38
+ """Divide a by b, error on zero."""
39
+ if b == 0:
40
+ raise ValueError("Cannot divide by zero.")
41
+ return a / b
42
+
43
+ @tool
44
+ def modulus(a: int, b: int) -> int:
45
+ """Compute a mod b."""
46
+ return a % b
47
+
48
+ # --- Browser Tools ---
49
+ @tool
50
+ def wiki_search(query: str) -> str:
51
+ """Search Wikipedia and return up to 3 relevant documents."""
52
+ try:
53
+ docs = WikipediaLoader(query=query, load_max_docs=3).load()
54
+ if not docs:
55
+ return "No Wikipedia results found."
56
+
57
+ results = []
58
+ for doc in docs:
59
+ title = doc.metadata.get('title', 'Unknown Title')
60
+ content = doc.page_content[:2000] # Limit content length
61
+ results.append(f"Title: {title}\nContent: {content}")
62
+
63
+ return "\n\n---\n\n".join(results)
64
+ except Exception as e:
65
+ return f"Wikipedia search error: {str(e)}"
66
+
67
+ @tool
68
+ def arxiv_search(query: str) -> str:
69
+ """Search Arxiv and return up to 3 relevant papers."""
70
+ try:
71
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
72
+ if not docs:
73
+ return "No arXiv papers found."
74
+
75
+ results = []
76
+ for doc in docs:
77
+ title = doc.metadata.get('Title', 'Unknown Title')
78
+ authors = ", ".join(doc.metadata.get('Authors', []))
79
+ content = doc.page_content[:2000] # Limit content length
80
+ results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
81
+
82
+ return "\n\n---\n\n".join(results)
83
+ except Exception as e:
84
+ return f"arXiv search error: {str(e)}"
85
+
86
+ @tool
87
+ def web_search(query: str) -> str:
88
+ """Search the web using DuckDuckGo and return top results."""
89
+ try:
90
+ search = DuckDuckGoSearchRun()
91
+ result = search.run(query)
92
+ return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length
93
+ except Exception as e:
94
+ return f"Web search error: {str(e)}"
95
+
96
+ # --- Load system prompt ---
97
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
98
+ system_prompt = f.read()
99
+
100
+ # --- System message ---
101
+ sys_msg = SystemMessage(content=system_prompt)
102
+
103
+ # --- Tool Setup ---
104
+ tools = [
105
+ multiply,
106
+ add,
107
+ subtract,
108
+ divide,
109
+ modulus,
110
+ wiki_search,
111
+ arxiv_search,
112
+ web_search,
113
+ ]
114
+
115
+ # --- Graph Builder ---
116
+ def build_graph():
117
+ # Initialize model with Gemini 2.5 Flash - the latest and best FREE model
118
+ llm = ChatGoogleGenerativeAI(
119
+ model="gemini-2.5-flash", # Corrected to the latest free model
120
+ temperature=0.3,
121
+ google_api_key=google_api_key,
122
+ max_retries=3
123
+ )
124
+
125
+ # Bind tools to LLM
126
+ llm_with_tools = llm.bind_tools(tools)
127
+
128
+ # Define state with proper initialization
129
+ class AgentState:
130
+ def __init__(self, messages):
131
+ self.messages = messages
132
+
133
+ # Node definitions with error handling
134
+ def agent_node(state):
135
+ """Main agent node that processes messages with retry logic"""
136
+ try:
137
+ # Add rate limiting
138
+ time.sleep(1) # 1 second delay between requests
139
+
140
+ # Add retry logic for API quota issues
141
+ @retry(stop=stop_after_attempt(3),
142
+ wait=wait_exponential(multiplier=1, min=4, max=10))
143
+ def invoke_llm_with_retry():
144
+ return llm_with_tools.invoke(state.messages)
145
+
146
+ response = invoke_llm_with_retry()
147
+ return AgentState(messages=state.messages + [response])
148
+
149
+ except Exception as e:
150
+ # Handle specific errors
151
+ error_type = "UNKNOWN"
152
+ if "429" in str(e):
153
+ error_type = "QUOTA_EXCEEDED"
154
+ elif "400" in str(e):
155
+ error_type = "INVALID_REQUEST"
156
+
157
+ error_msg = f"AGENT ERROR ({error_type}): {str(e)[:200]}"
158
+ return AgentState(messages=state.messages + [AIMessage(content=error_msg)])
159
+
160
+ # Tool node
161
+ def tool_node(state):
162
+ """Execute tools based on agent's request"""
163
+ last_message = state.messages[-1]
164
+ tool_calls = last_message.additional_kwargs.get("tool_calls", [])
165
+
166
+ tool_responses = []
167
+ for tool_call in tool_calls:
168
+ tool_name = tool_call["function"]["name"]
169
+ tool_args = tool_call["function"].get("arguments", {})
170
+
171
+ # Find the tool
172
+ tool_func = next((t for t in tools if t.name == tool_name), None)
173
+ if not tool_func:
174
+ tool_responses.append(f"Tool {tool_name} not found")
175
+ continue
176
+
177
+ try:
178
+ # Execute the tool
179
+ if isinstance(tool_args, str):
180
+ # Parse JSON if arguments are in string format
181
+ tool_args = json.loads(tool_args)
182
+
183
+ result = tool_func.invoke(tool_args)
184
+ tool_responses.append(f"Tool {tool_name} result: {result}")
185
+ except Exception as e:
186
+ tool_responses.append(f"Tool {tool_name} error: {str(e)}")
187
+
188
+ tool_response_content = "\n".join(tool_responses)
189
+ return AgentState(messages=state.messages + [AIMessage(content=tool_response_content)])
190
+
191
+ # Custom condition function
192
+ def should_continue(state):
193
+ last_message = state.messages[-1]
194
+
195
+ # If there was an error, end
196
+ if "AGENT ERROR" in last_message.content:
197
+ return "end"
198
+
199
+ # Check for tool calls
200
+ if hasattr(last_message, "tool_calls") and last_message.tool_calls:
201
+ return "tools"
202
+
203
+ # Check for final answer
204
+ if "FINAL ANSWER" in last_message.content:
205
+ return "end"
206
+
207
+ # Otherwise, continue to agent
208
+ return "agent"
209
+
210
+ # Build the graph
211
+ workflow = StateGraph(AgentState)
212
+
213
+ # Add nodes
214
+ workflow.add_node("agent", agent_node)
215
+ workflow.add_node("tools", tool_node)
216
+
217
+ # Set entry point
218
+ workflow.set_entry_point("agent")
219
+
220
+ # Define edges
221
+ workflow.add_conditional_edges(
222
+ "agent",
223
+ should_continue,
224
+ {
225
+ "agent": "agent",
226
+ "tools": "tools",
227
+ "end": END
228
+ }
229
+ )
230
+
231
+ workflow.add_conditional_edges(
232
+ "tools",
233
+ lambda state: "agent",
234
+ {
235
+ "agent": "agent"
236
+ }
237
+ )
238
+
239
+ return workflow.compile()
240
+
241
+ # Initialize the agent graph
242
+ agent_graph = build_graph()
system_prompt.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are GAIA (General AI Assistant). Follow these rules:
2
+
3
+ 1. Always use available tools when needed
4
+ 2. Format responses as:
5
+ REASONING: <step-by-step logic>
6
+ FINAL ANSWER: <concise result>
7
+
8
+ Available Tools:
9
+ - Math tools: multiply, add, subtract, divide, modulus
10
+ - Search tools:
11
+ * wiki_search (Wikipedia) - for biographical and discography info
12
+ * arxiv_search (arXiv) - for scientific papers
13
+ * web_search (general web search) - for current info
14
+
15
+ For music artist discography queries (like Mercedes Sosa):
16
+ 1. Use wiki_search("Artist Name discography") to find the discography section
17
+ 2. Identify studio albums specifically (exclude live albums, compilations, etc.)
18
+ 3. Check release years carefully - note that reissues or remasters don't count as new studio albums
19
+ 4. If Wikipedia is unclear, use web_search for additional sources
20
+ 5. Count only albums released within the specified year range
21
+
22
+ Example for Mercedes Sosa question:
23
+ USER: How many studio albums were published by Mercedes Sosa between 2000 and 2009?
24
+ REASONING:
25
+ 1. Use wiki_search("Mercedes Sosa discography")
26
+ 2. Parse discography section and filter for studio albums
27
+ 3. Identify albums released between 2000-2009:
28
+ - Acústico (2002, studio)
29
+ - Corazón Libre (2005, studio)
30
+ - Cantora (2009, studio - recorded in 2009)
31
+ 4. Final count: 3 studio albums
32
+ FINAL ANSWER: 3