vtony commited on
Commit
35195fa
·
verified ·
1 Parent(s): ee97348

Delete agent.py

Browse files
Files changed (1) hide show
  1. agent.py +0 -337
agent.py DELETED
@@ -1,337 +0,0 @@
1
- import os
2
- import time
3
- import json
4
- import re
5
- import calendar
6
- from datetime import datetime
7
- from dotenv import load_dotenv
8
- from langgraph.graph import StateGraph, END
9
- from langchain_google_genai import ChatGoogleGenerativeAI
10
- from langchain_community.tools import DuckDuckGoSearchRun
11
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
12
- from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
13
- from langchain_core.tools import tool
14
- from tenacity import retry, stop_after_attempt, wait_exponential
15
- from typing import TypedDict, Annotated, Sequence, List, Dict, Union
16
- import operator
17
-
18
- # Load environment variables
19
- load_dotenv()
20
- google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
21
- if not google_api_key:
22
- raise ValueError("Missing GOOGLE_API_KEY environment variable")
23
-
24
- # --- Math Tools ---
25
- @tool
26
- def multiply(a: int, b: int) -> int:
27
- """Multiply two integers."""
28
- return a * b
29
-
30
- @tool
31
- def add(a: int, b: int) -> int:
32
- """Add two integers."""
33
- return a + b
34
-
35
- @tool
36
- def subtract(a: int, b: int) -> int:
37
- """Subtract b from a."""
38
- return a - b
39
-
40
- @tool
41
- def divide(a: int, b: int) -> float:
42
- """Divide a by b, error on zero."""
43
- if b == 0:
44
- raise ValueError("Cannot divide by zero.")
45
- return a / b
46
-
47
- @tool
48
- def modulus(a: int, b: int) -> int:
49
- """Compute a mod b."""
50
- return a % b
51
-
52
- # --- Browser Tools ---
53
- @tool
54
- def wiki_search(query: str) -> str:
55
- """Search Wikipedia and return up to 3 relevant documents."""
56
- try:
57
- docs = WikipediaLoader(query=query, load_max_docs=3).load()
58
- if not docs:
59
- return "No Wikipedia results found."
60
-
61
- results = []
62
- for doc in docs:
63
- title = doc.metadata.get('title', 'Unknown Title')
64
- content = doc.page_content[:2000] # Limit content length
65
- results.append(f"Title: {title}\nContent: {content}")
66
-
67
- return "\n\n---\n\n".join(results)
68
- except Exception as e:
69
- return f"Wikipedia search error: {str(e)}"
70
-
71
- @tool
72
- def arxiv_search(query: str) -> str:
73
- """Search Arxiv and return up to 3 relevant papers."""
74
- try:
75
- docs = ArxivLoader(query=query, load_max_docs=3).load()
76
- if not docs:
77
- return "No arXiv papers found."
78
-
79
- results = []
80
- for doc in docs:
81
- title = doc.metadata.get('Title', 'Unknown Title')
82
- authors = ", ".join(doc.metadata.get('Authors', []))
83
- content = doc.page_content[:2000] # Limit content length
84
- results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
85
-
86
- return "\n\n---\n\n".join(results)
87
- except Exception as e:
88
- return f"arXiv search error: {str(e)}"
89
-
90
- @tool
91
- def web_search(query: str) -> str:
92
- """Search the web using DuckDuckGo and return top results."""
93
- try:
94
- search = DuckDuckGoSearchRun()
95
- result = search.run(query)
96
- return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length
97
- except Exception as e:
98
- return f"Web search error: {str(e)}"
99
-
100
- # --- Enhanced Tools ---
101
- @tool
102
- def filter_by_year(items: List[Dict], year_range: str) -> List[Dict]:
103
- """Filter items containing year information, returning only those within specified range"""
104
- try:
105
- start_year, end_year = map(int, year_range.split('-'))
106
- filtered = []
107
- for item in items:
108
- # Extract year from different possible keys
109
- year = item.get('year') or item.get('release_year') or item.get('date')
110
- if not year:
111
- continue
112
-
113
- # Convert to integer if possible
114
- if isinstance(year, str) and year.isdigit():
115
- year = int(year)
116
-
117
- if isinstance(year, int) and start_year <= year <= end_year:
118
- filtered.append(item)
119
- return filtered
120
- except Exception as e:
121
- return f"Filter error: {str(e)}"
122
-
123
- @tool
124
- def extract_albums(text: str) -> List[Dict]:
125
- """Extract album information from text, automatically detecting names and years"""
126
- albums = []
127
-
128
- # Pattern 1: Album Name (Year)
129
- pattern1 = r'\"?(.+?)\"?\s*[\(\[](\d{4})[\)\]]'
130
- # Pattern 2: Year: Album Name
131
- pattern2 = r'(\d{4}):\s*\"?(.+?)\"?[\n\,]'
132
-
133
- for pattern in [pattern1, pattern2]:
134
- matches = re.findall(pattern, text)
135
- for match in matches:
136
- # Handle different match group orders
137
- if len(match) == 2:
138
- if match[0].isdigit(): # Year comes first
139
- year, name = match
140
- else: # Name comes first
141
- name, year = match
142
-
143
- try:
144
- year = int(year)
145
- albums.append({"name": name.strip(), "year": year})
146
- except ValueError:
147
- continue
148
-
149
- return albums
150
-
151
- @tool
152
- def compare_values(a: Union[str, int, float], b: Union[str, int, float]) -> str:
153
- """Compare two values with automatic type detection (number/date/string)"""
154
- try:
155
- # Attempt numeric comparison
156
- a_num = float(a) if isinstance(a, str) else a
157
- b_num = float(b) if isinstance(b, str) else b
158
- if a_num == b_num:
159
- return "equal"
160
- return "greater" if a_num > b_num else "less"
161
- except (ValueError, TypeError):
162
- pass
163
-
164
- # Attempt date comparison
165
- date_formats = [
166
- "%Y-%m-%d", "%d %B %Y", "%B %d, %Y", "%m/%d/%Y",
167
- "%Y", "%B %Y", "%b %d, %Y", "%d/%m/%Y"
168
- ]
169
-
170
- for fmt in date_formats:
171
- try:
172
- a_date = datetime.strptime(str(a), fmt)
173
- b_date = datetime.strptime(str(b), fmt)
174
- if a_date == b_date:
175
- return "equal"
176
- return "greater" if a_date > b_date else "less"
177
- except ValueError:
178
- continue
179
-
180
- # String comparison as fallback
181
- a_str = str(a).lower().strip()
182
- b_str = str(b).lower().strip()
183
- if a_str == b_str:
184
- return "equal"
185
- return "greater" if a_str > b_str else "less"
186
-
187
- @tool
188
- def count_items(items: List) -> int:
189
- """Count the number of items in a list"""
190
- return len(items)
191
-
192
- # --- Load system prompt ---
193
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
194
- system_prompt = f.read()
195
-
196
- # --- Tool Setup ---
197
- tools = [
198
- multiply,
199
- add,
200
- subtract,
201
- divide,
202
- modulus,
203
- wiki_search,
204
- arxiv_search,
205
- web_search,
206
- filter_by_year, # Enhanced tool
207
- extract_albums, # Enhanced tool
208
- compare_values, # Enhanced tool
209
- count_items # Enhanced tool
210
- ]
211
-
212
- # --- Graph Builder ---
213
- def build_graph():
214
- # Initialize model with Gemini 2.5 Flash
215
- llm = ChatGoogleGenerativeAI(
216
- model="gemini-2.5-flash",
217
- temperature=0.3,
218
- google_api_key=google_api_key,
219
- max_retries=3
220
- )
221
-
222
- # Bind tools to LLM
223
- llm_with_tools = llm.bind_tools(tools)
224
-
225
- # 1. Define state structure
226
- class AgentState(TypedDict):
227
- messages: Annotated[Sequence, operator.add]
228
- structured_data: dict # New field for structured information
229
-
230
- # 2. Create graph
231
- workflow = StateGraph(AgentState)
232
-
233
- # 3. Define node functions
234
- def agent_node(state: AgentState):
235
- """Main agent node"""
236
- try:
237
- # Remove forced delay to improve performance
238
- # time.sleep(1) # Commented out for performance
239
-
240
- # Call with retry mechanism
241
- @retry(stop=stop_after_attempt(3),
242
- wait=wait_exponential(multiplier=1, min=4, max=10))
243
- def invoke_with_retry():
244
- return llm_with_tools.invoke(state["messages"])
245
-
246
- response = invoke_with_retry()
247
- return {"messages": [response]}
248
-
249
- except Exception as e:
250
- error_type = "UNKNOWN"
251
- if "429" in str(e):
252
- error_type = "QUOTA_EXCEEDED"
253
- elif "400" in str(e):
254
- error_type = "INVALID_REQUEST"
255
-
256
- error_msg = f"AGENT ERROR ({error_type}): {str(e)[:200]}"
257
- return {"messages": [AIMessage(content=error_msg)]}
258
-
259
- def tool_node(state: AgentState):
260
- """Tool execution node"""
261
- last_msg = state["messages"][-1]
262
- tool_calls = last_msg.additional_kwargs.get("tool_calls", [])
263
-
264
- responses = []
265
- for call in tool_calls:
266
- tool_name = call["function"]["name"]
267
- tool_args = call["function"].get("arguments", {})
268
-
269
- # Find the tool
270
- tool_func = next((t for t in tools if t.name == tool_name), None)
271
- if not tool_func:
272
- responses.append(f"Tool {tool_name} not available")
273
- continue
274
-
275
- try:
276
- # Parse arguments
277
- if isinstance(tool_args, str):
278
- tool_args = json.loads(tool_args)
279
-
280
- # Execute tool
281
- result = tool_func.invoke(tool_args)
282
-
283
- # Store structured results
284
- if tool_name in ["extract_albums", "filter_by_year"]:
285
- state["structured_data"][tool_name] = result
286
-
287
- responses.append(f"{tool_name} result: {str(result)[:1000]}") # Limit result length
288
- except Exception as e:
289
- responses.append(f"{tool_name} error: {str(e)}")
290
-
291
- tool_response_content = "\n".join(responses)
292
- return {"messages": [AIMessage(content=tool_response_content)]}
293
-
294
- # 4. Add nodes to workflow
295
- workflow.add_node("agent", agent_node)
296
- workflow.add_node("tools", tool_node)
297
-
298
- # 5. Set entry point
299
- workflow.set_entry_point("agent")
300
-
301
- # 6. Define conditional edges
302
- def should_continue(state: AgentState):
303
- last_msg = state["messages"][-1]
304
-
305
- # End on error
306
- if "AGENT ERROR" in last_msg.content:
307
- return "end"
308
-
309
- # Go to tools if there are tool calls
310
- if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
311
- return "tools"
312
-
313
- # End if final answer is present
314
- if "FINAL ANSWER" in last_msg.content:
315
- return "end"
316
-
317
- # Otherwise continue with agent
318
- return "agent"
319
-
320
- workflow.add_conditional_edges(
321
- "agent",
322
- should_continue,
323
- {
324
- "agent": "agent",
325
- "tools": "tools",
326
- "end": END
327
- }
328
- )
329
-
330
- # 7. Define flow after tool node
331
- workflow.add_edge("tools", "agent")
332
-
333
- # 8. Compile graph
334
- return workflow.compile()
335
-
336
- # Initialize agent graph
337
- agent_graph = build_graph()