ZeroTimo's picture
Update agent.py
5aa21de verified
raw
history blame
7.16 kB
"""
agent.py – LangGraph-Agent mit
β€’ Gemini 2.0 Flash
β€’ Datei-Tools (CSV, Excel, Audio, Bild-Describe, OCR)
β€’ Fehler-Retry-Logik
"""
import os, base64, mimetypes, subprocess, json, tempfile
from typing import Any
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.tools.tavily_search import TavilySearchResults
# ----------------------------------------------------------------------
# 1 ── ENV / LLM
# ----------------------------------------------------------------------
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
google_api_key=GOOGLE_API_KEY,
temperature=0,
max_output_tokens=2048,
)
# ----------------------------------------------------------------------
# 2 ── ERROR-WRAPPER (garantiert "ERROR:"-String statt Exception)
# ----------------------------------------------------------------------
def error_guard(fn):
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception as e:
return f"ERROR: {e}"
return wrapper
# ----------------------------------------------------------------------
# 3 ── BASIS-TOOLS
# ----------------------------------------------------------------------
@tool
@error_guard
def simple_calculator(operation: str, a: float, b: float) -> float:
"""Basic maths: add, subtract, multiply, divide."""
ops = {"add": a + b, "subtract": a - b, "multiply": a * b,
"divide": a / b if b else float("inf")}
return ops.get(operation, "ERROR: unknown operation")
@tool
@error_guard
def fetch_gaia_file(task_id: str) -> str:
"""Download attachment for current GAIA task_id; returns local file path."""
import requests, pathlib, uuid
url = f"https://agents-course-unit4-scoring.hf.space/file/{task_id}"
r = requests.get(url, timeout=15)
r.raise_for_status()
suffix = pathlib.Path(url).suffix or ""
fp = pathlib.Path(tempfile.gettempdir())/f"{uuid.uuid4().hex}{suffix}"
fp.write_bytes(r.content)
return str(fp)
@tool
@error_guard
def parse_csv(file_path: str, query: str = "") -> str:
"""Load CSV & answer query using pandas.eval."""
import pandas as pd
df = pd.read_csv(file_path)
if not query:
return df.head().to_markdown()
return str(pd.eval(query, local_dict={"df": df}))
@tool
@error_guard
def parse_excel(file_path: str, query: str = "") -> str:
"""Load first sheet of Excel & answer query using pandas.eval."""
import pandas as pd
df = pd.read_excel(file_path)
if not query:
return df.head().to_markdown()
return str(pd.eval(query, local_dict={"df": df}))
# ----------------------------------------------------------------------
# 4 ── GEMINI MULTIMODAL-TOOLS
# ----------------------------------------------------------------------
@tool
@error_guard
def describe_image(file_path: str, prompt: str = "Describe the image.") -> str:
"""Send a local image (base64) to Gemini Vision and return description."""
mime, _ = mimetypes.guess_type(file_path)
if not (mime and mime.startswith("image/")):
return "ERROR: not an image."
with open(file_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode()
content = [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": f"data:{mime};base64,{b64}"},
]
resp = llm.invoke([HumanMessage(content=content)])
return resp.content
@tool
@error_guard
def gemini_transcribe_audio(file_path: str,
prompt: str = "Transcribe the audio.") -> str:
"""Transcribe audio via Gemini multimodal."""
mime, _ = mimetypes.guess_type(file_path)
if not (mime and mime.startswith("audio/")):
return "ERROR: not audio."
with open(file_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode()
content = [
{"type": "text", "text": prompt},
{"type": "media", "data": b64, "mime_type": mime},
]
resp = llm.invoke([HumanMessage(content=content)])
return resp.content
# ----------------------------------------------------------------------
# 5 ── OFFLINE OCR-TOOL (pytesseract)
# ----------------------------------------------------------------------
@tool
@error_guard
def ocr_image(file_path: str, lang: str = "eng") -> str:
"""Extract text from image using pytesseract."""
from PIL import Image
import pytesseract
img = Image.open(file_path)
return pytesseract.image_to_string(img, lang=lang).strip()
# ----------------------------------------------------------------------
# 6 ── WEB / WIKI SEARCH
# ----------------------------------------------------------------------
@tool
@error_guard
def web_search(query: str, max_results: int = 5) -> str:
"""Tavily web search – returns markdown list of results."""
search = TavilySearchResults(max_results=max_results)
hits = search.invoke(query)
if not hits:
return "ERROR: no results."
return "\n\n".join(f"{hit['title']} – {hit['url']}" for hit in hits)
# ----------------------------------------------------------------------
# 7 ── SYSTEM-PROMPT
# ----------------------------------------------------------------------
system_prompt = SystemMessage(content=(
"You are a precise GAIA challenge agent.\n"
"Always attempt a TOOL call before giving up. "
"If a tool returns 'ERROR', think of an alternative tool or parameters.\n"
"Use fetch_gaia_file(task_id) for any attachment."
))
# ----------------------------------------------------------------------
# 8 ── LangGraph Nodes
# ----------------------------------------------------------------------
tools = [
fetch_gaia_file,
parse_csv,
parse_excel,
gemini_transcribe_audio,
ocr_image,
describe_image,
web_search,
simple_calculator,
]
llm_with_tools = llm.bind_tools(tools)
def safe_llm_invoke(msgs):
for attempt in range(2):
resp = llm_with_tools.invoke(msgs)
content = resp.content or ""
if not content.startswith("ERROR"):
return resp
msgs.append(
SystemMessage(content="Previous tool call returned ERROR. Try another approach.")
)
return resp
def assistant(state: MessagesState):
msgs = state["messages"]
if not msgs or msgs[0].type != "system":
msgs = [system_prompt] + msgs
return {"messages": [safe_llm_invoke(msgs)]}
# ----------------------------------------------------------------------
# 9 ── Graph
# ----------------------------------------------------------------------
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
agent_executor = builder.compile()