ZeroTimo commited on
Commit
7d74ca3
·
verified ·
1 Parent(s): a92112f

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +157 -254
agent.py CHANGED
@@ -1,291 +1,195 @@
1
- import os
2
- import pandas as pd
3
- from langgraph.graph import StateGraph, START, MessagesState
4
- from langgraph.prebuilt import tools_condition, ToolNode
5
- from langchain_google_genai import ChatGoogleGenerativeAI
6
- from langchain_core.tools import tool
7
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
8
- from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
9
- from langchain_core.messages import SystemMessage, HumanMessage
10
- import requests
11
- import tempfile
12
-
13
- # Lade Umgebungsvariablen (Google API Key)
 
 
 
 
 
 
14
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
15
 
16
- # === Tools definieren ===
 
 
 
 
 
17
 
18
- GAIA_BASE_URL = "https://agents-course-unit4-scoring.hf.space"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  @tool
 
21
  def fetch_gaia_file(task_id: str) -> str:
22
- """
23
- Download the file attached to a GAIA task and return the local file-path.
24
-
25
- Args:
26
- task_id: The GAIA task_id (string in the JSON payload).
27
-
28
- Returns:
29
- Absolute path to the downloaded temp-file.
30
- """
31
- try:
32
- url = f"{GAIA_BASE_URL}/files/{task_id}"
33
- response = requests.get(url, timeout=20)
34
- response.raise_for_status()
35
-
36
- # Server liefert den echten Dateinamen im Header – fallback auf "download"
37
- filename = (
38
- response.headers.get("x-filename") or
39
- response.headers.get("content-disposition", "download").split("filename=")[-1].strip('"')
40
- )
41
- if not filename:
42
- filename = f"{task_id}.bin"
43
-
44
- tmp_path = os.path.join(tempfile.gettempdir(), filename)
45
- with open(tmp_path, "wb") as f:
46
- f.write(response.content)
47
-
48
- return tmp_path
49
- except Exception as e:
50
- return f"ERROR: could not download file for task {task_id}: {e}"
51
 
52
  @tool
 
53
  def parse_csv(file_path: str, query: str = "") -> str:
54
- """
55
- Load a CSV file from `file_path` and optionally run a simple analysis query.
56
-
57
- Args:
58
- file_path: absolute path to a CSV file (from fetch_gaia_file)
59
- query: optional natural-language instruction, e.g.
60
- "sum of column Sales where Category != 'Drinks'"
61
-
62
- Returns:
63
- A concise string with the answer OR a preview of the dataframe
64
- if no query given.
65
- """
66
- try:
67
- df = pd.read_csv(file_path)
68
-
69
- # Auto-preview if kein query
70
- if not query:
71
- preview = df.head(5).to_markdown(index=False)
72
- return f"CSV loaded. First rows:\n\n{preview}\n\nColumns: {', '.join(df.columns)}"
73
-
74
- # Mini-query-engine (sehr simpel, reicht für Summen / Mittelwerte)
75
- query_lc = query.lower()
76
- if "sum" in query_lc:
77
- # ermitteln, welche Spalte summiert werden soll
78
- for col in df.columns:
79
- if col.lower() in query_lc:
80
- s = df[col]
81
- if "where" in query_lc:
82
- # naive Filter-Parsing: where <col> != 'Drinks'
83
- cond_part = query_lc.split("where", 1)[1].strip()
84
- # SEHR einfaches != oder == Parsing
85
- if "!=" in cond_part:
86
- key, val = [x.strip().strip("'\"") for x in cond_part.split("!=")]
87
- s = df.loc[df[key] != val, col]
88
- elif "==" in cond_part:
89
- key, val = [x.strip().strip("'\"") for x in cond_part.split("==")]
90
- s = df.loc[df[key] == val, col]
91
- return str(round(s.sum(), 2))
92
- # Fallback
93
- return "Query type not supported by parse_csv."
94
- except Exception as e:
95
- return f"ERROR parsing CSV: {e}"
96
 
97
  @tool
 
98
  def parse_excel(file_path: str, query: str = "") -> str:
