Spaces:
Runtime error
Runtime error
Update agent.py
Browse files
agent.py
CHANGED
@@ -1,37 +1,29 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
• Fehler-Retry-Logik
|
6 |
-
"""
|
7 |
-
import os, base64, mimetypes, subprocess, json, tempfile
|
8 |
-
import functools
|
9 |
-
from typing import Any
|
10 |
|
11 |
from langgraph.graph import START, StateGraph, MessagesState
|
12 |
-
from langgraph.prebuilt import tools_condition, ToolNode
|
13 |
-
|
14 |
from langchain_core.messages import SystemMessage, HumanMessage
|
|
|
|
|
15 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
16 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
17 |
|
18 |
-
#
|
19 |
-
#
|
20 |
-
#
|
21 |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
|
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
max_output_tokens=2048,
|
28 |
-
)
|
29 |
-
|
30 |
-
# ----------------------------------------------------------------------
|
31 |
-
# 2 ── ERROR-WRAPPER (garantiert "ERROR:"-String statt Exception)
|
32 |
-
# ----------------------------------------------------------------------
|
33 |
def error_guard(fn):
|
34 |
-
@functools.wraps(fn)
|
35 |
def wrapper(*args, **kwargs):
|
36 |
try:
|
37 |
return fn(*args, **kwargs)
|
@@ -39,198 +31,213 @@ def error_guard(fn):
|
|
39 |
return f"ERROR: {e}"
|
40 |
return wrapper
|
41 |
|
42 |
-
|
43 |
-
#
|
44 |
-
#
|
45 |
-
|
46 |
-
|
47 |
-
def simple_calculator(operation: str, a: float, b: float) -> float:
|
48 |
-
"""Basic maths: add, subtract, multiply, divide."""
|
49 |
-
ops = {"add": a + b, "subtract": a - b, "multiply": a * b,
|
50 |
-
"divide": a / b if b else float("inf")}
|
51 |
-
return ops.get(operation, "ERROR: unknown operation")
|
52 |
|
53 |
@tool
|
54 |
@error_guard
|
55 |
def fetch_gaia_file(task_id: str) -> str:
|
56 |
-
"""Download attachment for
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
@tool
|
67 |
@error_guard
|
68 |
def parse_csv(file_path: str, query: str = "") -> str:
|
69 |
-
"""Load CSV
|
70 |
-
import pandas as pd
|
71 |
df = pd.read_csv(file_path)
|
72 |
if not query:
|
73 |
-
return df.
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
@tool
|
77 |
@error_guard
|
78 |
def parse_excel(file_path: str, query: str = "") -> str:
|
79 |
-
"""Load first sheet
|
80 |
-
import pandas as pd
|
81 |
df = pd.read_excel(file_path)
|
82 |
if not query:
|
83 |
-
return df.
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
89 |
@tool
|
90 |
@error_guard
|
91 |
-
def
|
92 |
-
"""
|
93 |
-
mime, _ = mimetypes.guess_type(file_path)
|
94 |
-
if not (mime and mime.startswith("image/")):
|
95 |
-
return "ERROR: not an image."
|
96 |
with open(file_path, "rb") as f:
|
97 |
b64 = base64.b64encode(f.read()).decode()
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
@tool
|
106 |
@error_guard
|
107 |
-
def
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
]
|
119 |
-
resp = llm.invoke([HumanMessage(content=content)])
|
120 |
return resp.content
|
121 |
|
122 |
-
#
|
123 |
-
#
|
124 |
-
#
|
125 |
@tool
|
126 |
@error_guard
|
127 |
def ocr_image(file_path: str, lang: str = "eng") -> str:
|
128 |
-
"""Extract text from image
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
137 |
@tool
|
138 |
@error_guard
|
139 |
def web_search(query: str, max_results: int = 5) -> str:
|
140 |
-
"""
|
141 |
-
|
142 |
-
hits = search.invoke(query)
|
143 |
if not hits:
|
144 |
-
return "
|
145 |
-
return "\n\n".join(f"{
|
146 |
-
|
147 |
-
|
148 |
-
#
|
149 |
-
#
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
Once you have written **Final Answer:** you are done – do **not** call any further tool.
|
188 |
-
"""
|
189 |
-
))
|
190 |
-
|
191 |
-
# ----------------------------------------------------------------------
|
192 |
-
# 8 ── LangGraph Nodes
|
193 |
-
# ----------------------------------------------------------------------
|
194 |
-
tools = [
|
195 |
-
fetch_gaia_file,
|
196 |
-
parse_csv,
|
197 |
-
parse_excel,
|
198 |
-
gemini_transcribe_audio,
|
199 |
-
ocr_image,
|
200 |
-
describe_image,
|
201 |
-
web_search,
|
202 |
-
simple_calculator,
|
203 |
-
]
|
204 |
|
205 |
-
|
|
|
|
|
206 |
|
|
|
207 |
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
if not content.startswith("ERROR"):
|
213 |
-
return resp
|
214 |
-
msgs.append(
|
215 |
-
SystemMessage(content="Previous tool call returned ERROR. Try another approach.")
|
216 |
-
)
|
217 |
-
return resp
|
218 |
|
|
|
|
|
|
|
219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
def assistant(state: MessagesState):
|
221 |
msgs = state["messages"]
|
222 |
-
if
|
223 |
msgs = [system_prompt] + msgs
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
-
# ----------------------------------------------------------------------
|
227 |
-
# 9 ── Graph
|
228 |
-
# ----------------------------------------------------------------------
|
229 |
builder = StateGraph(MessagesState)
|
230 |
builder.add_node("assistant", assistant)
|
231 |
builder.add_node("tools", ToolNode(tools))
|
232 |
builder.add_edge(START, "assistant")
|
233 |
-
builder.add_conditional_edges("assistant",
|
234 |
-
builder.add_edge("tools", "assistant")
|
235 |
|
|
|
236 |
agent_executor = builder.compile()
|
|
|
1 |
+
# agent.py – Gemini 2.0 Flash · LangGraph · Mehrere Tools
|
2 |
+
# =========================================================
|
3 |
+
import os, asyncio, base64, mimetypes, tempfile, functools, json
|
4 |
+
from typing import Dict, Any, List, Optional
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
from langgraph.graph import START, StateGraph, MessagesState
|
7 |
+
from langgraph.prebuilt import tools_condition, ToolNode, END
|
8 |
+
|
9 |
from langchain_core.messages import SystemMessage, HumanMessage
|
10 |
+
from langchain_core.tools import tool
|
11 |
+
|
12 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
13 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
14 |
|
15 |
+
# ---------------------------------------------------------------------
|
16 |
+
# Konstanten / API-Keys
|
17 |
+
# ---------------------------------------------------------------------
|
18 |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
19 |
+
TAVILY_KEY = os.getenv("TAVILY_API_KEY")
|
20 |
|
21 |
+
# ---------------------------------------------------------------------
|
22 |
+
# Fehler-Wrapper – behält Doc-String dank wraps
|
23 |
+
# ---------------------------------------------------------------------
|
24 |
+
import functools
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def error_guard(fn):
|
26 |
+
@functools.wraps(fn)
|
27 |
def wrapper(*args, **kwargs):
|
28 |
try:
|
29 |
return fn(*args, **kwargs)
|
|
|
31 |
return f"ERROR: {e}"
|
32 |
return wrapper
|
33 |
|
34 |
+
|
35 |
+
# ---------------------------------------------------------------------
|
36 |
+
# 1) fetch_gaia_file – Datei vom GAIA-Server holen
|
37 |
+
# ---------------------------------------------------------------------
|
38 |
+
GAIA_FILE_ENDPOINT = "https://agents-course-unit4-scoring.hf.space/file"
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
@tool
|
41 |
@error_guard
|
42 |
def fetch_gaia_file(task_id: str) -> str:
|
43 |
+
"""Download the attachment for the given GAIA task_id and return local path."""
|
44 |
+
url = f"{GAIA_FILE_ENDPOINT}/{task_id}"
|
45 |
+
try:
|
46 |
+
response = requests.get(url, timeout=30)
|
47 |
+
response.raise_for_status()
|
48 |
+
file_name = response.headers.get("x-gaia-filename", f"{task_id}")
|
49 |
+
tmp_path = tempfile.gettempdir() + "/" + file_name
|
50 |
+
with open(tmp_path, "wb") as f:
|
51 |
+
f.write(response.content)
|
52 |
+
return tmp_path
|
53 |
+
except Exception as e:
|
54 |
+
return f"ERROR: could not fetch file – {e}"
|
55 |
+
|
56 |
+
# ---------------------------------------------------------------------
|
57 |
+
# 2) CSV-Parser
|
58 |
+
# ---------------------------------------------------------------------
|
59 |
+
import pandas as pd
|
60 |
|
61 |
@tool
|
62 |
@error_guard
|
63 |
def parse_csv(file_path: str, query: str = "") -> str:
|
64 |
+
"""Load a CSV file and answer a quick pandas query (optional)."""
|
|
|
65 |
df = pd.read_csv(file_path)
|
66 |
if not query:
|
67 |
+
return f"Loaded CSV with {len(df)} rows and {len(df.columns)} cols.\nColumns: {list(df.columns)}"
|
68 |
+
try:
|
69 |
+
result = df.query(query)
|
70 |
+
return result.to_markdown()
|
71 |
+
except Exception as e:
|
72 |
+
return f"ERROR in pandas query: {e}"
|
73 |
+
|
74 |
+
# ---------------------------------------------------------------------
|
75 |
+
# 3) Excel-Parser
|
76 |
+
# ---------------------------------------------------------------------
|
77 |
@tool
|
78 |
@error_guard
|
79 |
def parse_excel(file_path: str, query: str = "") -> str:
|
80 |
+
"""Load an Excel file (first sheet) and answer a pandas query (optional)."""
|
|
|
81 |
df = pd.read_excel(file_path)
|
82 |
if not query:
|
83 |
+
return f"Loaded Excel with {len(df)} rows and {len(df.columns)} cols.\nColumns: {list(df.columns)}"
|
84 |
+
try:
|
85 |
+
result = df.query(query)
|
86 |
+
return result.to_markdown()
|
87 |
+
except Exception as e:
|
88 |
+
return f"ERROR in pandas query: {e}"
|
89 |
+
|
90 |
+
# ---------------------------------------------------------------------
|
91 |
+
# 4) Gemini-Audio-Transkription
|
92 |
+
# ---------------------------------------------------------------------
|
93 |
@tool
|
94 |
@error_guard
|
95 |
+
def gemini_transcribe_audio(file_path: str, prompt: str = "Transcribe the audio.") -> str:
|
96 |
+
"""Use Gemini to transcribe an audio file."""
|
|
|
|
|
|
|
97 |
with open(file_path, "rb") as f:
|
98 |
b64 = base64.b64encode(f.read()).decode()
|
99 |
+
mime = mimetypes.guess_type(file_path)[0] or "audio/mpeg"
|
100 |
+
message = HumanMessage(
|
101 |
+
content=[
|
102 |
+
{"type": "text", "text": prompt},
|
103 |
+
{"type": "media", "data": b64, "mime_type": mime},
|
104 |
+
]
|
105 |
+
)
|
106 |
+
resp = asyncio.run(gemini_llm.invoke([message]))
|
107 |
+
return resp.content if hasattr(resp, "content") else str(resp)
|
108 |
+
|
109 |
+
# ---------------------------------------------------------------------
|
110 |
+
# 5) Bild-Beschreibung
|
111 |
+
# ---------------------------------------------------------------------
|
112 |
@tool
|
113 |
@error_guard
|
114 |
+
def describe_image(file_path: str, prompt: str = "Describe this image.") -> str:
|
115 |
+
"""Gemini vision – Bild beschreiben."""
|
116 |
+
from PIL import Image
|
117 |
+
img = Image.open(file_path)
|
118 |
+
message = HumanMessage(
|
119 |
+
content=[
|
120 |
+
{"type": "text", "text": prompt},
|
121 |
+
img, # langchain übernimmt Encoding
|
122 |
+
]
|
123 |
+
)
|
124 |
+
resp = asyncio.run(gemini_llm.invoke([message]))
|
|
|
|
|
125 |
return resp.content
|
126 |
|
127 |
+
# ---------------------------------------------------------------------
|
128 |
+
# 6) OCR-Tool
|
129 |
+
# ---------------------------------------------------------------------
|
130 |
@tool
|
131 |
@error_guard
|
132 |
def ocr_image(file_path: str, lang: str = "eng") -> str:
|
133 |
+
"""Extract text from an image via pytesseract."""
|
134 |
+
try:
|
135 |
+
import pytesseract
|
136 |
+
from PIL import Image
|
137 |
+
text = pytesseract.image_to_string(Image.open(file_path), lang=lang)
|
138 |
+
return text.strip() or "No text found."
|
139 |
+
except Exception as e:
|
140 |
+
return f"ERROR: {e}"
|
141 |
+
|
142 |
+
# ---------------------------------------------------------------------
|
143 |
+
# 7) Tavily-Web-Suche
|
144 |
+
# ---------------------------------------------------------------------
|
145 |
@tool
|
146 |
@error_guard
|
147 |
def web_search(query: str, max_results: int = 5) -> str:
|
148 |
+
"""Search the web via Tavily and return a markdown list of results."""
|
149 |
+
hits = TavilySearchResults(max_results=max_results, api_key=TAVILY_KEY).invoke(query)
|
|
|
150 |
if not hits:
|
151 |
+
return "No results."
|
152 |
+
return "\n\n".join(f"{h['title']} – {h['url']}" for h in hits)
|
153 |
+
|
154 |
+
# ---------------------------------------------------------------------
|
155 |
+
# 8) Kleiner Rechner
|
156 |
+
# ---------------------------------------------------------------------
|
157 |
+
@tool
|
158 |
+
@error_guard
|
159 |
+
def simple_calculator(operation: str, a: float, b: float) -> float:
|
160 |
+
"""Basic maths (add, subtract, multiply, divide)."""
|
161 |
+
ops = {
|
162 |
+
"add": a + b,
|
163 |
+
"subtract": a - b,
|
164 |
+
"multiply": a * b,
|
165 |
+
"divide": a / b if b else float("inf"),
|
166 |
+
}
|
167 |
+
return ops.get(operation, f"ERROR: unknown op '{operation}'")
|
168 |
+
|
169 |
+
# ---------------------------------------------------------------------
|
170 |
+
# LLM + Semaphore-Throttle (Gemini 2.0 Flash)
|
171 |
+
# ---------------------------------------------------------------------
|
172 |
+
gemini_llm = ChatGoogleGenerativeAI(
|
173 |
+
model="gemini-2.0-flash",
|
174 |
+
google_api_key=GOOGLE_API_KEY,
|
175 |
+
temperature=0,
|
176 |
+
max_output_tokens=2048,
|
177 |
+
).bind_tools([
|
178 |
+
fetch_gaia_file, parse_csv, parse_excel,
|
179 |
+
gemini_transcribe_audio, describe_image, ocr_image,
|
180 |
+
web_search, simple_calculator,
|
181 |
+
])
|
182 |
+
|
183 |
+
LLM_SEMA = asyncio.Semaphore(3) # 3 gleichz. Anfragen ≈ < 15/min
|
184 |
+
|
185 |
+
async def safe_invoke(msgs: List[Any]):
|
186 |
+
async with LLM_SEMA:
|
187 |
+
return gemini_llm.invoke(msgs)
|
188 |
+
|
189 |
+
# ---------------------------------------------------------------------
|
190 |
+
# System-Prompt
|
191 |
+
# ---------------------------------------------------------------------
|
192 |
+
system_prompt = SystemMessage(content="""
|
193 |
+
You are GAIA-Assist, a precise, tool-using agent.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
+
If a question mentions an attachment:
|
196 |
+
1. Call fetch_gaia_file(task_id)
|
197 |
+
2. Use exactly one specialised parser tool on the returned path.
|
198 |
|
199 |
+
Otherwise decide between web_search or simple_calculator.
|
200 |
|
201 |
+
Format for a tool call:
|
202 |
+
Thought: Do I need to use a tool? Yes
|
203 |
+
Action: <tool name>
|
204 |
+
Action Input: <JSON arguments>
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
+
Format for final answer:
|
207 |
+
Thought: Do I need to use a tool? No
|
208 |
+
Final Answer: <your answer>
|
209 |
|
210 |
+
Stop once you output "Final Answer:".
|
211 |
+
""")
|
212 |
+
|
213 |
+
# ---------------------------------------------------------------------
|
214 |
+
# LangGraph – Assistant-Node
|
215 |
+
# ---------------------------------------------------------------------
|
216 |
def assistant(state: MessagesState):
|
217 |
msgs = state["messages"]
|
218 |
+
if msgs[0].type != "system":
|
219 |
msgs = [system_prompt] + msgs
|
220 |
+
resp = asyncio.run(safe_invoke(msgs))
|
221 |
+
finished = resp.content.lower().lstrip().startswith("final answer") or not resp.tool_calls
|
222 |
+
return {"messages": [resp], "should_end": finished}
|
223 |
+
|
224 |
+
def route(state):
|
225 |
+
return "END" if state["should_end"] else "tools"
|
226 |
+
|
227 |
+
# ---------------------------------------------------------------------
|
228 |
+
# Tools-Liste & Graph
|
229 |
+
# ---------------------------------------------------------------------
|
230 |
+
tools = [
|
231 |
+
fetch_gaia_file, parse_csv, parse_excel,
|
232 |
+
gemini_transcribe_audio, describe_image, ocr_image,
|
233 |
+
web_search, simple_calculator,
|
234 |
+
]
|
235 |
|
|
|
|
|
|
|
236 |
builder = StateGraph(MessagesState)
|
237 |
builder.add_node("assistant", assistant)
|
238 |
builder.add_node("tools", ToolNode(tools))
|
239 |
builder.add_edge(START, "assistant")
|
240 |
+
builder.add_conditional_edges("assistant", route, {"tools": "tools", "END": END})
|
|
|
241 |
|
242 |
+
# Compile
|
243 |
agent_executor = builder.compile()
|