ZeroTimo commited on
Commit
7cc5531
·
verified ·
1 Parent(s): 0d276c6

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +182 -175
agent.py CHANGED
@@ -1,37 +1,29 @@
1
- """
2
- agent.py – LangGraph-Agent mit
3
- Gemini 2.0 Flash
4
- Datei-Tools (CSV, Excel, Audio, Bild-Describe, OCR)
5
- • Fehler-Retry-Logik
6
- """
7
- import os, base64, mimetypes, subprocess, json, tempfile
8
- import functools
9
- from typing import Any
10
 
11
  from langgraph.graph import START, StateGraph, MessagesState
12
- from langgraph.prebuilt import tools_condition, ToolNode
13
- from langchain_core.tools import tool
14
  from langchain_core.messages import SystemMessage, HumanMessage
 
 
15
  from langchain_google_genai import ChatGoogleGenerativeAI
16
  from langchain_community.tools.tavily_search import TavilySearchResults
17
 
18
- # ----------------------------------------------------------------------
19
- # 1 ── ENV / LLM
20
- # ----------------------------------------------------------------------
21
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
 
22
 
23
- llm = ChatGoogleGenerativeAI(
24
- model="gemini-2.0-flash",
25
- google_api_key=GOOGLE_API_KEY,
26
- temperature=0,
27
- max_output_tokens=2048,
28
- )
29
-
30
- # ----------------------------------------------------------------------
31
- # 2 ── ERROR-WRAPPER (garantiert "ERROR:"-String statt Exception)
32
- # ----------------------------------------------------------------------
33
  def error_guard(fn):
34
- @functools.wraps(fn) # ➜ übernimmt __doc__, __name__, …
35
  def wrapper(*args, **kwargs):
36
  try:
37
  return fn(*args, **kwargs)
@@ -39,198 +31,213 @@ def error_guard(fn):
39
  return f"ERROR: {e}"
40
  return wrapper
41
 
42
- # ----------------------------------------------------------------------
43
- # 3 ── BASIS-TOOLS
44
- # ----------------------------------------------------------------------
45
- @tool
46
- @error_guard
47
- def simple_calculator(operation: str, a: float, b: float) -> float:
48
- """Basic maths: add, subtract, multiply, divide."""
49
- ops = {"add": a + b, "subtract": a - b, "multiply": a * b,
50
- "divide": a / b if b else float("inf")}
51
- return ops.get(operation, "ERROR: unknown operation")
52
 
53
  @tool
54
  @error_guard
55
  def fetch_gaia_file(task_id: str) -> str:
56
- """Download attachment for current GAIA task_id; returns local file path."""
57
- import requests, pathlib, uuid
58
- url = f"https://agents-course-unit4-scoring.hf.space/file/{task_id}"
59
- r = requests.get(url, timeout=15)
60
- r.raise_for_status()
61
- suffix = pathlib.Path(url).suffix or ""
62
- fp = pathlib.Path(tempfile.gettempdir())/f"{uuid.uuid4().hex}{suffix}"
63
- fp.write_bytes(r.content)
64
- return str(fp)
 
 
 
 
 
 
 
 
65
 
66
  @tool
67
  @error_guard
68
  def parse_csv(file_path: str, query: str = "") -> str:
69
- """Load CSV & answer query using pandas.eval."""
70
- import pandas as pd
71
  df = pd.read_csv(file_path)
72
  if not query:
73
- return df.head().to_markdown()
74
- return str(pd.eval(query, local_dict={"df": df}))
75
-
 
 
 
 
 
 
 
76
  @tool
77
  @error_guard
78
  def parse_excel(file_path: str, query: str = "") -> str:
79
- """Load first sheet of Excel & answer query using pandas.eval."""
80
- import pandas as pd
81
  df = pd.read_excel(file_path)
82
  if not query:
83
- return df.head().to_markdown()
84
- return str(pd.eval(query, local_dict={"df": df}))
85
-
86
- # ----------------------------------------------------------------------
87
- # 4 ── GEMINI MULTIMODAL-TOOLS
88
- # ----------------------------------------------------------------------
 
 
 
 
89
  @tool
