ZeroTimo's picture
Update agent.py
e8d1b6b verified
raw
history blame
11 kB
import os
import pandas as pd
from langgraph.graph import StateGraph, START, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.tools import tool
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from langchain_core.messages import SystemMessage, HumanMessage
import requests
import tempfile
# Lade Umgebungsvariablen (Google API Key)
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
# === Tools definieren ===
GAIA_BASE_URL = "https://agents-course-unit4-scoring.hf.space"
@tool
def fetch_gaia_file(task_id: str) -> str:
"""
Download the file attached to a GAIA task and return the local file-path.
Args:
task_id: The GAIA task_id (string in the JSON payload).
Returns:
Absolute path to the downloaded temp-file.
"""
try:
url = f"{GAIA_BASE_URL}/files/{task_id}"
response = requests.get(url, timeout=20)
response.raise_for_status()
# Server liefert den echten Dateinamen im Header – fallback auf "download"
filename = (
response.headers.get("x-filename") or
response.headers.get("content-disposition", "download").split("filename=")[-1].strip('"')
)
if not filename:
filename = f"{task_id}.bin"
tmp_path = os.path.join(tempfile.gettempdir(), filename)
with open(tmp_path, "wb") as f:
f.write(response.content)
return tmp_path
except Exception as e:
return f"ERROR: could not download file for task {task_id}: {e}"
@tool
def parse_csv(file_path: str, query: str = "") -> str:
"""
Load a CSV file from `file_path` and optionally run a simple analysis query.
Args:
file_path: absolute path to a CSV file (from fetch_gaia_file)
query: optional natural-language instruction, e.g.
"sum of column Sales where Category != 'Drinks'"
Returns:
A concise string with the answer OR a preview of the dataframe
if no query given.
"""
try:
df = pd.read_csv(file_path)
# Auto-preview if kein query
if not query:
preview = df.head(5).to_markdown(index=False)
return f"CSV loaded. First rows:\n\n{preview}\n\nColumns: {', '.join(df.columns)}"
# Mini-query-engine (sehr simpel, reicht für Summen / Mittelwerte)
query_lc = query.lower()
if "sum" in query_lc:
# ermitteln, welche Spalte summiert werden soll
for col in df.columns:
if col.lower() in query_lc:
s = df[col]
if "where" in query_lc:
# naive Filter-Parsing: where <col> != 'Drinks'
cond_part = query_lc.split("where", 1)[1].strip()
# SEHR einfaches != oder == Parsing
if "!=" in cond_part:
key, val = [x.strip().strip("'\"") for x in cond_part.split("!=")]
s = df.loc[df[key] != val, col]
elif "==" in cond_part:
key, val = [x.strip().strip("'\"") for x in cond_part.split("==")]
s = df.loc[df[key] == val, col]
return str(round(s.sum(), 2))
# Fallback
return "Query type not supported by parse_csv."
except Exception as e:
return f"ERROR parsing CSV: {e}"
@tool
def parse_excel(file_path: str, query: str = "") -> str:
"""
Identisch zu parse_csv, nur für XLS/XLSX.
"""
try:
df = pd.read_excel(file_path)
if not query:
preview = df.head(5).to_markdown(index=False)
return f"Excel loaded. First rows:\n\n{preview}\n\nColumns: {', '.join(df.columns)}"
query_lc = query.lower()
if "sum" in query_lc:
for col in df.columns:
if col.lower() in query_lc:
s = df[col]
if "where" in query_lc:
cond_part = query_lc.split("where", 1)[1].strip()
if "!=" in cond_part:
key, val = [x.strip().strip("'\"") for x in cond_part.split("!=")]
s = df.loc[df[key] != val, col]
elif "==" in cond_part:
key, val = [x.strip().strip("'\"") for x in cond_part.split("==")]
s = df.loc[df[key] == val, col]
return str(round(s.sum(), 2))
return "Query type not supported by parse_excel."
except Exception as e:
return f"ERROR parsing Excel: {e}"
@tool
def transcribe_audio(file_path: str, language: str = "en") -> str:
"""
Transcribe an audio file (MP3/WAV/etc.) using Faster-Whisper.
Args:
file_path: absolute path to an audio file (from fetch_gaia_file)
language: ISO language code, default "en"
Returns:
Full transcription as plain text, or "ERROR …"
"""
try:
from faster_whisper import WhisperModel
# Tiny model reicht für kurze Sprachmemos, ~75 MB
model = WhisperModel("tiny", device="cpu", compute_type="int8")
segments, _ = model.transcribe(file_path, language=language)
transcript = " ".join(segment.text.strip() for segment in segments).strip()
if not transcript:
return "ERROR: transcription empty."
return transcript
except Exception as e:
return f"ERROR: audio transcription failed – {e}"
@tool
def multiply(a: int, b: int) -> int:
"""Multiplies two numbers."""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Adds two numbers."""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtracts two numbers."""
return a - b
@tool
def divide(a: int, b: int) -> float:
"""Divides two numbers."""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulo(a: int, b: int) -> int:
"""Returns the remainder of dividing two numbers."""
return a % b
@tool
def wiki_search(query: str) -> str:
"""Search Wikipedia for a query and return the result."""
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
return "\n\n".join(doc.page_content for doc in search_docs)
@tool
def arxiv_search(query: str) -> str:
"""Search Arxiv for academic papers about a query."""
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
return "\n\n".join(doc.page_content[:1000] for doc in search_docs)
@tool
def web_search(query: str) -> str:
"""Perform a DuckDuckGo web search."""
wrapper = DuckDuckGoSearchAPIWrapper(max_results=5)
results = wrapper.run(query)
return results
# === System Prompt definieren ===
system_prompt = SystemMessage(content=(
system_prompt = SystemMessage(
content=(
"You are a focused, factual AI agent competing on the GAIA evaluation.\n"
"\n"
"GENERAL RULES\n"
"-------------\n"
"1. Always try to answer every question.\n"
"2. If you are NOT 100 % certain, prefer using a TOOL.\n"
"3. Never invent facts.\n"
"\n"
"TOOLS\n"
"-----\n"
"- fetch_gaia_file(task_id): downloads any attachment for the current task.\n"
"- parse_csv(file_path, query): analyse CSV files.\n"
"- parse_excel(file_path, query): analyse Excel files.\n"
"- transcribe_audio(file_path): transcribe MP3 / WAV audio.\n"
"- wiki_search(query): query English Wikipedia.\n"
"- arxiv_search(query): query arXiv.\n"
"- web_search(query): DuckDuckGo web search.\n"
"- simple_calculator(operation,a,b): basic maths.\n"
"\n"
"WHEN TO USE WHICH TOOL\n"
"----------------------\n"
"・If the prompt or GAIA metadata mentions an *attached* file, FIRST call "
"fetch_gaia_file with the given task_id. Then:\n"
" • CSV → parse_csv\n"
" • XLS/XLSX → parse_excel\n"
" • MP3/WAV → transcribe_audio (language auto-detect is OK)\n"
" • Image → (currently unsupported) answer that image processing is unavailable\n"
"・If you need factual data (dates, numbers, names) → wiki_search or web_search.\n"
"・If you need a scientific paper → arxiv_search.\n"
"・If a numeric operation is required → simple_calculator.\n"
"\n"
"ERROR HANDLING\n"
"--------------\n"
"If a tool call returns a string that starts with \"ERROR:\", IMMEDIATELY think of "
"an alternative strategy: retry with a different tool or modified parameters. "
"Do not repeat the same failing call twice.\n"
"\n"
"OUTPUT FORMAT\n"
"-------------\n"
"Follow the exact format asked in the question (e.g. single word, CSV, comma-list). "
"Do not add extra commentary.\n"
)
)
))
# === LLM definieren ===
llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
google_api_key=GOOGLE_API_KEY,
temperature=0,
max_output_tokens=2048,
system_message=system_prompt,
)
# === Tools in LLM einbinden ===
tools = [
fetch_gaia_file,
parse_csv,
parse_excel,
transcribe_audio,
wiki_search,
arxiv_search,
web_search,
simple_calculator,
]
llm_with_tools = llm.bind_tools(tools)
def safe_llm_invoke(messages):
"""
Ruft LLM einmal auf. Wenn das Ergebnis mit ERROR beginnt,
ruft es genau EIN weiteres Mal auf – jetzt weiß das LLM,
dass der vorige Tool-Call fehlgeschlagen ist.
"""
max_attempts = 2
for attempt in range(max_attempts):
result = llm_with_tools.invoke(messages)
content = result.content if hasattr(result, "content") else ""
if "ERROR:" not in content:
return result
# Fehler: füge eine System-Korrektur hinzu und versuche erneut
messages.append(
SystemMessage(
content="Previous tool call returned an ERROR. "
"Try a different tool or revise the input."
)
)
# nach max_attempts immer noch Fehler → zurückgeben
return result
# === Nodes für LangGraph ===
def assistant(state: MessagesState):
"""
Assistant node mit eingebautem Retry bei Tool-Fehlern.
"""
result_msg = safe_llm_invoke(state["messages"])
return {"messages": [result_msg]}
# === LangGraph bauen ===
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
# === Agent Executor ===
agent_executor = builder.compile()