ZeroTimo commited on
Commit
400e97a
·
verified ·
1 Parent(s): 909bf64

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +129 -173
agent.py CHANGED
@@ -1,241 +1,197 @@
1
- # agent.py – Gemini 2.0 Flash · LangGraph · Mehrere Tools
2
- # =========================================================
3
- import os, asyncio, base64, mimetypes, tempfile, functools, json
4
- from typing import Dict, Any, List, Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- from langgraph.graph import START, StateGraph, MessagesState, END
7
- from langgraph.prebuilt import tools_condition, ToolNode
 
8
 
 
9
  from langchain_core.messages import SystemMessage, HumanMessage
10
  from langchain_core.tools import tool
11
-
12
- from langchain_google_genai import ChatGoogleGenerativeAI
13
  from langchain_community.tools.tavily_search import TavilySearchResults
 
 
14
 
15
  # ---------------------------------------------------------------------
16
- # Konstanten / API-Keys
17
  # ---------------------------------------------------------------------
18
- GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
19
- TAVILY_KEY = os.getenv("TAVILY_API_KEY")
 
 
20
 
21
  # ---------------------------------------------------------------------
22
- # Fehler-Wrapper behält Doc-String dank wraps
23
  # ---------------------------------------------------------------------
24
- import functools
25
  def error_guard(fn):
 
26
  @functools.wraps(fn)
27
- def wrapper(*args, **kwargs):
28
  try:
29
- return fn(*args, **kwargs)
30
  except Exception as e:
31
  return f"ERROR: {e}"
32
  return wrapper
33
 
34
 
35
- # ---------------------------------------------------------------------
36
- # 1) fetch_gaia_file – Datei vom GAIA-Server holen
37
- # ---------------------------------------------------------------------
38
- GAIA_FILE_ENDPOINT = "https://agents-course-unit4-scoring.hf.space/file"
39
-
40
- @tool
41
- @error_guard
42
- def fetch_gaia_file(task_id: str) -> str:
43
- """Download the attachment for the given GAIA task_id and return local path."""
44
- url = f"{GAIA_FILE_ENDPOINT}/{task_id}"
45
- try:
46
- response = requests.get(url, timeout=30)
47
- response.raise_for_status()
48
- file_name = response.headers.get("x-gaia-filename", f"{task_id}")
49
- tmp_path = tempfile.gettempdir() + "/" + file_name
50
- with open(tmp_path, "wb") as f:
51
- f.write(response.content)
52
- return tmp_path
53
- except Exception as e:
54
- return f"ERROR: could not fetch file – {e}"
55
 
56
  # ---------------------------------------------------------------------
57
- # 2) CSV-Parser
58
  # ---------------------------------------------------------------------
59
- import pandas as pd
60
-
61
  @tool
62
  @error_guard
63
  def parse_csv(file_path: str, query: str = "") -> str:
64
- """Load a CSV file and answer a quick pandas query (optional)."""
65
  df = pd.read_csv(file_path)
66
  if not query:
67
- return f"Loaded CSV with {len(df)} rows and {len(df.columns)} cols.\nColumns: {list(df.columns)}"
68
  try:
69
- result = df.query(query)
70
- return result.to_markdown()
71
  except Exception as e:
72
- return f"ERROR in pandas query: {e}"
 
73
 
74
- # ---------------------------------------------------------------------
75
- # 3) Excel-Parser
76
- # ---------------------------------------------------------------------
77
  @tool
78
  @error_guard
79
- def parse_excel(file_path: str, query: str = "") -> str:
80
- """Load an Excel file (first sheet) and answer a pandas query (optional)."""
81
- df = pd.read_excel(file_path)
 
82
  if not query:
83
- return f"Loaded Excel with {len(df)} rows and {len(df.columns)} cols.\nColumns: {list(df.columns)}"
84
  try:
85
- result = df.query(query)
86
- return result.to_markdown()
87
  except Exception as e:
