vtony commited on
Commit
229c4ae
·
verified ·
1 Parent(s): f018873

Upload 2 files

Browse files
Files changed (2) hide show
  1. agent.py +264 -0
  2. requirements.txt +10 -0
agent.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import logging
5
+ from dotenv import load_dotenv
6
+ from langgraph.graph import StateGraph, END
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_community.tools import DuckDuckGoSearchRun
9
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
+ from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
11
+ from langchain_core.tools import tool
12
+ from tenacity import retry, stop_after_attempt, wait_exponential
13
+ from typing import TypedDict, Annotated, Sequence
14
+ import operator
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger("GAIA_Agent")
19
+
20
+ # Load environment variables
21
+ load_dotenv()
22
+ google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
23
+ if not google_api_key:
24
+ raise ValueError("Missing GOOGLE_API_KEY environment variable")
25
+
26
+ # --- Math Tools ---
27
+ @tool
28
+ def multiply(a: int, b: int) -> int:
29
+ """Multiply two integers."""
30
+ return a * b
31
+
32
+ @tool
33
+ def add(a: int, b: int) -> int:
34
+ """Add two integers."""
35
+ return a + b
36
+
37
+ @tool
38
+ def subtract(a: int, b: int) -> int:
39
+ """Subtract b from a."""
40
+ return a - b
41
+
42
+ @tool
43
+ def divide(a: int, b: int) -> float:
44
+ """Divide a by b, error on zero."""
45
+ if b == 0:
46
+ raise ValueError("Cannot divide by zero.")
47
+ return a / b
48
+
49
+ @tool
50
+ def modulus(a: int, b: int) -> int:
51
+ """Compute a mod b."""
52
+ return a % b
53
+
54
+ # --- Browser Tools ---
55
+ @tool
56
+ def wiki_search(query: str) -> str:
57
+ """Search Wikipedia and return up to 3 relevant documents."""
58
+ try:
59
+ # Ensure query contains "discography" keyword
60
+ if "discography" not in query.lower():
61
+ query = f"{query} discography"
62
+
63
+ docs = WikipediaLoader(query=query, load_max_docs=3).load()
64
+ if not docs:
65
+ return "No Wikipedia results found."
66
+
67
+ results = []
68
+ for doc in docs:
69
+ title = doc.metadata.get('title', 'Unknown Title')
70
+ content = doc.page_content[:2000] # Limit content length
71
+ results.append(f"Title: {title}\nContent: {content}")
72
+
73
+ return "\n\n---\n\n".join(results)
74
+ except Exception as e:
75
+ return f"Wikipedia search error: {str(e)}"
76
+
77
+ @tool
78
+ def arxiv_search(query: str) -> str:
79
+ """Search Arxiv and return up to 3 relevant papers."""
80
+ try:
81
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
82
+ if not docs:
83
+ return "No arXiv papers found."
84
+
85
+ results = []
86
+ for doc in docs:
87
+ title = doc.metadata.get('Title', 'Unknown Title')
88
+ authors = ", ".join(doc.metadata.get('Authors', []))
89
+ content = doc.page_content[:2000] # Limit content length
90
+ results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
91
+
92
+ return "\n\n---\n\n".join(results)
93
+ except Exception as e:
94
+ return f"arXiv search error: {str(e)}"
95
+
96
+ @tool
97
+ def web_search(query: str) -> str:
98
+ """Search the web using DuckDuckGo and return top results."""
99
+ try:
100
+ search = DuckDuckGoSearchRun()
101
+ result = search.run(query)
102
+ return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length
103
+ except Exception as e:
104
+ return f"Web search error: {str(e)}"
105
+
106
+ # --- Load system prompt ---
107
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
108
+ system_prompt = f.read()
109
+
110
+ # --- Tool Setup ---
111
+ tools = [
112
+ multiply,
113
+ add,
114
+ subtract,
115
+ divide,
116
+ modulus,
117
+ wiki_search,
118
+ arxiv_search,
119
+ web_search,
120
+ ]
121
+
122
+ # --- Graph Builder ---
123
+ def build_graph():
124
+ # Initialize model with Gemini 2.5 Flash
125
+ llm = ChatGoogleGenerativeAI(
126
+ model="gemini-2.5-flash",
127
+ temperature=0.3,
128
+ google_api_key=google_api_key,
129
+ max_retries=5,
130
+ request_timeout=60
131
+ )
132
+
133
+ # Bind tools to LLM
134
+ llm_with_tools = llm.bind_tools(tools)
135
+
136
+ # 1. Define state structure
137
+ class AgentState(TypedDict):
138
+ messages: Annotated[Sequence, operator.add]
139
+ retry_count: int
140
+
141
+ # 2. Create graph
142
+ workflow = StateGraph(AgentState)
143
+
144
+ # 3. Define node functions
145
+ def agent_node(state: AgentState):
146
+ """Main agent node"""
147
+ try:
148
+ # Add request delay to avoid rate limiting
149
+ time.sleep(2)
150
+
151
+ # Retry mechanism for API calls
152
+ @retry(stop=stop_after_attempt(5),
153
+ wait=wait_exponential(multiplier=1, min=4, max=30))
154
+ def invoke_with_retry():
155
+ return llm_with_tools.invoke(state["messages"])
156
+
157
+ response = invoke_with_retry()
158
+ return {"messages": [response], "retry_count": 0}
159
+
160
+ except Exception as e:
161
+ # Detailed error logging
162
+ error_details = f"Gemini API Error: {type(e).__name__}: {str(e)}"
163
+ logger.error(error_details)
164
+
165
+ error_type = "UNKNOWN"
166
+ if "429" in str(e):
167
+ error_type = "QUOTA_EXCEEDED"
168
+ elif "400" in str(e):
169
+ error_type = "INVALID_REQUEST"
170
+ elif "503" in str(e):
171
+ error_type = "SERVICE_UNAVAILABLE"
172
+
173
+ new_retry_count = state.get("retry_count", 0) + 1
174
+ error_msg = f"AGENT ERROR ({error_type}): {error_details[:300]}"
175
+
176
+ if new_retry_count < 3:
177
+ error_msg += "\n\nWill retry after delay..."
178
+ else:
179
+ error_msg += "\n\nMax retries exceeded. Please try again later."
180
+
181
+ return {"messages": [AIMessage(content=error_msg)], "retry_count": new_retry_count}
182
+
183
+ def tool_node(state: AgentState):
184
+ """Tool execution node"""
185
+ last_msg = state["messages"][-1]
186
+ tool_calls = last_msg.additional_kwargs.get("tool_calls", [])
187
+
188
+ responses = []
189
+ for call in tool_calls:
190
+ tool_name = call["function"]["name"]
191
+ tool_args = call["function"].get("arguments", {})
192
+
193
+ tool_func = next((t for t in tools if t.name == tool_name), None)
194
+ if not tool_func:
195
+ responses.append(f"Tool {tool_name} not available")
196
+ continue
197
+
198
+ try:
199
+ # Parse arguments
200
+ if isinstance(tool_args, str):
201
+ try:
202
+ tool_args = json.loads(tool_args)
203
+ except json.JSONDecodeError:
204
+ if "query" in tool_args:
205
+ tool_args = {"query": tool_args}
206
+ else:
207
+ tool_args = {"query": tool_args}
208
+
209
+ # Execute tool
210
+ result = tool_func.invoke(tool_args)
211
+ responses.append(f"{tool_name} result: {str(result)[:1000]}")
212
+ except Exception as e:
213
+ responses.append(f"{tool_name} error: {str(e)}")
214
+
215
+ tool_response_content = "\n".join(responses)
216
+ return {"messages": [AIMessage(content=tool_response_content)], "retry_count": 0}
217
+
218
+ # 4. Add nodes to workflow
219
+ workflow.add_node("agent", agent_node)
220
+ workflow.add_node("tools", tool_node)
221
+
222
+ # 5. Set entry point
223
+ workflow.set_entry_point("agent")
224
+
225
+ # 6. Define conditional edges
226
+ def should_continue(state: AgentState):
227
+ last_msg = state["messages"][-1]
228
+ retry_count = state.get("retry_count", 0)
229
+
230
+ # Handle error cases
231
+ if "AGENT ERROR" in last_msg.content:
232
+ if retry_count < 3:
233
+ return "agent"
234
+ return "end"
235
+
236
+ # Route to tools if tool calls exist
237
+ if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
238
+ return "tools"
239
+
240
+ # End if final answer is present
241
+ if "FINAL ANSWER" in last_msg.content:
242
+ return "end"
243
+
244
+ # Continue to agent otherwise
245
+ return "agent"
246
+
247
+ workflow.add_conditional_edges(
248
+ "agent",
249
+ should_continue,
250
+ {
251
+ "agent": "agent",
252
+ "tools": "tools",
253
+ "end": END
254
+ }
255
+ )
256
+
257
+ # 7. Define flow after tool node
258
+ workflow.add_edge("tools", "agent")
259
+
260
+ # 8. Compile graph
261
+ return workflow.compile()
262
+
263
+ # Initialize agent graph
264
+ agent_graph = build_graph()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain-google-genai
2
+ langchain
3
+ langchain-community
4
+ langgraph
5
+ duckduckgo-search
6
+ wikipedia
7
+ arxiv
8
+ tenacity
9
+ python-dotenv
10
+ google-generativeai