ZeroTimo's picture
Update agent.py
5a388d4 verified
raw
history blame
8.16 kB
"""
agent.py – LangGraph-Agent mit
• Gemini 2.0 Flash
• Datei-Tools (CSV, Excel, Audio, Bild-Describe, OCR)
• Fehler-Retry-Logik
"""
import os, base64, mimetypes, subprocess, json, tempfile
import functools
from typing import Any
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.tools.tavily_search import TavilySearchResults
# ----------------------------------------------------------------------
# 1 ── ENV / LLM
# ----------------------------------------------------------------------
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
google_api_key=GOOGLE_API_KEY,
temperature=0,
max_output_tokens=2048,
)
# ----------------------------------------------------------------------
# 2 ── ERROR-WRAPPER (garantiert "ERROR:"-String statt Exception)
# ----------------------------------------------------------------------
def error_guard(fn):
@functools.wraps(fn) # ➜ übernimmt __doc__, __name__, …
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception as e:
return f"ERROR: {e}"
return wrapper
# ----------------------------------------------------------------------
# 3 ── BASIS-TOOLS
# ----------------------------------------------------------------------
@tool
@error_guard
def simple_calculator(operation: str, a: float, b: float) -> float:
"""Basic maths: add, subtract, multiply, divide."""
ops = {"add": a + b, "subtract": a - b, "multiply": a * b,
"divide": a / b if b else float("inf")}
return ops.get(operation, "ERROR: unknown operation")
@tool
@error_guard
def fetch_gaia_file(task_id: str) -> str:
"""Download attachment for current GAIA task_id; returns local file path."""
import requests, pathlib, uuid
url = f"https://agents-course-unit4-scoring.hf.space/file/{task_id}"
r = requests.get(url, timeout=15)
r.raise_for_status()
suffix = pathlib.Path(url).suffix or ""
fp = pathlib.Path(tempfile.gettempdir())/f"{uuid.uuid4().hex}{suffix}"
fp.write_bytes(r.content)
return str(fp)
@tool
@error_guard
def parse_csv(file_path: str, query: str = "") -> str:
"""Load CSV & answer query using pandas.eval."""
import pandas as pd
df = pd.read_csv(file_path)
if not query:
return df.head().to_markdown()
return str(pd.eval(query, local_dict={"df": df}))
@tool
@error_guard
def parse_excel(file_path: str, query: str = "") -> str:
"""Load first sheet of Excel & answer query using pandas.eval."""
import pandas as pd
df = pd.read_excel(file_path)
if not query:
return df.head().to_markdown()
return str(pd.eval(query, local_dict={"df": df}))
# ----------------------------------------------------------------------
# 4 ── GEMINI MULTIMODAL-TOOLS
# ----------------------------------------------------------------------
@tool
@error_guard
def describe_image(file_path: str, prompt: str = "Describe the image.") -> str:
"""Send a local image (base64) to Gemini Vision and return description."""
mime, _ = mimetypes.guess_type(file_path)
if not (mime and mime.startswith("image/")):
return "ERROR: not an image."
with open(file_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode()
content = [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": f"data:{mime};base64,{b64}"},
]
resp = llm.invoke([HumanMessage(content=content)])
return resp.content
@tool
@error_guard
def gemini_transcribe_audio(file_path: str,
prompt: str = "Transcribe the audio.") -> str:
"""Transcribe audio via Gemini multimodal."""
mime, _ = mimetypes.guess_type(file_path)
if not (mime and mime.startswith("audio/")):
return "ERROR: not audio."
with open(file_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode()
content = [
{"type": "text", "text": prompt},
{"type": "media", "data": b64, "mime_type": mime},
]
resp = llm.invoke([HumanMessage(content=content)])
return resp.content
# ----------------------------------------------------------------------
# 5 ── OFFLINE OCR-TOOL (pytesseract)
# ----------------------------------------------------------------------
@tool
@error_guard
def ocr_image(file_path: str, lang: str = "eng") -> str:
"""Extract text from image using pytesseract."""
from PIL import Image
import pytesseract
img = Image.open(file_path)
return pytesseract.image_to_string(img, lang=lang).strip()
# ----------------------------------------------------------------------
# 6 ── WEB / WIKI SEARCH
# ----------------------------------------------------------------------
@tool
@error_guard
def web_search(query: str, max_results: int = 5) -> str:
"""Tavily web search – returns markdown list of results."""
search = TavilySearchResults(max_results=max_results)
hits = search.invoke(query)
if not hits:
return "ERROR: no results."
return "\n\n".join(f"{hit['title']}{hit['url']}" for hit in hits)
# ----------------------------------------------------------------------
# 7 ── SYSTEM-PROMPT
# ----------------------------------------------------------------------
system_prompt = SystemMessage(content=(
""""
You are GAIA-Assist, an accurate, tool-using agent.
TOOLS YOU CAN CALL
------------------
• fetch_gaia_file(task_id) – download the current task’s attachment
• parse_csv(file_path, query="")
• parse_excel(file_path, query="")
• gemini_transcribe_audio(file_path[, prompt])
• describe_image(file_path[, prompt])
• ocr_image(file_path[, lang="eng"])
• web_search(query [, max_results=5])
• simple_calculator(operation, a, b)
WORKFLOW RULES
--------------
1. **If** the question mentions an attachment, first call
fetch_gaia_file(task_id).
– After it returns a path, choose exactly one specialised parser.
2. **Otherwise**, think whether a web_search or calculator is needed.
3. **NEVER** call the same tool twice in a row with the same input.
ANSWER FORMAT
-------------
*If a tool is needed*
Thought: Do I need to use a tool? **Yes**
Action: <tool name>
Action Input: <JSON-encoded arguments>
*If no tool is needed*
Thought: Do I need to use a tool? **No**
Final Answer: <your concise answer here>
Once you have written **Final Answer:** you are done – do **not** call any further tool.
"""
))
# ----------------------------------------------------------------------
# 8 ── LangGraph Nodes
# ----------------------------------------------------------------------
tools = [
fetch_gaia_file,
parse_csv,
parse_excel,
gemini_transcribe_audio,
ocr_image,
describe_image,
web_search,
simple_calculator,
]
llm_with_tools = llm.bind_tools(tools)
def safe_llm_invoke(msgs):
for attempt in range(2):
resp = llm_with_tools.invoke(msgs)
content = resp.content or ""
if not content.startswith("ERROR"):
return resp
msgs.append(
SystemMessage(content="Previous tool call returned ERROR. Try another approach.")
)
return resp
def assistant(state: MessagesState):
msgs = state["messages"]
if not msgs or msgs[0].type != "system":
msgs = [system_prompt] + msgs
return {"messages": [safe_llm_invoke(msgs)]}
# ----------------------------------------------------------------------
# 9 ── Graph
# ----------------------------------------------------------------------
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
agent_executor = builder.compile()