88
- return f"ERROR in pandas query: {e}"
89
 
90
  # ---------------------------------------------------------------------
91
- # 4) Gemini-Audio-Transkription
92
  # ---------------------------------------------------------------------
93
  @tool
94
  @error_guard
95
- def gemini_transcribe_audio(file_path: str, prompt: str = "Transcribe the audio.") -> str:
96
- """Use Gemini to transcribe an audio file."""
97
- with open(file_path, "rb") as f:
98
- b64 = base64.b64encode(f.read()).decode()
99
- mime = mimetypes.guess_type(file_path)[0] or "audio/mpeg"
100
- message = HumanMessage(
101
- content=[
102
- {"type": "text", "text": prompt},
103
- {"type": "media", "data": b64, "mime_type": mime},
104
- ]
105
- )
106
- resp = asyncio.run(safe_invoke([message]))
107
- return resp.content if hasattr(resp, "content") else str(resp)
108
-
109
- # ---------------------------------------------------------------------
110
- # 5) Bild-Beschreibung
111
- # ---------------------------------------------------------------------
112
- @tool
113
- @error_guard
114
- def describe_image(file_path: str, prompt: str = "Describe this image.") -> str:
115
- """Gemini vision – Bild beschreiben."""
116
- from PIL import Image
117
- img = Image.open(file_path)
118
- message = HumanMessage(
119
- content=[
120
- {"type": "text", "text": prompt},
121
- img, # langchain übernimmt Encoding
122
- ]
123
- )
124
- resp = asyncio.run(safe_invoke([message]))
125
- return resp.content
126
 
127
- # ---------------------------------------------------------------------
128
- # 6) OCR-Tool
129
- # ---------------------------------------------------------------------
130
- @tool
131
- @error_guard
132
- def ocr_image(file_path: str, lang: str = "eng") -> str:
133
- """Extract text from an image via pytesseract."""
134
- try:
135
- import pytesseract
136
- from PIL import Image
137
- text = pytesseract.image_to_string(Image.open(file_path), lang=lang)
138
- return text.strip() or "No text found."
139
- except Exception as e:
140
- return f"ERROR: {e}"
141
 
142
- # ---------------------------------------------------------------------
143
- # 7) Tavily-Web-Suche
144
- # ---------------------------------------------------------------------
145
  @tool
146
  @error_guard
147
- def web_search(query: str, max_results: int = 5) -> str:
148
- """Search the web via Tavily and return a markdown list of results."""
149
- hits = TavilySearchResults(max_results=max_results, api_key=TAVILY_KEY).invoke(query)
150
- if not hits:
151
- return "No results."
152
- return "\n\n".join(f"{h['title']} – {h['url']}" for h in hits)
153
 
154
  # ---------------------------------------------------------------------
155
- # 8) Kleiner Rechner
156
  # ---------------------------------------------------------------------
157
- @tool
158
- @error_guard
159
- def simple_calculator(operation: str, a: float, b: float) -> float:
160
- """Basic maths (add, subtract, multiply, divide)."""
161
- ops = {
162
- "add": a + b,
163
- "subtract": a - b,
164
- "multiply": a * b,
165
- "divide": a / b if b else float("inf"),
166
- }
167
- return ops.get(operation, f"ERROR: unknown op '{operation}'")
168
 
169
  # ---------------------------------------------------------------------
170
- # LLM + Semaphore-Throttle (Gemini 2.0 Flash)
171
  # ---------------------------------------------------------------------
