ZeroTimo commited on
Commit
55bc352
·
verified ·
1 Parent(s): c329c72

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +40 -106
agent.py CHANGED
@@ -1,47 +1,38 @@
 
 
1
  import os
2
- import re
3
  import time
4
  import functools
5
- from typing import Dict, Any, List
6
-
7
  import pandas as pd
 
 
8
 
9
- # LangGraph
10
  from langgraph.graph import StateGraph, START, END, MessagesState
11
- from langgraph.prebuilt import ToolNode, tools_condition
12
-
13
- # LangChain Core
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
-
17
- # Google Gemini
18
  from langchain_google_genai import ChatGoogleGenerativeAI
19
-
20
- # Tools
21
  from langchain_community.tools.tavily_search import TavilySearchResults
22
  from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
23
 
24
- # Python REPL Tool
25
  try:
26
  from langchain_experimental.tools.python.tool import PythonAstREPLTool
27
  except ImportError:
28
  from langchain.tools.python.tool import PythonAstREPLTool
29
 
30
  # ---------------------------------------------------------------------
31
- # 0) Optionale LangSmith-Tracing (setze ENV: LANGCHAIN_API_KEY)
32
  # ---------------------------------------------------------------------
33
-
34
  if os.getenv("LANGCHAIN_API_KEY"):
35
  os.environ["LANGCHAIN_TRACING_V2"] = "true"
36
  os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
37
  os.environ.setdefault("LANGCHAIN_PROJECT", "gaia-agent")
38
- print("📡 LangSmith tracing enabled.")
39
 
40
  # ---------------------------------------------------------------------
41
- # 1) Helfer: Fehler-Decorator + Backoff-Wrapper
42
  # ---------------------------------------------------------------------
43
  def error_guard(fn):
44
- """Fängt Tool-Fehler ab & gibt String zurück (bricht Agent nicht ab)."""
45
  @functools.wraps(fn)
46
  def wrapper(*args, **kw):
47
  try:
@@ -50,77 +41,47 @@ def error_guard(fn):
50
  return f"ERROR: {e}"
51
  return wrapper
52
 
53
-
54
- def with_backoff(fn, tries: int = 4, delay: int = 4):
55
- """Synchrones Retry-Wrapper für LLM-Aufrufe."""
56
- for t in range(tries):
57
- try:
58
- return fn()
59
- except Exception as e:
60
- if ("429" in str(e) or "RateLimit" in str(e)) and t < tries - 1:
61
- time.sleep(delay)
62
- delay *= 2
63
- continue
64
- raise
65
-
66
  # ---------------------------------------------------------------------
67
- # 2) Eigene Tools (CSV / Excel)
68
  # ---------------------------------------------------------------------
69
  @tool
70
  @error_guard
71
  def parse_csv(file_path: str, query: str = "") -> str:
72
- """Load a CSV file and (optional) run a pandas query."""
73
  df = pd.read_csv(file_path)
74
  if not query:
75
  return f"Rows={len(df)}, Cols={list(df.columns)}"
76
- try:
77
- return df.query(query).to_markdown(index=False)
78
- except Exception as e:
79
- return f"ERROR query: {e}"
80
-
81
 
82
  @tool
83
  @error_guard
84
  def parse_excel(file_path: str, sheet: str | int | None = None, query: str = "") -> str:
85
- """Load an Excel sheet (name or index) and (optional) run a pandas query."""
86
  sheet_arg = int(sheet) if isinstance(sheet, str) and sheet.isdigit() else sheet or 0
87
  df = pd.read_excel(file_path, sheet_name=sheet_arg)
88
  if not query:
89
  return f"Rows={len(df)}, Cols={list(df.columns)}"
90
- try:
91
- return df.query(query).to_markdown(index=False)
92
- except Exception as e:
93
- return f"ERROR query: {e}"
94
 
95
- # ---------------------------------------------------------------------
96
- # 3) Externe Search-Tools (Tavily, Wikipedia)
97
- # ---------------------------------------------------------------------
98
  @tool
99
  @error_guard
100
  def web_search(query: str, max_results: int = 5) -> str:
101
- """Search the web via Tavily and return markdown list of results."""
102
  api_key = os.getenv("TAVILY_API_KEY")
103
  hits = TavilySearchResults(max_results=max_results, api_key=api_key).invoke(query)
104
  if not hits:
105
  return "No results."
106
  return "\n".join(f"{h['title']} – {h['url']}" for h in hits)
107
 
108
-
109
  @tool
110
  @error_guard
111
  def wiki_search(query: str, sentences: int = 3) -> str:
112
- """Quick Wikipedia summary."""
113
  wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=4000)
114
  res = wrapper.run(query)
115
  return "\n".join(res.split(". ")[:sentences]) if res else "No article found."
116
 
117
- # ---------------------------------------------------------------------
118
- # 4) Python-REPL Tool (fertig aus LangChain)
119
- # ---------------------------------------------------------------------
120
  python_repl = PythonAstREPLTool()
