ZeroTimo commited on
Commit
e8d1b6b
·
verified ·
1 Parent(s): ff70a29

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +224 -7
agent.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from langgraph.graph import StateGraph, START, MessagesState
3
  from langgraph.prebuilt import tools_condition, ToolNode
4
  from langchain_google_genai import ChatGoogleGenerativeAI
@@ -6,11 +7,151 @@ from langchain_core.tools import tool
6
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
7
  from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
8
  from langchain_core.messages import SystemMessage, HumanMessage
 
 
9
 
10
  # Lade Umgebungsvariablen (Google API Key)
11
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
12
 
13
  # === Tools definieren ===
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  @tool
15
  def multiply(a: int, b: int) -> int:
16
  """Multiplies two numbers."""
@@ -59,11 +200,51 @@ def web_search(query: str) -> str:
59
 
60
  # === System Prompt definieren ===
61
  system_prompt = SystemMessage(content=(
62
- "You are an expert assistant. You MUST answer precisely, factually, and accurately. "
63
- "If you do not know the answer, use the available tools such as Wikipedia Search, Arxiv Search, "
64
- "or Web Search to find the correct information. "
65
- "If a math operation is needed, use the calculation tools. "
66
- "Do NOT invent answers. Only return answers you are confident in."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  ))
68
 
69
  # === LLM definieren ===
@@ -76,12 +257,48 @@ llm = ChatGoogleGenerativeAI(
76
  )
77
 
78
  # === Tools in LLM einbinden ===
79
- tools = [multiply, add, subtract, divide, modulo, wiki_search, arxiv_search, web_search]
 
 
 
 
 
 
 
 
 
80
  llm_with_tools = llm.bind_tools(tools)
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  # === Nodes für LangGraph ===
83
  def assistant(state: MessagesState):
84
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
85
 
86
  # === LangGraph bauen ===
87
  builder = StateGraph(MessagesState)
 
1
  import os
2
+ import pandas as pd
3
  from langgraph.graph import StateGraph, START, MessagesState
4
  from langgraph.prebuilt import tools_condition, ToolNode
5
  from langchain_google_genai import ChatGoogleGenerativeAI
 
7
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
8
  from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
9
  from langchain_core.messages import SystemMessage, HumanMessage
10
+ import requests
11
+ import tempfile
12
 
13
  # Lade Umgebungsvariablen (Google API Key)
14
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
15
 
16
  # === Tools definieren ===