172
  gemini_llm = ChatGoogleGenerativeAI(
 
173
  model="gemini-2.0-flash",
174
- google_api_key=GOOGLE_API_KEY,
175
  temperature=0,
176
  max_output_tokens=2048,
177
- ).bind_tools([
178
- fetch_gaia_file, parse_csv, parse_excel,
179
- gemini_transcribe_audio, describe_image, ocr_image,
180
- web_search, simple_calculator,] ,return_named_tools=True)
181
-
182
- LLM_SEMA = asyncio.Semaphore(2) # 3 gleichz. Anfragen ≈ < 15/min
183
-
184
- # safe_invoke neu (ersetzt die alte Funktion)
185
- async def safe_invoke(msgs, tries: int = 4):
186
- """Gemini-Aufruf mit Semaphor + Exponential-Back-off bei 429 / Netzfehlern."""
187
- delay = 4
188
- for t in range(tries):
189
- async with LLM_SEMA:
190
- try:
191
- return await gemini_llm.ainvoke(msgs)
192
- except Exception as e:
193
- # nur bei Rate-Limit oder Netzwerk erneut versuchen
194
- if ("429" in str(e) or "RateLimit" in str(e)) and t < tries - 1:
195
- await asyncio.sleep(delay)
196
- delay *= 2 # 4 s, 8 s, 16 s
197
- continue
198
- raise
199
-
200
- # ---------------------------------------------------------------------
201
- # System-Prompt
202
- # ---------------------------------------------------------------------
203
- system_prompt = SystemMessage(content="""
204
- You are a helpful assistant tasked with answering questions using a set of tools.
205
- Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
206
- FINAL ANSWER: [YOUR FINAL ANSWER].
207
- YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
208
- Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
209
- """)
210
 
211
  # ---------------------------------------------------------------------
212
- # LangGraph – Assistant-Node
213
  # ---------------------------------------------------------------------
214
- def assistant(state: MessagesState):
 
215
  msgs = state["messages"]
216
  if msgs[0].type != "system":
217
- msgs = [system_prompt] + msgs
218
- resp = asyncio.run(safe_invoke(msgs))
219
- finished = resp.content.lower().lstrip().startswith("final answer") or not resp.tool_calls
 
 
 
220
  return {"messages": [resp], "should_end": finished}
221
 
222
  def route(state):
223
  return "END" if state["should_end"] else "tools"
224
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  # ---------------------------------------------------------------------
226
- # Tools-Liste & Graph
227
  # ---------------------------------------------------------------------
228
- tools = [
229
- fetch_gaia_file, parse_csv, parse_excel,
230
- gemini_transcribe_audio, describe_image, ocr_image,
231
- web_search, simple_calculator,
232
- ]
233
 
234
- builder = StateGraph(MessagesState)
235
- builder.add_node("assistant", assistant)
236
- builder.add_node("tools", ToolNode(tools))
237
- builder.add_edge(START, "assistant")
238
- builder.add_conditional_edges("assistant", route, {"tools": "tools", "END": END})
239
 
240
- # Compile
241
- agent_executor = builder.compile()
 
 
 
 
 
 
1
+ # agent.py – LangChain · LangGraph · Gemini Flash
2
+ # ================================================
3
+
4
+ """
5
+ Abhängigkeiten (requirements.txt):
6
+ ----------------------------------
7
+ langchain==0.1.*
8
+ langgraph
9
+ google-generativeai
10
+ tavily-python
11
+ wikipedia-api
12
+ pandas
13
+ openpyxl
14
+ tabulate
15
+ """
16
+
17
+ import os, re, time, functools
18
+ from typing import Dict, Any, List
19
 
20
+ import pandas as pd
21
+ from langgraph.graph import StateGraph, START, END, MessagesState
22
+ from langgraph.prebuilt import ToolNode, tools_condition
23
 
24
+ from langchain_google_genai import ChatGoogleGenerativeAI
25
  from langchain_core.messages import SystemMessage, HumanMessage
26
  from langchain_core.tools import tool
 
 
27
  from langchain_community.tools.tavily_search import TavilySearchResults
28
+ from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
29
+ from langchain.tools.python.tool import PythonAstREPLTool
30
 
31
  # ---------------------------------------------------------------------
32
+ # 0) Optionale LangSmith-Tracing (setze ENV: LANGCHAIN_API_KEY)
33
  # ---------------------------------------------------------------------
