vtony commited on
Commit
0a72192
·
verified ·
1 Parent(s): dcee690

Upload agent.py

Browse files
Files changed (1) hide show
  1. agent.py +337 -0
agent.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import re
5
+ import calendar
6
+ from datetime import datetime
7
+ from dotenv import load_dotenv
8
+ from langgraph.graph import StateGraph, END
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
+ from langchain_community.tools import DuckDuckGoSearchRun
11
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
12
+ from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
13
+ from langchain_core.tools import tool
14
+ from tenacity import retry, stop_after_attempt, wait_exponential
15
+ from typing import TypedDict, Annotated, Sequence, List, Dict, Union
16
+ import operator
17
+
18
+ # Load environment variables
19
+ load_dotenv()
20
+ google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
21
+ if not google_api_key:
22
+ raise ValueError("Missing GOOGLE_API_KEY environment variable")
23
+
24
+ # --- Math Tools ---
25
+ @tool
26
+ def multiply(a: int, b: int) -> int:
27
+ """Multiply two integers."""
28
+ return a * b
29
+
30
+ @tool
31
+ def add(a: int, b: int) -> int:
32
+ """Add two integers."""
33
+ return a + b
34
+
35
+ @tool
36
+ def subtract(a: int, b: int) -> int:
37
+ """Subtract b from a."""
38
+ return a - b
39
+
40
+ @tool
41
+ def divide(a: int, b: int) -> float:
42
+ """Divide a by b, error on zero."""
43
+ if b == 0:
44
+ raise ValueError("Cannot divide by zero.")
45
+ return a / b
46
+
47
+ @tool
48
+ def modulus(a: int, b: int) -> int:
49
+ """Compute a mod b."""
50
+ return a % b
51
+
52
+ # --- Browser Tools ---
53
+ @tool
54
+ def wiki_search(query: str) -> str:
55
+ """Search Wikipedia and return up to 3 relevant documents."""
56
+ try:
57
+ docs = WikipediaLoader(query=query, load_max_docs=3).load()
58
+ if not docs:
59
+ return "No Wikipedia results found."
60
+
61
+ results = []
62
+ for doc in docs:
63
+ title = doc.metadata.get('title', 'Unknown Title')
64
+ content = doc.page_content[:2000] # Limit content length
65
+ results.append(f"Title: {title}\nContent: {content}")
66
+
67
+ return "\n\n---\n\n".join(results)
68
+ except Exception as e:
69
+ return f"Wikipedia search error: {str(e)}"
70
+
71
+ @tool
72
+ def arxiv_search(query: str) -> str:
73
+ """Search Arxiv and return up to 3 relevant papers."""
74
+ try:
75
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
76
+ if not docs:
77
+ return "No arXiv papers found."
78
+
79
+ results = []
80
+ for doc in docs:
81
+ title = doc.metadata.get('Title', 'Unknown Title')
82
+ authors = ", ".join(doc.metadata.get('Authors', []))
83
+ content = doc.page_content[:2000] # Limit content length
84
+ results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
85
+
86
+ return "\n\n---\n\n".join(results)
87
+ except Exception as e:
88
+ return f"arXiv search error: {str(e)}"
89
+
90
+ @tool
91
+ def web_search(query: str) -> str:
92
+ """Search the web using DuckDuckGo and return top results."""
93
+ try:
94
+ search = DuckDuckGoSearchRun()
95
+ result = search.run(query)
96
+ return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length
97
+ except Exception as e:
98
+ return f"Web search error: {str(e)}"
99
+
100
+ # --- Enhanced Tools ---
101
+ @tool
102
+ def filter_by_year(items: List[Dict], year_range: str) -> List[Dict]:
103
+ """Filter items containing year information, returning only those within specified range"""
104
+ try:
105
+ start_year, end_year = map(int, year_range.split('-'))
106
+ filtered = []
107
+ for item in items:
108
+ # Extract year from different possible keys
109
+ year = item.get('year') or item.get('release_year') or item.get('date')
110
+ if not year:
111
+ continue
112
+
113
+ # Convert to integer if possible
114
+ if isinstance(year, str) and year.isdigit():
115
+ year = int(year)
116
+
117
+ if isinstance(year, int) and start_year <= year <= end_year:
118
+ filtered.append(item)
119
+ return filtered
120
+ except Exception as e:
121
+ return f"Filter error: {str(e)}"
122
+
123
+ @tool
124
+ def extract_albums(text: str) -> List[Dict]:
125
+ """Extract album information from text, automatically detecting names and years"""
126
+ albums = []
127
+
128
+ # Pattern 1: Album Name (Year)
129
+ pattern1 = r'\"?(.+?)\"?\s*[\(\[](\d{4})[\)\]]'
130
+ # Pattern 2: Year: Album Name
131
+ pattern2 = r'(\d{4}):\s*\"?(.+?)\"?[\n\,]'
132
+
133
+ for pattern in [pattern1, pattern2]:
134
+ matches = re.findall(pattern, text)
135
+ for match in matches:
136
+ # Handle different match group orders
137
+ if len(match) == 2:
138
+ if match[0].isdigit(): # Year comes first
139
+ year, name = match
140
+ else: # Name comes first
141
+ name, year = match
142
+
143
+ try:
144
+ year = int(year)
145
+ albums.append({"name": name.strip(), "year": year})
146
+ except ValueError:
147
+ continue
148
+
149
+ return albums
150
+
151
+ @tool
152
+ def compare_values(a: Union[str, int, float], b: Union[str, int, float]) -> str:
153
+ """Compare two values with automatic type detection (number/date/string)"""
154
+ try:
155
+ # Attempt numeric comparison
156
+ a_num = float(a) if isinstance(a, str) else a
157
+ b_num = float(b) if isinstance(b, str) else b
158
+ if a_num == b_num:
159
+ return "equal"
160
+ return "greater" if a_num > b_num else "less"
161
+ except (ValueError, TypeError):
162
+ pass
163
+
164
+ # Attempt date comparison
165
+ date_formats = [
166
+ "%Y-%m-%d", "%d %B %Y", "%B %d, %Y", "%m/%d/%Y",
167
+ "%Y", "%B %Y", "%b %d, %Y", "%d/%m/%Y"
168
+ ]
169
+
170
+ for fmt in date_formats:
171
+ try:
172
+ a_date = datetime.strptime(str(a), fmt)
173
+ b_date = datetime.strptime(str(b), fmt)
174
+ if a_date == b_date:
175
+ return "equal"
176
+ return "greater" if a_date > b_date else "less"
177
+ except ValueError:
178
+ continue
179
+
180
+ # String comparison as fallback
181
+ a_str = str(a).lower().strip()
182
+ b_str = str(b).lower().strip()
183
+ if a_str == b_str:
184
+ return "equal"
185
+ return "greater" if a_str > b_str else "less"
186
+
187
+ @tool
188
+ def count_items(items: List) -> int:
189
+ """Count the number of items in a list"""
190
+ return len(items)
191
+
192
+ # --- Load system prompt ---
193
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
194
+ system_prompt = f.read()
195
+
196
+ # --- Tool Setup ---
197
+ tools = [
198
+ multiply,
199
+ add,
200
+ subtract,
201
+ divide,
202
+ modulus,
203
+ wiki_search,
204
+ arxiv_search,
205
+ web_search,
206
+ filter_by_year, # Enhanced tool
207
+ extract_albums, # Enhanced tool
208
+ compare_values, # Enhanced tool
209
+ count_items # Enhanced tool
210
+ ]
211
+
212
+ # --- Graph Builder ---
213
+ def build_graph():
214
+ # Initialize model with Gemini 2.5 Flash
215
+ llm = ChatGoogleGenerativeAI(
216
+ model="gemini-2.5-flash",
217
+ temperature=0.3,
218
+ google_api_key=google_api_key,
219
+ max_retries=3
220
+ )
221
+
222
+ # Bind tools to LLM
223
+ llm_with_tools = llm.bind_tools(tools)
224
+
225
+ # 1. Define state structure
226
+ class AgentState(TypedDict):
227
+ messages: Annotated[Sequence, operator.add]
228
+ structured_data: dict # New field for structured information
229
+
230
+ # 2. Create graph
231
+ workflow = StateGraph(AgentState)
232
+
233
+ # 3. Define node functions
234
+ def agent_node(state: AgentState):
235
+ """Main agent node"""
236
+ try:
237
+ # Remove forced delay to improve performance
238
+ # time.sleep(1) # Commented out for performance
239
+
240
+ # Call with retry mechanism
241
+ @retry(stop=stop_after_attempt(3),
242
+ wait=wait_exponential(multiplier=1, min=4, max=10))
243
+ def invoke_with_retry():
244
+ return llm_with_tools.invoke(state["messages"])
245
+
246
+ response = invoke_with_retry()
247
+ return {"messages": [response]}
248
+
249
+ except Exception as e:
250
+ error_type = "UNKNOWN"
251
+ if "429" in str(e):
252
+ error_type = "QUOTA_EXCEEDED"
253
+ elif "400" in str(e):
254
+ error_type = "INVALID_REQUEST"
255
+
256
+ error_msg = f"AGENT ERROR ({error_type}): {str(e)[:200]}"
257
+ return {"messages": [AIMessage(content=error_msg)]}
258
+
259
+ def tool_node(state: AgentState):
260
+ """Tool execution node"""
261
+ last_msg = state["messages"][-1]
262
+ tool_calls = last_msg.additional_kwargs.get("tool_calls", [])
263
+
264
+ responses = []
265
+ for call in tool_calls:
266
+ tool_name = call["function"]["name"]
267
+ tool_args = call["function"].get("arguments", {})
268
+
269
+ # Find the tool
270
+ tool_func = next((t for t in tools if t.name == tool_name), None)
271
+ if not tool_func:
272
+ responses.append(f"Tool {tool_name} not available")
273
+ continue
274
+
275
+ try:
276
+ # Parse arguments
277
+ if isinstance(tool_args, str):
278
+ tool_args = json.loads(tool_args)
279
+
280
+ # Execute tool
281
+ result = tool_func.invoke(tool_args)
282
+
283
+ # Store structured results
284
+ if tool_name in ["extract_albums", "filter_by_year"]:
285
+ state["structured_data"][tool_name] = result
286
+
287
+ responses.append(f"{tool_name} result: {str(result)[:1000]}") # Limit result length
288
+ except Exception as e:
289
+ responses.append(f"{tool_name} error: {str(e)}")
290
+
291
+ tool_response_content = "\n".join(responses)
292
+ return {"messages": [AIMessage(content=tool_response_content)]}
293
+
294
+ # 4. Add nodes to workflow
295
+ workflow.add_node("agent", agent_node)
296
+ workflow.add_node("tools", tool_node)
297
+
298
+ # 5. Set entry point
299
+ workflow.set_entry_point("agent")
300
+
301
+ # 6. Define conditional edges
302
+ def should_continue(state: AgentState):
303
+ last_msg = state["messages"][-1]
304
+
305
+ # End on error
306
+ if "AGENT ERROR" in last_msg.content:
307
+ return "end"
308
+
309
+ # Go to tools if there are tool calls
310
+ if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
311
+ return "tools"
312
+
313
+ # End if final answer is present
314
+ if "FINAL ANSWER" in last_msg.content:
315
+ return "end"
316
+
317
+ # Otherwise continue with agent
318
+ return "agent"
319
+
320
+ workflow.add_conditional_edges(
321
+ "agent",
322
+ should_continue,
323
+ {
324
+ "agent": "agent",
325
+ "tools": "tools",
326
+ "end": END
327
+ }
328
+ )
329
+
330
+ # 7. Define flow after tool node
331
+ workflow.add_edge("tools", "agent")
332
+
333
+ # 8. Compile graph
334
+ return workflow.compile()
335
+
336
+ # Initialize agent graph
337
+ agent_graph = build_graph()