17
+
18
+ GAIA_BASE_URL = "https://agents-course-unit4-scoring.hf.space"
19
+
20
+ @tool
21
+ def fetch_gaia_file(task_id: str) -> str:
22
+ """
23
+ Download the file attached to a GAIA task and return the local file-path.
24
+
25
+ Args:
26
+ task_id: The GAIA task_id (string in the JSON payload).
27
+
28
+ Returns:
29
+ Absolute path to the downloaded temp-file.
30
+ """
31
+ try:
32
+ url = f"{GAIA_BASE_URL}/files/{task_id}"
33
+ response = requests.get(url, timeout=20)
34
+ response.raise_for_status()
35
+
36
+ # Server liefert den echten Dateinamen im Header – fallback auf "download"
37
+ filename = (
38
+ response.headers.get("x-filename") or
39
+ response.headers.get("content-disposition", "download").split("filename=")[-1].strip('"')
40
+ )
41
+ if not filename:
42
+ filename = f"{task_id}.bin"
43
+
44
+ tmp_path = os.path.join(tempfile.gettempdir(), filename)
45
+ with open(tmp_path, "wb") as f:
46
+ f.write(response.content)
47
+
48
+ return tmp_path
49
+ except Exception as e:
50
+ return f"ERROR: could not download file for task {task_id}: {e}"
51
+
52
+ @tool
53
+ def parse_csv(file_path: str, query: str = "") -> str:
54
+ """
55
+ Load a CSV file from `file_path` and optionally run a simple analysis query.
56
+
57
+ Args:
58
+ file_path: absolute path to a CSV file (from fetch_gaia_file)
59
+ query: optional natural-language instruction, e.g.
60
+ "sum of column Sales where Category != 'Drinks'"
61
+
62
+ Returns:
63
+ A concise string with the answer OR a preview of the dataframe
64
+ if no query given.
65
+ """
66
+ try:
67
+ df = pd.read_csv(file_path)
68
+
69
+ # Auto-preview if kein query
70
+ if not query:
71
+ preview = df.head(5).to_markdown(index=False)
72
+ return f"CSV loaded. First rows:\n\n{preview}\n\nColumns: {', '.join(df.columns)}"
73
+
74
+ # Mini-query-engine (sehr simpel, reicht für Summen / Mittelwerte)
75
+ query_lc = query.lower()
76
+ if "sum" in query_lc:
77
+ # ermitteln, welche Spalte summiert werden soll
78
+ for col in df.columns:
79
+ if col.lower() in query_lc:
80
+ s = df[col]
81
+ if "where" in query_lc:
82
+ # naive Filter-Parsing: where <col> != 'Drinks'
83
+ cond_part = query_lc.split("where", 1)[1].strip()
84
+ # SEHR einfaches != oder == Parsing
85
+ if "!=" in cond_part:
86
+ key, val = [x.strip().strip("'\"") for x in cond_part.split("!=")]
87
+ s = df.loc[df[key] != val, col]
88
+ elif "==" in cond_part:
89
+ key, val = [x.strip().strip("'\"") for x in cond_part.split("==")]
90
+ s = df.loc[df[key] == val, col]
91
+ return str(round(s.sum(), 2))
92
+ # Fallback
93
+ return "Query type not supported by parse_csv."
94
+ except Exception as e:
95
+ return f"ERROR parsing CSV: {e}"
96
+
97
+ @tool
98
+ def parse_excel(file_path: str, query: str = "") -> str:
99
+ """
100
+ Identisch zu parse_csv, nur für XLS/XLSX.
101
+ """
102
+ try:
103
+ df = pd.read_excel(file_path)
104
+
105
+ if not query:
106
+ preview = df.head(5).to_markdown(index=False)
107
+ return f"Excel loaded. First rows:\n\n{preview}\n\nColumns: {', '.join(df.columns)}"
108
+
109
+ query_lc = query.lower()
110
+ if "sum" in query_lc:
111
+ for col in df.columns:
112
+ if col.lower() in query_lc:
113
+ s = df[col]
114
+ if "where" in query_lc:
115
+ cond_part = query_lc.split("where", 1)[1].strip()
116
+ if "!=" in cond_part:
117
+ key, val = [x.strip().strip("'\"") for x in cond_part.split("!=")]
118
+ s = df.loc[df[key] != val, col]
119
+ elif "==" in cond_part:
120
+ key, val = [x.strip().strip("'\"") for x in cond_part.split("==")]
121
+ s = df.loc[df[key] == val, col]
122
+ return str(round(s.sum(), 2))
123
+ return "Query type not supported by parse_excel."
124
+ except Exception as e:
125
+ return f"ERROR parsing Excel: {e}"
126
+
127
+ @tool
128
+ def transcribe_audio(file_path: str, language: str = "en") -> str:
129
+ """
130
+ Transcribe an audio file (MP3/WAV/etc.) using Faster-Whisper.
131
+
132
+ Args:
133
+ file_path: absolute path to an audio file (from fetch_gaia_file)
134
+ language: ISO language code, default "en"
135
+
136
+ Returns:
137
+ Full transcription as plain text, or "ERROR …"
138
+ """
139
+ try:
140
+ from faster_whisper import WhisperModel
141
+
142
+ # Tiny model reicht für kurze Sprachmemos, ~75 MB
143
+ model = WhisperModel("tiny", device="cpu", compute_type="int8")
144
+
145
+ segments, _ = model.transcribe(file_path, language=language)
146
+ transcript = " ".join(segment.text.strip() for segment in segments).strip()
147
+
148
+ if not transcript:
149
+ return "ERROR: transcription empty."
150
+ return transcript
151
+ except Exception as e:
152
+ return f"ERROR: audio transcription failed – {e}"
153
+
154
+
155
  @tool
156
  def multiply(a: int, b: int) -> int:
157
  """Multiplies two numbers."""
 
200
 
201
  # === System Prompt definieren ===