90
  @error_guard
91
- def describe_image(file_path: str, prompt: str = "Describe the image.") -> str:
92
- """Send a local image (base64) to Gemini Vision and return description."""
93
- mime, _ = mimetypes.guess_type(file_path)
94
- if not (mime and mime.startswith("image/")):
95
- return "ERROR: not an image."
96
  with open(file_path, "rb") as f:
97
  b64 = base64.b64encode(f.read()).decode()
98
- content = [
99
- {"type": "text", "text": prompt},
100
- {"type": "image_url", "image_url": f"data:{mime};base64,{b64}"},
101
- ]
102
- resp = llm.invoke([HumanMessage(content=content)])
103
- return resp.content
104
-
 
 
 
 
 
 
105
  @tool
106
  @error_guard
107
- def gemini_transcribe_audio(file_path: str,
108
- prompt: str = "Transcribe the audio.") -> str:
109
- """Transcribe audio via Gemini multimodal."""
110
- mime, _ = mimetypes.guess_type(file_path)
111
- if not (mime and mime.startswith("audio/")):
112
- return "ERROR: not audio."
113
- with open(file_path, "rb") as f:
114
- b64 = base64.b64encode(f.read()).decode()
115
- content = [
116
- {"type": "text", "text": prompt},
117
- {"type": "media", "data": b64, "mime_type": mime},
118
- ]
119
- resp = llm.invoke([HumanMessage(content=content)])
120
  return resp.content
121
 
122
- # ----------------------------------------------------------------------
123
- # 5 ── OFFLINE OCR-TOOL (pytesseract)
124
- # ----------------------------------------------------------------------
125
  @tool
126
  @error_guard
127
  def ocr_image(file_path: str, lang: str = "eng") -> str:
128
- """Extract text from image using pytesseract."""
129
- from PIL import Image
130
- import pytesseract
131
- img = Image.open(file_path)
132
- return pytesseract.image_to_string(img, lang=lang).strip()
133
-
134
- # ----------------------------------------------------------------------
135
- # 6 ── WEB / WIKI SEARCH
136
- # ----------------------------------------------------------------------
 
 
 
137
  @tool
138
  @error_guard
139
  def web_search(query: str, max_results: int = 5) -> str:
140
- """Tavily web search returns markdown list of results."""
141
- search = TavilySearchResults(max_results=max_results)
142
- hits = search.invoke(query)
143
  if not hits:
144
- return "ERROR: no results."
145
- return "\n\n".join(f"{hit['title']} – {hit['url']}" for hit in hits)
146
-
147
-
148
- # ----------------------------------------------------------------------
149
- # 7 ── SYSTEM-PROMPT
150
- # ----------------------------------------------------------------------
151
- system_prompt = SystemMessage(content=(
152
- """"
153
- You are GAIA-Assist, an accurate, tool-using agent.
154
-
155
- TOOLS YOU CAN CALL
156
- ------------------
157
- fetch_gaia_file(task_id) – download the current task’s attachment
158
- parse_csv(file_path, query="")
159
- • parse_excel(file_path, query="")
160
- gemini_transcribe_audio(file_path[, prompt])
161
- • describe_image(file_path[, prompt])
162
- ocr_image(file_path[, lang="eng"])
163
- web_search(query [, max_results=5])
164
- simple_calculator(operation, a, b)
165
-
166
- WORKFLOW RULES
167
- --------------
168
- 1. **If** the question mentions an attachment, first call
169
- fetch_gaia_file(task_id).
170
- – After it returns a path, choose exactly one specialised parser.
171
-
172
- 2. **Otherwise**, think whether a web_search or calculator is needed.
173
-
174
- 3. **NEVER** call the same tool twice in a row with the same input.
175
-
176
- ANSWER FORMAT
177
- -------------
178
- *If a tool is needed*
179
- Thought: Do I need to use a tool? **Yes**
180
- Action: <tool name>
181
- Action Input: <JSON-encoded arguments>
182
-
183
- *If no tool is needed*
184
- Thought: Do I need to use a tool? **No**
185
- Final Answer: <your concise answer here>
186
-
187
- Once you have written **Final Answer:** you are done – do **not** call any further tool.
188
- """
189
- ))
190
-
191
- # ----------------------------------------------------------------------
192
- # 8 ── LangGraph Nodes
193
- # ----------------------------------------------------------------------
194
- tools = [
195
- fetch_gaia_file,
196
- parse_csv,
197
- parse_excel,
198
- gemini_transcribe_audio,
199
- ocr_image,
200
- describe_image,
201
- web_search,
202
- simple_calculator,
203
- ]
204
 