121
 
122
  # ---------------------------------------------------------------------
123
- # 5) LLM – Gemini Flash, an Tools gebunden
124
  # ---------------------------------------------------------------------
125
  gemini_llm = ChatGoogleGenerativeAI(
126
  google_api_key=os.getenv("GOOGLE_API_KEY"),
@@ -129,81 +90,54 @@ gemini_llm = ChatGoogleGenerativeAI(
129
  max_output_tokens=2048,
130
  )
131
 
132
- # ---------------------------------------------------------------------
133
- # 6) System-Prompt (ReAct, keine Prefixe im Final-Output!)
134
- # ---------------------------------------------------------------------
135
  SYSTEM_PROMPT = SystemMessage(
136
  content=(
137
- "You are a helpful assistant with access to several tools.\n"
138
- "You can think step by step and use tools to find answers.\n\n"
139
- "When you want to use a tool, write it like this:\n"
140
- "Tool: <tool_name>\n"
141
- "Input: <input for the tool>\n\n"
142
- "Wait for the tool result before continuing.\n"
143
- "When you know the final answer, reply with the answer **only**.\n"
144
- "Don't include any prefix, explanation or formatting around the answer.\n"
145
- "Answer formatting:\n"
146
- "- For numbers: no units unless requested\n"
147
- "- For strings: no articles or abbreviations\n"
148
- "- For lists: comma + space separated, correct order\n"
149
  )
150
  )
151
 
 
 
152
  # ---------------------------------------------------------------------
153
- # 7) LangGraph – Planner + Tools + Router
154
  # ---------------------------------------------------------------------
155
-
156
  def planner(state: MessagesState):
157
- msgs = state["messages"]
158
- if msgs[0].type != "system":
159
- msgs = [SYSTEM_PROMPT] + msgs
160
-
161
- resp = with_backoff(lambda: gemini_llm.invoke(msgs))
162
-
163
- # WICHTIG: Gib tool_calls weiter – sie lösen im ToolNode die Ausführung aus
164
- return {
165
- "messages": msgs + [resp],
166
- "should_end": (
167
- not getattr(resp, "tool_calls", None) # kein Tool gewünscht
168
- and "\n" not in resp.content # einfache Heuristik
169
- )
170
- }
171
 
172
- def route(state):
173
- return "END" if state["should_end"] else "tools"
174
-
175
- # Tool-Knoten
176
- TOOLS = [web_search, wiki_search, parse_csv, parse_excel, python_repl]
177
 
 
 
 
178
  graph = StateGraph(MessagesState)
179
-
180
  graph.add_node("planner", planner)
181
  graph.add_node("tools", ToolNode(TOOLS))
182
-
183
  graph.add_edge(START, "planner")
184
- graph.add_edge("tools", "planner") # 🔁 Rücksprung zum Planner nach Tool-Ausführung
185
-
186
- graph.add_conditional_edges("planner", route, {
187
- "tools": "tools",
188
- "END": END,
189
- })
190
 
191
- # compile → LangGraph-Executor
192
  agent_executor = graph.compile()
193
 
194
  # ---------------------------------------------------------------------
195
- # 8) Öffentliche Klasse – wird von app.py / logic.py verwendet
196
  # ---------------------------------------------------------------------
197
  class GaiaAgent:
198
- """LangChain·LangGraph-Agent für GAIA Level 1."""
199
-
200
  def __init__(self):
201
  print("✅ GaiaAgent initialised (LangGraph)")
202
 
203
  def __call__(self, task_id: str, question: str) -> str:
204
- """Run the agent on a single GAIA question → exact answer string."""
205
- start_state = {"messages": [HumanMessage(content=question)]}
206
- final_state = agent_executor.invoke(start_state)
207
- # letze Message enthält Antwort
208
- answer = final_state["messages"][-1].content
209
- return answer.strip()
 
1
+ # agent.py
2
+
3
  import os
 
4
  import time
5
  import functools
 
 
6
  import pandas as pd
7
+ from typing import Dict, Any, List
8
+ import re
9
 
 
10
  from langgraph.graph import StateGraph, START, END, MessagesState
11
+ from langgraph.prebuilt import ToolNode
 
 
12
  from langchain_core.messages import SystemMessage, HumanMessage
13
  from langchain_core.tools import tool
 
 
14
  from langchain_google_genai import ChatGoogleGenerativeAI
 
 
15
  from langchain_community.tools.tavily_search import TavilySearchResults
16
  from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
17
 
 
18
  try:
19
  from langchain_experimental.tools.python.tool import PythonAstREPLTool
20
  except ImportError:
21
  from langchain.tools.python.tool import PythonAstREPLTool
22
 
23
  # ---------------------------------------------------------------------
24
+ # LangSmith optional
25
  # ---------------------------------------------------------------------
 
26
  if os.getenv("LANGCHAIN_API_KEY"):
27
  os.environ["LANGCHAIN_TRACING_V2"] = "true"
28
  os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
29
  os.environ.setdefault("LANGCHAIN_PROJECT", "gaia-agent")
30
+ print("📱 LangSmith tracing enabled.")
31
 
32
  # ---------------------------------------------------------------------
33
+ # Fehler-Wrapper
34
  # ---------------------------------------------------------------------
35
  def error_guard(fn):
 
36
  @functools.wraps(fn)
37
  def wrapper(*args, **kw):
38
  try:
 
41
  return f"ERROR: {e}"
42
  return wrapper
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  # ---------------------------------------------------------------------
45
+ # Eigene Tools
46
  # ---------------------------------------------------------------------
47
  @tool
48
  @error_guard
49
  def parse_csv(file_path: str, query: str = "") -> str:
 
50
  df = pd.read_csv(file_path)
51
  if not query:
52
  return f"Rows={len(df)}, Cols={list(df.columns)}"
53
+ return df.query(query).to_markdown(index=False)
 
 
 
 
54
 
55
  @tool
56
  @error_guard
57
  def parse_excel(file_path: str, sheet: str | int | None = None, query: str = "") -> str:
 
58
  sheet_arg = int(sheet) if isinstance(sheet, str) and sheet.isdigit() else sheet or 0
59
  df = pd.read_excel(file_path, sheet_name=sheet_arg)
60
  if not query:
61
  return f"Rows={len(df)}, Cols={list(df.columns)}"
62
+ return df.query(query).to_markdown(index=False)
 
 
 
63
 
 
 
 
64
  @tool
65
  @error_guard
66
  def web_search(query: str, max_results: int = 5) -> str:
 
67
  api_key = os.getenv("TAVILY_API_KEY")
68
  hits = TavilySearchResults(max_results=max_results, api_key=api_key).invoke(query)
69
  if not hits:
70
  return "No results."
71
  return "\n".join(f"{h['title']} – {h['url']}" for h in hits)
72
 
 
73
  @tool
74
  @error_guard
75
  def wiki_search(query: str, sentences: int = 3) -> str:
 
76
  wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=4000)