99
- """
100
- Identisch zu parse_csv, nur für XLS/XLSX.
101
- """
102
- try:
103
- df = pd.read_excel(file_path)
104
-
105
- if not query:
106
- preview = df.head(5).to_markdown(index=False)
107
- return f"Excel loaded. First rows:\n\n{preview}\n\nColumns: {', '.join(df.columns)}"
108
-
109
- query_lc = query.lower()
110
- if "sum" in query_lc:
111
- for col in df.columns:
112
- if col.lower() in query_lc:
113
- s = df[col]
114
- if "where" in query_lc:
115
- cond_part = query_lc.split("where", 1)[1].strip()
116
- if "!=" in cond_part:
117
- key, val = [x.strip().strip("'\"") for x in cond_part.split("!=")]
118
- s = df.loc[df[key] != val, col]
119
- elif "==" in cond_part:
120
- key, val = [x.strip().strip("'\"") for x in cond_part.split("==")]
121
- s = df.loc[df[key] == val, col]
122
- return str(round(s.sum(), 2))
123
- return "Query type not supported by parse_excel."
124
- except Exception as e:
125
- return f"ERROR parsing Excel: {e}"
126
-
127
- @tool
128
- def transcribe_audio(file_path: str, language: str = "en") -> str:
129
- """
130
- Transcribe an audio file (MP3/WAV/etc.) using Faster-Whisper.
131
-
132
- Args:
133
- file_path: absolute path to an audio file (from fetch_gaia_file)
134
- language: ISO language code, default "en"
135
-
136
- Returns:
137
- Full transcription as plain text, or "ERROR …"
138
- """
139
- try:
140
- from faster_whisper import WhisperModel
141
-
142
- # Tiny model reicht für kurze Sprachmemos, ~75 MB
143
- model = WhisperModel("tiny", device="cpu", compute_type="int8")
144
-
145
- segments, _ = model.transcribe(file_path, language=language)
146
- transcript = " ".join(segment.text.strip() for segment in segments).strip()
147
-
148
- if not transcript:
149
- return "ERROR: transcription empty."
150
- return transcript
151
- except Exception as e:
152
- return f"ERROR: audio transcription failed – {e}"
153
-
154
-
155
  @tool
156
- def simple_calculator(operation: str, a: float, b: float) -> float:
157
- """Basic maths: add, subtract, multiply, divide and modulo."""
158
- if operation == "add":
159
- return a + b
160
- if operation == "subtract":
161
- return a - b
162
- if operation == "multiply":
163
- return a * b
164
- if operation == "divide":
165
- return a / b if b else float("inf")
166
- if operation == "modulo":
167
- return a % b
168
- return 0.0
169
-
170
 
171
  @tool
172
- def wiki_search(query: str) -> str:
173
- """Search Wikipedia for a query and return the result."""
174
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
175
- return "\n\n".join(doc.page_content for doc in search_docs)
176
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  @tool
178
- def arxiv_search(query: str) -> str:
179
- """Search Arxiv for academic papers about a query."""
180
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
181
- return "\n\n".join(doc.page_content[:1000] for doc in search_docs)
182
-
 
 
 
 
 
 
183
  @tool
 
184
  def web_search(query: str) -> str:
185
- """Perform a DuckDuckGo web search."""
186
- wrapper = DuckDuckGoSearchAPIWrapper(max_results=5)
187
- results = wrapper.run(query)
188
- return results
189
-
190
- # === System Prompt definieren ===
191
- system_prompt = SystemMessage(
192
- content=(
193
- "You are a focused, factual AI agent competing on the GAIA evaluation.\n"
194
- "\n"
195
- "GENERAL RULES\n"
196
- "-------------\n"
197
- "1. Always try to answer every question.\n"
198
- "2. If you are NOT 100 % certain, prefer using a TOOL.\n"
199
- "3. Never invent facts.\n"
200
- "\n"
201
- "TOOLS\n"
202
- "-----\n"
203
- "- fetch_gaia_file(task_id): downloads any attachment for the current task.\n"
204
- "- parse_csv(file_path, query): analyse CSV files.\n"
205
- "- parse_excel(file_path, query): analyse Excel files.\n"
206
- "- transcribe_audio(file_path): transcribe MP3 / WAV audio.\n"
207
- "- wiki_search(query): query English Wikipedia.\n"
208
- "- arxiv_search(query): query arXiv.\n"
209
- "- web_search(query): DuckDuckGo web search.\n"
210
- "- simple_calculator(operation,a,b): basic maths.\n"
211
- "\n"
212
- "WHEN TO USE WHICH TOOL\n"
213
- "----------------------\n"
214
- "・If the prompt or GAIA metadata mentions an *attached* file, FIRST call "
215
- "fetch_gaia_file with the given task_id. Then:\n"
216
- " • CSV → parse_csv\n"
217
- " • XLS/XLSX → parse_excel\n"
218
- " • MP3/WAV → transcribe_audio (language auto-detect is OK)\n"
219
- " • Image → (currently unsupported) answer that image processing is unavailable\n"
220
- "・If you need factual data (dates, numbers, names) → wiki_search or web_search.\n"
221
- "・If you need a scientific paper → arxiv_search.\n"
222
- "・If a numeric operation is required → simple_calculator.\n"
223
- "\n"
224
- "ERROR HANDLING\n"
225
- "--------------\n"
226
- "If a tool call returns a string that starts with \"ERROR:\", IMMEDIATELY think of "
227
- "an alternative strategy: retry with a different tool or modified parameters. "
228
- "Do not repeat the same failing call twice.\n"
229
- "\n"
230
- "OUTPUT FORMAT\n"
231
- "-------------\n"
232
- "Follow the exact format asked in the question (e.g. single word, CSV, comma-list). "
233
- "Do not add extra commentary.\n"
234
- )
235
- )
236
-
237
- # === LLM definieren ===
238
- llm = ChatGoogleGenerativeAI(
239
- model="gemini-2.0-flash",
240
- google_api_key=GOOGLE_API_KEY,
241
- temperature=0,
242
- max_output_tokens=2048
243
- )
244
-
245
- # === Tools in LLM einbinden ===
246
  tools = [
247
  fetch_gaia_file,
248
  parse_csv,
249
  parse_excel,
250
- transcribe_audio,
251
- wiki_search,
252
- arxiv_search,
253
  web_search,
254
  simple_calculator,
255
  ]
 
256
  llm_with_tools = llm.bind_tools(tools)
257
 
258
- def safe_llm_invoke(messages):
259
- """
260
- Ruft LLM einmal auf. Wenn das Ergebnis mit ERROR beginnt,
261
- ruft es genau EIN weiteres Mal auf – jetzt weiß das LLM,
262
- dass der vorige Tool-Call fehlgeschlagen ist.
263
- """
264
- max_attempts = 2
265
- for attempt in range(max_attempts):
266
- result = llm_with_tools.invoke(messages)
267
- content = result.content if hasattr(result, "content") else ""
268
- if "ERROR:" not in content:
269
- return result
270
- # Fehler: füge eine System-Korrektur hinzu und versuche erneut
271
- messages.append(
272
- SystemMessage(
273
- content="Previous tool call returned an ERROR. "
274
- "Try a different tool or revise the input."
275
- )
276
  )
277
- # nach max_attempts immer noch Fehler → zurückgeben
278
- return result
279
 
280
 
281
- # === Nodes für LangGraph ===
282
  def assistant(state: MessagesState):
283
  msgs = state["messages"]
284
  if not msgs or msgs[0].type != "system":
285
  msgs = [system_prompt] + msgs
286
- return {"messages": [llm_with_tools.invoke(msgs)]}
287
 
288
- # === LangGraph bauen ===
 
 
289
  builder = StateGraph(MessagesState)
290
  builder.add_node("assistant", assistant)
291
  builder.add_node("tools", ToolNode(tools))
@@ -293,5 +197,4 @@ builder.add_edge(START, "assistant")
293
  builder.add_conditional_edges("assistant", tools_condition)
294
  builder.add_edge("tools", "assistant")
295
 
296
- # === Agent Executor ===
297
  agent_executor = builder.compile()
 