202
  system_prompt = SystemMessage(content=(
203
+ system_prompt = SystemMessage(
204
+ content=(
205
+ "You are a focused, factual AI agent competing on the GAIA evaluation.\n"
206
+ "\n"
207
+ "GENERAL RULES\n"
208
+ "-------------\n"
209
+ "1. Always try to answer every question.\n"
210
+ "2. If you are NOT 100 % certain, prefer using a TOOL.\n"
211
+ "3. Never invent facts.\n"
212
+ "\n"
213
+ "TOOLS\n"
214
+ "-----\n"
215
+ "- fetch_gaia_file(task_id): downloads any attachment for the current task.\n"
216
+ "- parse_csv(file_path, query): analyse CSV files.\n"
217
+ "- parse_excel(file_path, query): analyse Excel files.\n"
218
+ "- transcribe_audio(file_path): transcribe MP3 / WAV audio.\n"
219
+ "- wiki_search(query): query English Wikipedia.\n"
220
+ "- arxiv_search(query): query arXiv.\n"
221
+ "- web_search(query): DuckDuckGo web search.\n"
222
+ "- simple_calculator(operation,a,b): basic maths.\n"
223
+ "\n"
224
+ "WHEN TO USE WHICH TOOL\n"
225
+ "----------------------\n"
226
+ "・If the prompt or GAIA metadata mentions an *attached* file, FIRST call "
227
+ "fetch_gaia_file with the given task_id. Then:\n"
228
+ " • CSV → parse_csv\n"
229
+ " • XLS/XLSX → parse_excel\n"
230
+ " • MP3/WAV → transcribe_audio (language auto-detect is OK)\n"
231
+ " • Image → (currently unsupported) answer that image processing is unavailable\n"
232
+ "・If you need factual data (dates, numbers, names) → wiki_search or web_search.\n"
233
+ "・If you need a scientific paper → arxiv_search.\n"
234
+ "・If a numeric operation is required → simple_calculator.\n"
235
+ "\n"
236
+ "ERROR HANDLING\n"
237
+ "--------------\n"
238
+ "If a tool call returns a string that starts with \"ERROR:\", IMMEDIATELY think of "
239
+ "an alternative strategy: retry with a different tool or modified parameters. "
240
+ "Do not repeat the same failing call twice.\n"
241
+ "\n"
242
+ "OUTPUT FORMAT\n"
243
+ "-------------\n"
244
+ "Follow the exact format asked in the question (e.g. single word, CSV, comma-list). "
245
+ "Do not add extra commentary.\n"
246
+ )
247
+ )
248
  ))
249
 
250
  # === LLM definieren ===
 
257
  )
258
 
259
  # === Tools in LLM einbinden ===
260
+ tools = [
261
+ fetch_gaia_file,
262
+ parse_csv,
263
+ parse_excel,
264
+ transcribe_audio,
265
+ wiki_search,
266
+ arxiv_search,
267
+ web_search,
268
+ simple_calculator,
269
+ ]
270
  llm_with_tools = llm.bind_tools(tools)
271
 
272
+ def safe_llm_invoke(messages):
273
+ """
274
+ Ruft LLM einmal auf. Wenn das Ergebnis mit ERROR beginnt,
275
+ ruft es genau EIN weiteres Mal auf – jetzt weiß das LLM,
276
+ dass der vorige Tool-Call fehlgeschlagen ist.
277
+ """
278
+ max_attempts = 2
279
+ for attempt in range(max_attempts):
280
+ result = llm_with_tools.invoke(messages)
281
+ content = result.content if hasattr(result, "content") else ""
282
+ if "ERROR:" not in content:
283
+ return result
284
+ # Fehler: füge eine System-Korrektur hinzu und versuche erneut
285
+ messages.append(
286
+ SystemMessage(
287
+ content="Previous tool call returned an ERROR. "
288
+ "Try a different tool or revise the input."
289
+ )
290
+ )
291
+ # nach max_attempts immer noch Fehler → zurückgeben
292
+ return result
293
+
294
+
295
  # === Nodes für LangGraph ===
296
  def assistant(state: MessagesState):
297
+ """
298
+ Assistant node mit eingebautem Retry bei Tool-Fehlern.
299
+ """
300
+ result_msg = safe_llm_invoke(state["messages"])
301
+ return {"messages": [result_msg]}
302
 
303
  # === LangGraph bauen ===
304
  builder = StateGraph(MessagesState)