77
  res = wrapper.run(query)
78
  return "\n".join(res.split(". ")[:sentences]) if res else "No article found."
79
 
80
+ # Python Tool
 
 
81
  python_repl = PythonAstREPLTool()
82
 
83
  # ---------------------------------------------------------------------
84
+ # Gemini LLM
85
  # ---------------------------------------------------------------------
86
  gemini_llm = ChatGoogleGenerativeAI(
87
  google_api_key=os.getenv("GOOGLE_API_KEY"),
 
90
  max_output_tokens=2048,
91
  )
92
 
 
 
 
93
  SYSTEM_PROMPT = SystemMessage(
94
  content=(
95
+ "You are a helpful assistant with access to tools.\n"
96
+ "Use tools when appropriate using tool calls.\n"
97
+ "If the answer is clear, return it directly without explanation."
 
 
 
 
 
 
 
 
 
98
  )
99
  )
100
 
101
+ TOOLS = [web_search, wiki_search, parse_csv, parse_excel, python_repl]
102
+
103
  # ---------------------------------------------------------------------
104
+ # LangGraph Nodes
105
  # ---------------------------------------------------------------------
 
106
  def planner(state: MessagesState):
107
+ messages = state["messages"]
108
+ if not any(m.type == "system" for m in messages):
109
+ messages = [SYSTEM_PROMPT] + messages
110
+ resp = gemini_llm.invoke(messages)
111
+ return {"messages": messages + [resp]}
 
 
 
 
 
 
 
 
 
112
 
113
+ def should_end(state: MessagesState) -> bool:
114
+ last = state["messages"][-1]
115
+ return not getattr(last, "tool_calls", None)
 
 
116
 
117
+ # ---------------------------------------------------------------------
118
+ # Build Graph
119
+ # ---------------------------------------------------------------------
120
  graph = StateGraph(MessagesState)
 
121
  graph.add_node("planner", planner)
122
  graph.add_node("tools", ToolNode(TOOLS))
 
123
  graph.add_edge(START, "planner")
124
+ graph.add_conditional_edges(
125
+ "planner",
126
+ lambda state: "END" if should_end(state) else "tools",
127
+ {"tools": "tools", "END": END},
128
+ )
129
+ graph.add_edge("tools", "planner")
130
 
 
131
  agent_executor = graph.compile()
132
 
133
  # ---------------------------------------------------------------------
134
+ # Öffentliche Klasse
135
  # ---------------------------------------------------------------------
136
  class GaiaAgent:
 
 
137
  def __init__(self):
138
  print("✅ GaiaAgent initialised (LangGraph)")
139
 
140
  def __call__(self, task_id: str, question: str) -> str:
141
+ state = {"messages": [HumanMessage(content=question)]}
142
+ final = agent_executor.invoke(state)
143
+ return final["messages"][-1].content.strip()