34
+ if os.getenv("LANGCHAIN_API_KEY"):
35
+ os.environ.setdefault("LANGCHAIN_TRACING_V2", "true")
36
+ from langchain_community.utils import configure_langsmith
37
+ configure_langsmith(project_name="gaia-agent")
38
 
39
  # ---------------------------------------------------------------------
40
+ # 1) Helfer: Fehler-Decorator + Backoff-Wrapper
41
  # ---------------------------------------------------------------------
 
42
  def error_guard(fn):
43
+ """Fängt Tool-Fehler ab & gibt String zurück (bricht Agent nicht ab)."""
44
  @functools.wraps(fn)
45
+ def wrapper(*args, **kw):
46
  try:
47
+ return fn(*args, **kw)
48
  except Exception as e:
49
  return f"ERROR: {e}"
50
  return wrapper
51
 
52
 
53
+ def with_backoff(fn, tries: int = 4, delay: int = 4):
54
+ """Synchrones Retry-Wrapper für LLM-Aufrufe."""
55
+ for t in range(tries):
56
+ try:
57
+ return fn()
58
+ except Exception as e:
59
+ if ("429" in str(e) or "RateLimit" in str(e)) and t < tries - 1:
60
+ time.sleep(delay)
61
+ delay *= 2
62
+ continue
63
+ raise
 
 
 
 
 
 
 
 
 
64
 
65
  # ---------------------------------------------------------------------
66
+ # 2) Eigene Tools (CSV / Excel)
67
  # ---------------------------------------------------------------------
 
 
68
  @tool
69
  @error_guard
70
  def parse_csv(file_path: str, query: str = "") -> str:
71
+ """Load a CSV file and (optional) run a pandas query."""
72
  df = pd.read_csv(file_path)
73
  if not query:
74
+ return f"Rows={len(df)}, Cols={list(df.columns)}"
75
  try:
76
+ return df.query(query).to_markdown(index=False)
 
77
  except Exception as e:
78
+ return f"ERROR query: {e}"
79
+
80
 
 
 
 
81
  @tool
82
  @error_guard
83
+ def parse_excel(file_path: str, sheet: str | int | None = None, query: str = "") -> str:
84
+ """Load an Excel sheet (name or index) and (optional) run a pandas query."""
85
+ sheet_arg = int(sheet) if isinstance(sheet, str) and sheet.isdigit() else sheet or 0
86
+ df = pd.read_excel(file_path, sheet_name=sheet_arg)
87
  if not query:
88
+ return f"Rows={len(df)}, Cols={list(df.columns)}"
89
  try:
90
+ return df.query(query).to_markdown(index=False)
 
91
  except Exception as e:
92
+ return f"ERROR query: {e}"
93
 
94
  # ---------------------------------------------------------------------
95
+ # 3) Externe Search-Tools (Tavily, Wikipedia)
96
  # ---------------------------------------------------------------------
97
  @tool
98
  @error_guard
99
+ def web_search(query: str, max_results: int = 5) -> str:
100
+ """Search the web via Tavily and return markdown list of results."""
101
+ api_key = os.getenv("TAVILY_API_KEY")
102
+ hits = TavilySearchResults(max_results=max_results, api_key=api_key).invoke(query)
103
+ if not hits:
104
+ return "No results."
105
+ return "\n".join(f"{h['title']} – {h['url']}" for h in hits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
 
 
 
108
  @tool
109
  @error_guard
110
+ def wiki_search(query: str, sentences: int = 3) -> str:
111
+ """Quick Wikipedia summary."""
112
+ wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=4000)
113
+ res = wrapper.run(query)
114
+ return "\n".join(res.split(". ")[:sentences]) if res else "No article found."
 
115
 
116
  # ---------------------------------------------------------------------
117
+ # 4) Python-REPL Tool (fertig aus LangChain)
118
  # ---------------------------------------------------------------------