205
- llm_with_tools = llm.bind_tools(tools)
 
 
206
 
 
207
 
208
- def safe_llm_invoke(msgs):
209
- for attempt in range(2):
210
- resp = llm_with_tools.invoke(msgs)
211
- content = resp.content or ""
212
- if not content.startswith("ERROR"):
213
- return resp
214
- msgs.append(
215
- SystemMessage(content="Previous tool call returned ERROR. Try another approach.")
216
- )
217
- return resp
218
 
 
 
 
219
 
 
 
 
 
 
 
220
  def assistant(state: MessagesState):
221
  msgs = state["messages"]
222
- if not msgs or msgs[0].type != "system":
223
  msgs = [system_prompt] + msgs
224
- return {"messages": [safe_llm_invoke(msgs)]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
- # ----------------------------------------------------------------------
227
- # 9 ── Graph
228
- # ----------------------------------------------------------------------
229
  builder = StateGraph(MessagesState)
230
  builder.add_node("assistant", assistant)
231
  builder.add_node("tools", ToolNode(tools))
232
  builder.add_edge(START, "assistant")
233
- builder.add_conditional_edges("assistant", tools_condition)
234
- builder.add_edge("tools", "assistant")
235
 
 
236
  agent_executor = builder.compile()
 
1
+ # agent.py – Gemini 2.0 Flash · LangGraph · Mehrere Tools
2
+ # =========================================================
3
+ import os, asyncio, base64, mimetypes, tempfile, functools, json
4
+ from typing import Dict, Any, List, Optional
 
 
 
 
 
5
 
6
  from langgraph.graph import START, StateGraph, MessagesState
7
+ from langgraph.prebuilt import tools_condition, ToolNode, END
8
+
9
  from langchain_core.messages import SystemMessage, HumanMessage
10
+ from langchain_core.tools import tool
11
+
12
  from langchain_google_genai import ChatGoogleGenerativeAI
13
  from langchain_community.tools.tavily_search import TavilySearchResults
14
 
15
+ # ---------------------------------------------------------------------
16
+ # Konstanten / API-Keys
17
+ # ---------------------------------------------------------------------
18
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
19
+ TAVILY_KEY = os.getenv("TAVILY_API_KEY")
20
 
21
+ # ---------------------------------------------------------------------
22
+ # Fehler-Wrapper – behält Doc-String dank wraps
23
+ # ---------------------------------------------------------------------
24
+ import functools
 
 
 
 
 
 
25
  def error_guard(fn):
26
+ @functools.wraps(fn)
27
  def wrapper(*args, **kwargs):
28
  try:
29
  return fn(*args, **kwargs)
 
31
  return f"ERROR: {e}"
32
  return wrapper
33
 
34
+
35
+ # ---------------------------------------------------------------------
36
+ # 1) fetch_gaia_file – Datei vom GAIA-Server holen
37
+ # ---------------------------------------------------------------------
38
+ GAIA_FILE_ENDPOINT = "https://agents-course-unit4-scoring.hf.space/file"
 
 
 
 
 
39
 
40
  @tool
41
  @error_guard
42
  def fetch_gaia_file(task_id: str) -> str:
43
+ """Download the attachment for the given GAIA task_id and return local path."""
44
+ url = f"{GAIA_FILE_ENDPOINT}/{task_id}"
45
+ try:
46
+ response = requests.get(url, timeout=30)
47
+ response.raise_for_status()
48
+ file_name = response.headers.get("x-gaia-filename", f"{task_id}")
49
+ tmp_path = tempfile.gettempdir() + "/" + file_name
50
+ with open(tmp_path, "wb") as f:
51
+ f.write(response.content)
52
+ return tmp_path
53
+ except Exception as e:
54
+ return f"ERROR: could not fetch file – {e}"
55
+
56
+ # ---------------------------------------------------------------------
57
+ # 2) CSV-Parser
58
+ # ---------------------------------------------------------------------
59
+ import pandas as pd
60
 
61
  @tool
62
  @error_guard
63
  def parse_csv(file_path: str, query: str = "") -> str:
64
+ """Load a CSV file and answer a quick pandas query (optional)."""
 
65
  df = pd.read_csv(file_path)
66
  if not query:
67
+ return f"Loaded CSV with {len(df)} rows and {len(df.columns)} cols.\nColumns: {list(df.columns)}"
68
+ try:
69
+ result = df.query(query)
70
+ return result.to_markdown()
71
+ except Exception as e:
72
+ return f"ERROR in pandas query: {e}"
73
+
74
+ # ---------------------------------------------------------------------
75
+ # 3) Excel-Parser
76
+ # ---------------------------------------------------------------------
77
  @tool
78
  @error_guard
79
  def parse_excel(file_path: str, query: str = "") -> str:
80
+ """Load an Excel file (first sheet) and answer a pandas query (optional)."""
 
81
  df = pd.read_excel(file_path)
82
  if not query:
83
+ return f"Loaded Excel with {len(df)} rows and {len(df.columns)} cols.\nColumns: {list(df.columns)}"
84
+ try:
85
+ result = df.query(query)
86
+ return result.to_markdown()
87
+ except Exception as e:
88
+ return f"ERROR in pandas query: {e}"
89
+
90
+ # ---------------------------------------------------------------------
91
+ # 4) Gemini-Audio-Transkription
92
+ # ---------------------------------------------------------------------
93
  @tool
94
  @error_guard
95
+ def gemini_transcribe_audio(file_path: str, prompt: str = "Transcribe the audio.") -> str:
96
+ """Use Gemini to transcribe an audio file."""
 
 
 
97
  with open(file_path, "rb") as f:
98
  b64 = base64.b64encode(f.read()).decode()
99
+ mime = mimetypes.guess_type(file_path)[0] or "audio/mpeg"
100
+ message = HumanMessage(
101
+ content=[
102
+ {"type": "text", "text": prompt},
103
+ {"type": "media", "data": b64, "mime_type": mime},
104
+ ]
105
+ )
106
+ resp = asyncio.run(gemini_llm.invoke([message]))
107
+ return resp.content if hasattr(resp, "content") else str(resp)
108
+
109
+ # ---------------------------------------------------------------------
110
+ # 5) Bild-Beschreibung
111
+ # ---------------------------------------------------------------------
112
  @tool
113
  @error_guard
114
+ def describe_image(file_path: str, prompt: str = "Describe this image.") -> str:
115
+ """Gemini vision Bild beschreiben."""
116
+ from PIL import Image
117
+ img = Image.open(file_path)
118
+ message = HumanMessage(
119
+ content=[
120
+ {"type": "text", "text": prompt},
121
+ img, # langchain übernimmt Encoding
122
+ ]
123
+ )
124
+ resp = asyncio.run(gemini_llm.invoke([message]))
 
 
125
  return resp.content
126
 
127
+ # ---------------------------------------------------------------------
128
+ # 6) OCR-Tool
129
+ # ---------------------------------------------------------------------
130
  @tool
131
  @error_guard
132
  def ocr_image(file_path: str, lang: str = "eng") -> str:
133
+ """Extract text from an image via pytesseract."""
134
+ try:
135
+ import pytesseract
136
+ from PIL import Image
137
+ text = pytesseract.image_to_string(Image.open(file_path), lang=lang)
138
+ return text.strip() or "No text found."
139
+ except Exception as e:
140
+ return f"ERROR: {e}"
141
+
142
+ # ---------------------------------------------------------------------
143
+ # 7) Tavily-Web-Suche
144
+ # ---------------------------------------------------------------------
145
  @tool
146
  @error_guard
147
  def web_search(query: str, max_results: int = 5) -> str:
148
+ """Search the web via Tavily and return a markdown list of results."""
149
+ hits = TavilySearchResults(max_results=max_results, api_key=TAVILY_KEY).invoke(query)
 
150
  if not hits:
151
+ return "No results."
152
+ return "\n\n".join(f"{h['title']} – {h['url']}" for h in hits)
153
+
154
+ # ---------------------------------------------------------------------
155
+ # 8) Kleiner Rechner
156
+ # ---------------------------------------------------------------------
157
+ @tool
158
+ @error_guard
159
+ def simple_calculator(operation: str, a: float, b: float) -> float:
160
+ """Basic maths (add, subtract, multiply, divide)."""
161
+ ops = {
162
+ "add": a + b,
163
+ "subtract": a - b,
164
+ "multiply": a * b,
165
+ "divide": a / b if b else float("inf"),
166
+ }
167
+ return ops.get(operation, f"ERROR: unknown op '{operation}'")
168
+
169
+ # ---------------------------------------------------------------------
170
+ # LLM + Semaphore-Throttle (Gemini 2.0 Flash)
171
+ # ---------------------------------------------------------------------
172
+ gemini_llm = ChatGoogleGenerativeAI(
173
+ model="gemini-2.0-flash",
174
+ google_api_key=GOOGLE_API_KEY,
175
+ temperature=0,
176
+ max_output_tokens=2048,
177
+ ).bind_tools([
178
+ fetch_gaia_file, parse_csv, parse_excel,
179
+ gemini_transcribe_audio, describe_image, ocr_image,
180
+ web_search, simple_calculator,
181
+ ])
182
+
183
+ LLM_SEMA = asyncio.Semaphore(3) # 3 gleichz. Anfragen ≈ < 15/min
184
+
185
+ async def safe_invoke(msgs: List[Any]):
186
+ async with LLM_SEMA:
187
+ return gemini_llm.invoke(msgs)
188
+
189
+ # ---------------------------------------------------------------------
190
+ # System-Prompt
191
+ # ---------------------------------------------------------------------
192
+ system_prompt = SystemMessage(content="""
193
+ You are GAIA-Assist, a precise, tool-using agent.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ If a question mentions an attachment:
196
+ 1. Call fetch_gaia_file(task_id)
197
+ 2. Use exactly one specialised parser tool on the returned path.
198
 
199
+ Otherwise decide between web_search or simple_calculator.
200
 
201
+ Format for a tool call:
202
+ Thought: Do I need to use a tool? Yes
203
+ Action: <tool name>
204
+ Action Input: <JSON arguments>
 
 
 
 
 
 
205
 
206
+ Format for final answer:
207
+ Thought: Do I need to use a tool? No
208
+ Final Answer: <your answer>
209
 
210
+ Stop once you output "Final Answer:".
211
+ """)
212
+
213
+ # ---------------------------------------------------------------------
214
+ # LangGraph – Assistant-Node
215
+ # ---------------------------------------------------------------------
216
  def assistant(state: MessagesState):
217
  msgs = state["messages"]
218
+ if msgs[0].type != "system":
219
  msgs = [system_prompt] + msgs
220
+ resp = asyncio.run(safe_invoke(msgs))
221
+ finished = resp.content.lower().lstrip().startswith("final answer") or not resp.tool_calls
222
+ return {"messages": [resp], "should_end": finished}
223
+
224
+ def route(state):
225
+ return "END" if state["should_end"] else "tools"
226
+
227
+ # ---------------------------------------------------------------------
228
+ # Tools-Liste & Graph
229
+ # ---------------------------------------------------------------------
230
+ tools = [
231
+ fetch_gaia_file, parse_csv, parse_excel,
232
+ gemini_transcribe_audio, describe_image, ocr_image,
233
+ web_search, simple_calculator,
234
+ ]
235
 
 
 
 
236
  builder = StateGraph(MessagesState)
237
  builder.add_node("assistant", assistant)
238
  builder.add_node("tools", ToolNode(tools))
239
  builder.add_edge(START, "assistant")
240
+ builder.add_conditional_edges("assistant", route, {"tools": "tools", "END": END})
 
241
 
242
+ # Compile
243
  agent_executor = builder.compile()