1
+ """
2
+ agent.py – LangGraph-Agent mit
3
+ Gemini 2.0 Flash
4
+ Datei-Tools (CSV, Excel, Audio, Bild-Describe, OCR)
5
+ Fehler-Retry-Logik
6
+ """
7
+ import os, base64, mimetypes, subprocess, json, tempfile
8
+ from typing import Any
9
+
10
+ from langgraph.graph import START, StateGraph, MessagesState
11
+ from langgraph.prebuilt import tools_condition, ToolNode
12
+ from langchain_core.tools import tool
13
+ from langchain_core.messages import SystemMessage, HumanMessage
14
+ from langchain_google_genai import ChatGoogleGenerativeAI
15
+ from langchain_community.tools.duckduckgo_search import DuckDuckGoSearchResults as DDGS
16
+
17
+ # ----------------------------------------------------------------------
18
+ # 1 ── ENV / LLM
19
+ # ----------------------------------------------------------------------
20
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
21
 
22
+ llm = ChatGoogleGenerativeAI(
23
+ model="gemini-2.0-flash",
24
+ google_api_key=GOOGLE_API_KEY,
25
+ temperature=0,
26
+ max_output_tokens=2048,
27
+ )
28
 
29
+ # ----------------------------------------------------------------------
30
+ # 2 ── ERROR-WRAPPER (garantiert "ERROR:"-String statt Exception)
31
+ # ----------------------------------------------------------------------
32
+ def error_guard(fn):
33
+ def wrapper(*args, **kwargs):
34
+ try:
35
+ return fn(*args, **kwargs)
36
+ except Exception as e:
37
+ return f"ERROR: {e}"
38
+ return wrapper
39
+
40
+ # ----------------------------------------------------------------------
41
+ # 3 ── BASIS-TOOLS
42
+ # ----------------------------------------------------------------------
43
+ @tool
44
+ @error_guard
45
+ def simple_calculator(operation: str, a: float, b: float) -> float:
46
+ """Basic maths: add, subtract, multiply, divide."""
47
+ ops = {"add": a + b, "subtract": a - b, "multiply": a * b,
48
+ "divide": a / b if b else float("inf")}
49
+ return ops.get(operation, "ERROR: unknown operation")
50
 
51
  @tool
52
+ @error_guard
53
  def fetch_gaia_file(task_id: str) -> str:
54
+ """Download attachment for current GAIA task_id; returns local file path."""
55
+ import requests, pathlib, uuid
56
+ url = f"https://agents-course-unit4-scoring.hf.space/file/{task_id}"
57
+ r = requests.get(url, timeout=15)
58
+ r.raise_for_status()
59
+ suffix = pathlib.Path(url).suffix or ""
60
+ fp = pathlib.Path(tempfile.gettempdir())/f"{uuid.uuid4().hex}{suffix}"
61
+ fp.write_bytes(r.content)
62
+ return str(fp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  @tool
65
+ @error_guard
66
  def parse_csv(file_path: str, query: str = "") -> str:
67
+ """Load CSV & answer query using pandas.eval."""
68
+ import pandas as pd
69
+ df = pd.read_csv(file_path)
70
+ if not query:
71
+ return df.head().to_markdown()
72
+ return str(pd.eval(query, local_dict={"df": df}))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  @tool
75
+ @error_guard
76
  def parse_excel(file_path: str, query: str = "") -> str:
77
+ """Load first sheet of Excel & answer query using pandas.eval."""
78
+ import pandas as pd
79
+ df = pd.read_excel(file_path)
80
+ if not query:
81
+ return df.head().to_markdown()
82
+ return str(pd.eval(query, local_dict={"df": df}))
83
+
84
+ # ----------------------------------------------------------------------
85
+ # 4 ── GEMINI MULTIMODAL-TOOLS
86
+ # ----------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  @tool
88
+ @error_guard
89
+ def describe_image(file_path: str, prompt: str = "Describe the image.") -> str:
90
+ """Send a local image (base64) to Gemini Vision and return description."""
91
+ mime, _ = mimetypes.guess_type(file_path)
92
+ if not (mime and mime.startswith("image/")):
93
+ return "ERROR: not an image."
94
+ with open(file_path, "rb") as f:
95
+ b64 = base64.b64encode(f.read()).decode()
96
+ content = [
97
+ {"type": "text", "text": prompt},
98
+ {"type": "image_url", "image_url": f"data:{mime};base64,{b64}"},
99
+ ]
100
+ resp = llm.invoke([HumanMessage(content=content)])
101
+ return resp.content
102
 
103
  @tool
104
+ @error_guard
105
+ def gemini_transcribe_audio(file_path: str,
106
+ prompt: str = "Transcribe the audio.") -> str:
107
+ """Transcribe audio via Gemini multimodal."""
108
+ mime, _ = mimetypes.guess_type(file_path)
109
+ if not (mime and mime.startswith("audio/")):
110
+ return "ERROR: not audio."
111
+ with open(file_path, "rb") as f:
112
+ b64 = base64.b64encode(f.read()).decode()
113
+ content = [
114
+ {"type": "text", "text": prompt},
115
+ {"type": "media", "data": b64, "mime_type": mime},
116
+ ]
117
+ resp = llm.invoke([HumanMessage(content=content)])
118
+ return resp.content
119
+
120
+ # ----------------------------------------------------------------------
121
+ # 5 ── OFFLINE OCR-TOOL (pytesseract)
122
+ # ----------------------------------------------------------------------
123
  @tool
124
+ @error_guard
125
+ def ocr_image(file_path: str, lang: str = "eng") -> str:
126
+ """Extract text from image using pytesseract."""
127
+ from PIL import Image
128
+ import pytesseract
129
+ img = Image.open(file_path)
130
+ return pytesseract.image_to_string(img, lang=lang).strip()
131
+
132
+ # ----------------------------------------------------------------------
133
+ # 6 ── WEB / WIKI SEARCH
134
+ # ----------------------------------------------------------------------
135
  @tool
136
+ @error_guard
137
  def web_search(query: str) -> str:
138
+ """DuckDuckGo text search top 5 results."""
139
+ with DDGS() as ddgs:
140
+ results = ddgs.text(query, max_results=5)
141
+ if not results:
142
+ return "ERROR: no results."
143
+ return "\n\n".join(f"{r['title']} {r['href']}" for r in results)
144
+
145
+ # ----------------------------------------------------------------------
146
+ # 7 ── SYSTEM-PROMPT
147
+ # ----------------------------------------------------------------------
148
+ system_prompt = SystemMessage(content=(
149
+ "You are a precise GAIA challenge agent.\n"
150
+ "Always attempt a TOOL call before giving up. "
151
+ "If a tool returns 'ERROR', think of an alternative tool or parameters.\n"
152
+ "Use fetch_gaia_file(task_id) for any attachment."
153
+ ))
154
+
155
+ # ----------------------------------------------------------------------
156
+ # 8 ── LangGraph Nodes
157
+ # ----------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  tools = [
159
  fetch_gaia_file,
160
  parse_csv,
161
  parse_excel,
162
+ gemini_transcribe_audio,
163
+ ocr_image,
164
+ describe_image,
165
  web_search,
166
  simple_calculator,
167
  ]
168
+
169
  llm_with_tools = llm.bind_tools(tools)
170
 
171
+
172
+ def safe_llm_invoke(msgs):
173
+ for attempt in range(2):
174
+ resp = llm_with_tools.invoke(msgs)
175
+ content = resp.content or ""
176
+ if not content.startswith("ERROR"):
177
+ return resp
178
+ msgs.append(
179
+ SystemMessage(content="Previous tool call returned ERROR. Try another approach.")
 
 
 
 
 
 
 
 
 
180
  )
181
+ return resp
 
182
 
183
 
 
184
  def assistant(state: MessagesState):
185
  msgs = state["messages"]
186
  if not msgs or msgs[0].type != "system":
187
  msgs = [system_prompt] + msgs
188
+ return {"messages": [safe_llm_invoke(msgs)]}
189
 
190
+ # ----------------------------------------------------------------------
191
+ # 9 ── Graph
192
+ # ----------------------------------------------------------------------
193
  builder = StateGraph(MessagesState)
194
  builder.add_node("assistant", assistant)
195
  builder.add_node("tools", ToolNode(tools))
 
197
  builder.add_conditional_edges("assistant", tools_condition)
198
  builder.add_edge("tools", "assistant")
199
 
 
200
  agent_executor = builder.compile()