Spaces:
Runtime error
Runtime error
Update agent.py
Browse files
agent.py
CHANGED
@@ -1,241 +1,197 @@
|
|
1 |
-
# agent.py –
|
2 |
-
#
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
from langgraph.
|
|
|
8 |
|
|
|
9 |
from langchain_core.messages import SystemMessage, HumanMessage
|
10 |
from langchain_core.tools import tool
|
11 |
-
|
12 |
-
from langchain_google_genai import ChatGoogleGenerativeAI
|
13 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
|
|
|
|
14 |
|
15 |
# ---------------------------------------------------------------------
|
16 |
-
#
|
17 |
# ---------------------------------------------------------------------
|
18 |
-
|
19 |
-
|
|
|
|
|
20 |
|
21 |
# ---------------------------------------------------------------------
|
22 |
-
# Fehler-
|
23 |
# ---------------------------------------------------------------------
|
24 |
-
import functools
|
25 |
def error_guard(fn):
|
|
|
26 |
@functools.wraps(fn)
|
27 |
-
def wrapper(*args, **
|
28 |
try:
|
29 |
-
return fn(*args, **
|
30 |
except Exception as e:
|
31 |
return f"ERROR: {e}"
|
32 |
return wrapper
|
33 |
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
response = requests.get(url, timeout=30)
|
47 |
-
response.raise_for_status()
|
48 |
-
file_name = response.headers.get("x-gaia-filename", f"{task_id}")
|
49 |
-
tmp_path = tempfile.gettempdir() + "/" + file_name
|
50 |
-
with open(tmp_path, "wb") as f:
|
51 |
-
f.write(response.content)
|
52 |
-
return tmp_path
|
53 |
-
except Exception as e:
|
54 |
-
return f"ERROR: could not fetch file – {e}"
|
55 |
|
56 |
# ---------------------------------------------------------------------
|
57 |
-
# 2) CSV
|
58 |
# ---------------------------------------------------------------------
|
59 |
-
import pandas as pd
|
60 |
-
|
61 |
@tool
|
62 |
@error_guard
|
63 |
def parse_csv(file_path: str, query: str = "") -> str:
|
64 |
-
"""Load a CSV file and
|
65 |
df = pd.read_csv(file_path)
|
66 |
if not query:
|
67 |
-
return f"
|
68 |
try:
|
69 |
-
|
70 |
-
return result.to_markdown()
|
71 |
except Exception as e:
|
72 |
-
return f"ERROR
|
|
|
73 |
|
74 |
-
# ---------------------------------------------------------------------
|
75 |
-
# 3) Excel-Parser
|
76 |
-
# ---------------------------------------------------------------------
|
77 |
@tool
|
78 |
@error_guard
|
79 |
-
def parse_excel(file_path: str, query: str = "") -> str:
|
80 |
-
"""Load an Excel
|
81 |
-
|
|
|
82 |
if not query:
|
83 |
-
return f"
|
84 |
try:
|
85 |
-
|
86 |
-
return result.to_markdown()
|
87 |
except Exception as e:
|
88 |
-
return f"ERROR
|
89 |
|
90 |
# ---------------------------------------------------------------------
|
91 |
-
#
|
92 |
# ---------------------------------------------------------------------
|
93 |
@tool
|
94 |
@error_guard
|
95 |
-
def
|
96 |
-
"""
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
{"type": "text", "text": prompt},
|
103 |
-
{"type": "media", "data": b64, "mime_type": mime},
|
104 |
-
]
|
105 |
-
)
|
106 |
-
resp = asyncio.run(safe_invoke([message]))
|
107 |
-
return resp.content if hasattr(resp, "content") else str(resp)
|
108 |
-
|
109 |
-
# ---------------------------------------------------------------------
|
110 |
-
# 5) Bild-Beschreibung
|
111 |
-
# ---------------------------------------------------------------------
|
112 |
-
@tool
|
113 |
-
@error_guard
|
114 |
-
def describe_image(file_path: str, prompt: str = "Describe this image.") -> str:
|
115 |
-
"""Gemini vision – Bild beschreiben."""
|
116 |
-
from PIL import Image
|
117 |
-
img = Image.open(file_path)
|
118 |
-
message = HumanMessage(
|
119 |
-
content=[
|
120 |
-
{"type": "text", "text": prompt},
|
121 |
-
img, # langchain übernimmt Encoding
|
122 |
-
]
|
123 |
-
)
|
124 |
-
resp = asyncio.run(safe_invoke([message]))
|
125 |
-
return resp.content
|
126 |
|
127 |
-
# ---------------------------------------------------------------------
|
128 |
-
# 6) OCR-Tool
|
129 |
-
# ---------------------------------------------------------------------
|
130 |
-
@tool
|
131 |
-
@error_guard
|
132 |
-
def ocr_image(file_path: str, lang: str = "eng") -> str:
|
133 |
-
"""Extract text from an image via pytesseract."""
|
134 |
-
try:
|
135 |
-
import pytesseract
|
136 |
-
from PIL import Image
|
137 |
-
text = pytesseract.image_to_string(Image.open(file_path), lang=lang)
|
138 |
-
return text.strip() or "No text found."
|
139 |
-
except Exception as e:
|
140 |
-
return f"ERROR: {e}"
|
141 |
|
142 |
-
# ---------------------------------------------------------------------
|
143 |
-
# 7) Tavily-Web-Suche
|
144 |
-
# ---------------------------------------------------------------------
|
145 |
@tool
|
146 |
@error_guard
|
147 |
-
def
|
148 |
-
"""
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
return "\n\n".join(f"{h['title']} – {h['url']}" for h in hits)
|
153 |
|
154 |
# ---------------------------------------------------------------------
|
155 |
-
#
|
156 |
# ---------------------------------------------------------------------
|
157 |
-
|
158 |
-
@error_guard
|
159 |
-
def simple_calculator(operation: str, a: float, b: float) -> float:
|
160 |
-
"""Basic maths (add, subtract, multiply, divide)."""
|
161 |
-
ops = {
|
162 |
-
"add": a + b,
|
163 |
-
"subtract": a - b,
|
164 |
-
"multiply": a * b,
|
165 |
-
"divide": a / b if b else float("inf"),
|
166 |
-
}
|
167 |
-
return ops.get(operation, f"ERROR: unknown op '{operation}'")
|
168 |
|
169 |
# ---------------------------------------------------------------------
|
170 |
-
# LLM
|
171 |
# ---------------------------------------------------------------------
|
172 |
gemini_llm = ChatGoogleGenerativeAI(
|
|
|
173 |
model="gemini-2.0-flash",
|
174 |
-
google_api_key=GOOGLE_API_KEY,
|
175 |
temperature=0,
|
176 |
max_output_tokens=2048,
|
177 |
-
).bind_tools(
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
#
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
# ---------------------------------------------------------------------
|
201 |
-
# System-Prompt
|
202 |
-
# ---------------------------------------------------------------------
|
203 |
-
system_prompt = SystemMessage(content="""
|
204 |
-
You are a helpful assistant tasked with answering questions using a set of tools.
|
205 |
-
Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
|
206 |
-
FINAL ANSWER: [YOUR FINAL ANSWER].
|
207 |
-
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
208 |
-
Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
|
209 |
-
""")
|
210 |
|
211 |
# ---------------------------------------------------------------------
|
212 |
-
# LangGraph –
|
213 |
# ---------------------------------------------------------------------
|
214 |
-
def
|
|
|
215 |
msgs = state["messages"]
|
216 |
if msgs[0].type != "system":
|
217 |
-
msgs = [
|
218 |
-
resp =
|
219 |
-
finished =
|
|
|
|
|
|
|
220 |
return {"messages": [resp], "should_end": finished}
|
221 |
|
222 |
def route(state):
|
223 |
return "END" if state["should_end"] else "tools"
|
224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
# ---------------------------------------------------------------------
|
226 |
-
#
|
227 |
# ---------------------------------------------------------------------
|
228 |
-
|
229 |
-
|
230 |
-
gemini_transcribe_audio, describe_image, ocr_image,
|
231 |
-
web_search, simple_calculator,
|
232 |
-
]
|
233 |
|
234 |
-
|
235 |
-
|
236 |
-
builder.add_node("tools", ToolNode(tools))
|
237 |
-
builder.add_edge(START, "assistant")
|
238 |
-
builder.add_conditional_edges("assistant", route, {"tools": "tools", "END": END})
|
239 |
|
240 |
-
|
241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# agent.py – LangChain · LangGraph · Gemini Flash
|
2 |
+
# ================================================
|
3 |
+
|
4 |
+
"""
|
5 |
+
Abhängigkeiten (requirements.txt):
|
6 |
+
----------------------------------
|
7 |
+
langchain==0.1.*
|
8 |
+
langgraph
|
9 |
+
google-generativeai
|
10 |
+
tavily-python
|
11 |
+
wikipedia-api
|
12 |
+
pandas
|
13 |
+
openpyxl
|
14 |
+
tabulate
|
15 |
+
"""
|
16 |
+
|
17 |
+
import os, re, time, functools
|
18 |
+
from typing import Dict, Any, List
|
19 |
|
20 |
+
import pandas as pd
|
21 |
+
from langgraph.graph import StateGraph, START, END, MessagesState
|
22 |
+
from langgraph.prebuilt import ToolNode, tools_condition
|
23 |
|
24 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
25 |
from langchain_core.messages import SystemMessage, HumanMessage
|
26 |
from langchain_core.tools import tool
|
|
|
|
|
27 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
28 |
+
from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
|
29 |
+
from langchain.tools.python.tool import PythonAstREPLTool
|
30 |
|
31 |
# ---------------------------------------------------------------------
|
32 |
+
# 0) Optionale LangSmith-Tracing (setze ENV: LANGCHAIN_API_KEY)
|
33 |
# ---------------------------------------------------------------------
|
34 |
+
if os.getenv("LANGCHAIN_API_KEY"):
|
35 |
+
os.environ.setdefault("LANGCHAIN_TRACING_V2", "true")
|
36 |
+
from langchain_community.utils import configure_langsmith
|
37 |
+
configure_langsmith(project_name="gaia-agent")
|
38 |
|
39 |
# ---------------------------------------------------------------------
|
40 |
+
# 1) Helfer: Fehler-Decorator + Backoff-Wrapper
|
41 |
# ---------------------------------------------------------------------
|
|
|
42 |
def error_guard(fn):
|
43 |
+
"""Fängt Tool-Fehler ab & gibt String zurück (bricht Agent nicht ab)."""
|
44 |
@functools.wraps(fn)
|
45 |
+
def wrapper(*args, **kw):
|
46 |
try:
|
47 |
+
return fn(*args, **kw)
|
48 |
except Exception as e:
|
49 |
return f"ERROR: {e}"
|
50 |
return wrapper
|
51 |
|
52 |
|
53 |
+
def with_backoff(fn, tries: int = 4, delay: int = 4):
|
54 |
+
"""Synchrones Retry-Wrapper für LLM-Aufrufe."""
|
55 |
+
for t in range(tries):
|
56 |
+
try:
|
57 |
+
return fn()
|
58 |
+
except Exception as e:
|
59 |
+
if ("429" in str(e) or "RateLimit" in str(e)) and t < tries - 1:
|
60 |
+
time.sleep(delay)
|
61 |
+
delay *= 2
|
62 |
+
continue
|
63 |
+
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
# ---------------------------------------------------------------------
|
66 |
+
# 2) Eigene Tools (CSV / Excel)
|
67 |
# ---------------------------------------------------------------------
|
|
|
|
|
68 |
@tool
|
69 |
@error_guard
|
70 |
def parse_csv(file_path: str, query: str = "") -> str:
|
71 |
+
"""Load a CSV file and (optional) run a pandas query."""
|
72 |
df = pd.read_csv(file_path)
|
73 |
if not query:
|
74 |
+
return f"Rows={len(df)}, Cols={list(df.columns)}"
|
75 |
try:
|
76 |
+
return df.query(query).to_markdown(index=False)
|
|
|
77 |
except Exception as e:
|
78 |
+
return f"ERROR query: {e}"
|
79 |
+
|
80 |
|
|
|
|
|
|
|
81 |
@tool
|
82 |
@error_guard
|
83 |
+
def parse_excel(file_path: str, sheet: str | int | None = None, query: str = "") -> str:
|
84 |
+
"""Load an Excel sheet (name or index) and (optional) run a pandas query."""
|
85 |
+
sheet_arg = int(sheet) if isinstance(sheet, str) and sheet.isdigit() else sheet or 0
|
86 |
+
df = pd.read_excel(file_path, sheet_name=sheet_arg)
|
87 |
if not query:
|
88 |
+
return f"Rows={len(df)}, Cols={list(df.columns)}"
|
89 |
try:
|
90 |
+
return df.query(query).to_markdown(index=False)
|
|
|
91 |
except Exception as e:
|
92 |
+
return f"ERROR query: {e}"
|
93 |
|
94 |
# ---------------------------------------------------------------------
|
95 |
+
# 3) Externe Search-Tools (Tavily, Wikipedia)
|
96 |
# ---------------------------------------------------------------------
|
97 |
@tool
|
98 |
@error_guard
|
99 |
+
def web_search(query: str, max_results: int = 5) -> str:
|
100 |
+
"""Search the web via Tavily and return markdown list of results."""
|
101 |
+
api_key = os.getenv("TAVILY_API_KEY")
|
102 |
+
hits = TavilySearchResults(max_results=max_results, api_key=api_key).invoke(query)
|
103 |
+
if not hits:
|
104 |
+
return "No results."
|
105 |
+
return "\n".join(f"{h['title']} – {h['url']}" for h in hits)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
|
|
|
|
|
|
108 |
@tool
|
109 |
@error_guard
|
110 |
+
def wiki_search(query: str, sentences: int = 3) -> str:
|
111 |
+
"""Quick Wikipedia summary."""
|
112 |
+
wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=4000)
|
113 |
+
res = wrapper.run(query)
|
114 |
+
return "\n".join(res.split(". ")[:sentences]) if res else "No article found."
|
|
|
115 |
|
116 |
# ---------------------------------------------------------------------
|
117 |
+
# 4) Python-REPL Tool (fertig aus LangChain)
|
118 |
# ---------------------------------------------------------------------
|
119 |
+
python_repl = PythonAstREPLTool()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
# ---------------------------------------------------------------------
|
122 |
+
# 5) LLM – Gemini Flash, an Tools gebunden
|
123 |
# ---------------------------------------------------------------------
|
124 |
gemini_llm = ChatGoogleGenerativeAI(
|
125 |
+
google_api_key=os.getenv("GOOGLE_API_KEY"),
|
126 |
model="gemini-2.0-flash",
|
|
|
127 |
temperature=0,
|
128 |
max_output_tokens=2048,
|
129 |
+
).bind_tools(
|
130 |
+
[web_search, wiki_search, parse_csv, parse_excel, python_repl],
|
131 |
+
return_named_tools=True,
|
132 |
+
)
|
133 |
+
|
134 |
+
# ---------------------------------------------------------------------
|
135 |
+
# 6) System-Prompt (ReAct, keine Prefixe im Final-Output!)
|
136 |
+
# ---------------------------------------------------------------------
|
137 |
+
SYSTEM_PROMPT = SystemMessage(
|
138 |
+
content=(
|
139 |
+
"You are a helpful assistant with access to Python tools.\n"
|
140 |
+
"• Think step by step.\n"
|
141 |
+
"• Call a tool when needed – reply in this JSON format:\n"
|
142 |
+
" {\"tool\": \"<tool_name>\", \"tool_input\": { ... }}\n"
|
143 |
+
"• When you have the answer, reply with the answer **only** "
|
144 |
+
"– no prefix, no explanations.\n"
|
145 |
+
"Answer format rules:\n"
|
146 |
+
" • Single number → no separators / units unless required.\n"
|
147 |
+
" • Single string → no articles/abbrev.\n"
|
148 |
+
" • List → comma + single space separated, keep required order.\n"
|
149 |
+
)
|
150 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
# ---------------------------------------------------------------------
|
153 |
+
# 7) LangGraph – Planner + Tools + Router
|
154 |
# ---------------------------------------------------------------------
|
155 |
+
def planner(state: MessagesState):
|
156 |
+
"""LLM-Planner – entscheidet, ob Tool nötig oder Final Answer erreicht."""
|
157 |
msgs = state["messages"]
|
158 |
if msgs[0].type != "system":
|
159 |
+
msgs = [SYSTEM_PROMPT] + msgs
|
160 |
+
resp = with_backoff(lambda: gemini_llm.invoke(msgs))
|
161 |
+
finished = (
|
162 |
+
not getattr(resp, "tool_calls", None) # keine Toolaufrufe
|
163 |
+
and "\n" not in resp.content # heuristik: kurze Endantwort
|
164 |
+
)
|
165 |
return {"messages": [resp], "should_end": finished}
|
166 |
|
167 |
def route(state):
|
168 |
return "END" if state["should_end"] else "tools"
|
169 |
|
170 |
+
# Tool-Knoten
|
171 |
+
TOOLS = [web_search, wiki_search, parse_csv, parse_excel, python_repl]
|
172 |
+
|
173 |
+
graph = StateGraph(MessagesState)
|
174 |
+
graph.add_node("planner", planner)
|
175 |
+
graph.add_node("tools", ToolNode(TOOLS))
|
176 |
+
graph.add_edge(START, "planner")
|
177 |
+
graph.add_conditional_edges("planner", route, {"tools": "tools", "END": END})
|
178 |
+
|
179 |
+
# compile → LangGraph-Executor
|
180 |
+
agent_executor = graph.compile(max_iterations=8)
|
181 |
+
|
182 |
# ---------------------------------------------------------------------
|
183 |
+
# 8) Öffentliche Klasse – wird von app.py / logic.py verwendet
|
184 |
# ---------------------------------------------------------------------
|
185 |
+
class GaiaAgent:
|
186 |
+
"""LangChain·LangGraph-Agent für GAIA Level 1."""
|
|
|
|
|
|
|
187 |
|
188 |
+
def __init__(self):
|
189 |
+
print("✅ GaiaAgent initialised (LangGraph)")
|
|
|
|
|
|
|
190 |
|
191 |
+
def __call__(self, task_id: str, question: str) -> str:
|
192 |
+
"""Run the agent on a single GAIA question → exact answer string."""
|
193 |
+
start_state = {"messages": [HumanMessage(content=question)]}
|
194 |
+
final_state = agent_executor.invoke(start_state)
|
195 |
+
# letze Message enthält Antwort
|
196 |
+
answer = final_state["messages"][-1].content
|
197 |
+
return answer.strip()
|