ZeroTimo's picture
Update agent.py
eeeeed7 verified
raw
history blame
8.93 kB
# agent.py – Gemini 2.0 Flash · LangGraph · Mehrere Tools
# =========================================================
import os, asyncio, base64, mimetypes, tempfile, functools, json
from typing import Dict, Any, List, Optional
from langgraph.graph import START, StateGraph, MessagesState, END
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.tools.tavily_search import TavilySearchResults
# ---------------------------------------------------------------------
# Konstanten / API-Keys
# ---------------------------------------------------------------------
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
TAVILY_KEY = os.getenv("TAVILY_API_KEY")
# ---------------------------------------------------------------------
# Fehler-Wrapper – behält Doc-String dank wraps
# ---------------------------------------------------------------------
import functools
def error_guard(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception as e:
return f"ERROR: {e}"
return wrapper
# ---------------------------------------------------------------------
# 1) fetch_gaia_file – Datei vom GAIA-Server holen
# ---------------------------------------------------------------------
GAIA_FILE_ENDPOINT = "https://agents-course-unit4-scoring.hf.space/file"
@tool
@error_guard
def fetch_gaia_file(task_id: str) -> str:
"""Download the attachment for the given GAIA task_id and return local path."""
url = f"{GAIA_FILE_ENDPOINT}/{task_id}"
try:
response = requests.get(url, timeout=30)
response.raise_for_status()
file_name = response.headers.get("x-gaia-filename", f"{task_id}")
tmp_path = tempfile.gettempdir() + "/" + file_name
with open(tmp_path, "wb") as f:
f.write(response.content)
return tmp_path
except Exception as e:
return f"ERROR: could not fetch file – {e}"
# ---------------------------------------------------------------------
# 2) CSV-Parser
# ---------------------------------------------------------------------
import pandas as pd
@tool
@error_guard
def parse_csv(file_path: str, query: str = "") -> str:
"""Load a CSV file and answer a quick pandas query (optional)."""
df = pd.read_csv(file_path)
if not query:
return f"Loaded CSV with {len(df)} rows and {len(df.columns)} cols.\nColumns: {list(df.columns)}"
try:
result = df.query(query)
return result.to_markdown()
except Exception as e:
return f"ERROR in pandas query: {e}"
# ---------------------------------------------------------------------
# 3) Excel-Parser
# ---------------------------------------------------------------------
@tool
@error_guard
def parse_excel(file_path: str, query: str = "") -> str:
"""Load an Excel file (first sheet) and answer a pandas query (optional)."""
df = pd.read_excel(file_path)
if not query:
return f"Loaded Excel with {len(df)} rows and {len(df.columns)} cols.\nColumns: {list(df.columns)}"
try:
result = df.query(query)
return result.to_markdown()
except Exception as e:
return f"ERROR in pandas query: {e}"
# ---------------------------------------------------------------------
# 4) Gemini-Audio-Transkription
# ---------------------------------------------------------------------
@tool
@error_guard
def gemini_transcribe_audio(file_path: str, prompt: str = "Transcribe the audio.") -> str:
"""Use Gemini to transcribe an audio file."""
with open(file_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode()
mime = mimetypes.guess_type(file_path)[0] or "audio/mpeg"
message = HumanMessage(
content=[
{"type": "text", "text": prompt},
{"type": "media", "data": b64, "mime_type": mime},
]
)
resp = asyncio.run(gemini_llm.invoke([message]))
return resp.content if hasattr(resp, "content") else str(resp)
# ---------------------------------------------------------------------
# 5) Bild-Beschreibung
# ---------------------------------------------------------------------
@tool
@error_guard
def describe_image(file_path: str, prompt: str = "Describe this image.") -> str:
"""Gemini vision – Bild beschreiben."""
from PIL import Image
img = Image.open(file_path)
message = HumanMessage(
content=[
{"type": "text", "text": prompt},
img, # langchain übernimmt Encoding
]
)
resp = asyncio.run(gemini_llm.invoke([message]))
return resp.content
# ---------------------------------------------------------------------
# 6) OCR-Tool
# ---------------------------------------------------------------------
@tool
@error_guard
def ocr_image(file_path: str, lang: str = "eng") -> str:
"""Extract text from an image via pytesseract."""
try:
import pytesseract
from PIL import Image
text = pytesseract.image_to_string(Image.open(file_path), lang=lang)
return text.strip() or "No text found."
except Exception as e:
return f"ERROR: {e}"
# ---------------------------------------------------------------------
# 7) Tavily-Web-Suche
# ---------------------------------------------------------------------
@tool
@error_guard
def web_search(query: str, max_results: int = 5) -> str:
"""Search the web via Tavily and return a markdown list of results."""
hits = TavilySearchResults(max_results=max_results, api_key=TAVILY_KEY).invoke(query)
if not hits:
return "No results."
return "\n\n".join(f"{h['title']}{h['url']}" for h in hits)
# ---------------------------------------------------------------------
# 8) Kleiner Rechner
# ---------------------------------------------------------------------
@tool
@error_guard
def simple_calculator(operation: str, a: float, b: float) -> float:
"""Basic maths (add, subtract, multiply, divide)."""
ops = {
"add": a + b,
"subtract": a - b,
"multiply": a * b,
"divide": a / b if b else float("inf"),
}
return ops.get(operation, f"ERROR: unknown op '{operation}'")
# ---------------------------------------------------------------------
# LLM + Semaphore-Throttle (Gemini 2.0 Flash)
# ---------------------------------------------------------------------
gemini_llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
google_api_key=GOOGLE_API_KEY,
temperature=0,
max_output_tokens=2048,
).bind_tools([
fetch_gaia_file, parse_csv, parse_excel,
gemini_transcribe_audio, describe_image, ocr_image,
web_search, simple_calculator,
])
LLM_SEMA = asyncio.Semaphore(3) # 3 gleichz. Anfragen ≈ < 15/min
async def safe_invoke(msgs: List[Any]):
async with LLM_SEMA:
return gemini_llm.invoke(msgs)
# ---------------------------------------------------------------------
# System-Prompt
# ---------------------------------------------------------------------
system_prompt = SystemMessage(content="""
You are GAIA-Assist, a precise, tool-using agent.
If a question mentions an attachment:
1. Call fetch_gaia_file(task_id)
2. Use exactly one specialised parser tool on the returned path.
Otherwise decide between web_search or simple_calculator.
Format for a tool call:
Thought: Do I need to use a tool? Yes
Action: <tool name>
Action Input: <JSON arguments>
Format for final answer:
Thought: Do I need to use a tool? No
Final Answer: <your answer>
Stop once you output "Final Answer:".
""")
# ---------------------------------------------------------------------
# LangGraph – Assistant-Node
# ---------------------------------------------------------------------
def assistant(state: MessagesState):
msgs = state["messages"]
if msgs[0].type != "system":
msgs = [system_prompt] + msgs
resp = asyncio.run(safe_invoke(msgs))
finished = resp.content.lower().lstrip().startswith("final answer") or not resp.tool_calls
return {"messages": [resp], "should_end": finished}
def route(state):
return "END" if state["should_end"] else "tools"
# ---------------------------------------------------------------------
# Tools-Liste & Graph
# ---------------------------------------------------------------------
tools = [
fetch_gaia_file, parse_csv, parse_excel,
gemini_transcribe_audio, describe_image, ocr_image,
web_search, simple_calculator,
]
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", route, {"tools": "tools", "END": END})
# Compile
agent_executor = builder.compile()