119
+ python_repl = PythonAstREPLTool()
 
 
 
 
 
 
 
 
 
 
120
 
121
  # ---------------------------------------------------------------------
122
+ # 5) LLM Gemini Flash, an Tools gebunden
123
  # ---------------------------------------------------------------------
124
  gemini_llm = ChatGoogleGenerativeAI(
125
+ google_api_key=os.getenv("GOOGLE_API_KEY"),
126
  model="gemini-2.0-flash",
 
127
  temperature=0,
128
  max_output_tokens=2048,
129
+ ).bind_tools(
130
+ [web_search, wiki_search, parse_csv, parse_excel, python_repl],
131
+ return_named_tools=True,
132
+ )
133
+
134
+ # ---------------------------------------------------------------------
135
+ # 6) System-Prompt (ReAct, keine Prefixe im Final-Output!)
136
+ # ---------------------------------------------------------------------
137
+ SYSTEM_PROMPT = SystemMessage(
138
+ content=(
139
+ "You are a helpful assistant with access to Python tools.\n"
140
+ "• Think step by step.\n"
141
+ "• Call a tool when needed – reply in this JSON format:\n"
142
+ " {\"tool\": \"<tool_name>\", \"tool_input\": { ... }}\n"
143
+ "• When you have the answer, reply with the answer **only** "
144
+ "– no prefix, no explanations.\n"
145
+ "Answer format rules:\n"
146
+ " Single number no separators / units unless required.\n"
147
+ " • Single string → no articles/abbrev.\n"
148
+ " • List → comma + single space separated, keep required order.\n"
149
+ )
150
+ )
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  # ---------------------------------------------------------------------
153
+ # 7) LangGraph – Planner + Tools + Router
154
  # ---------------------------------------------------------------------
155
+ def planner(state: MessagesState):
156
+ """LLM-Planner – entscheidet, ob Tool nötig oder Final Answer erreicht."""
157
  msgs = state["messages"]
158
  if msgs[0].type != "system":
159
+ msgs = [SYSTEM_PROMPT] + msgs
160
+ resp = with_backoff(lambda: gemini_llm.invoke(msgs))
161
+ finished = (
162
+ not getattr(resp, "tool_calls", None) # keine Toolaufrufe
163
+ and "\n" not in resp.content # heuristik: kurze Endantwort
164
+ )
165
  return {"messages": [resp], "should_end": finished}
166
 
167
  def route(state):
168
  return "END" if state["should_end"] else "tools"
169
 
170
+ # Tool-Knoten
171
+ TOOLS = [web_search, wiki_search, parse_csv, parse_excel, python_repl]
172
+
173
+ graph = StateGraph(MessagesState)
174
+ graph.add_node("planner", planner)
175
+ graph.add_node("tools", ToolNode(TOOLS))
176
+ graph.add_edge(START, "planner")
177
+ graph.add_conditional_edges("planner", route, {"tools": "tools", "END": END})
178
+
179
+ # compile → LangGraph-Executor
180
+ agent_executor = graph.compile(max_iterations=8)
181
+
182
  # ---------------------------------------------------------------------
183
+ # 8) Öffentliche Klasse – wird von app.py / logic.py verwendet
184
  # ---------------------------------------------------------------------
185
+ class GaiaAgent:
186
+ """LangChain·LangGraph-Agent für GAIA Level 1."""
 
 
 
187
 
188
+ def __init__(self):
189
+ print(" GaiaAgent initialised (LangGraph)")
 
 
 
190
 
191
+ def __call__(self, task_id: str, question: str) -> str:
192
+ """Run the agent on a single GAIA question → exact answer string."""
193
+ start_state = {"messages": [HumanMessage(content=question)]}
194
+ final_state = agent_executor.invoke(start_state)
195
+ # letze Message enthält Antwort
196
+ answer = final_state["messages"][-1].content
197
+ return answer.strip()