Spaces:
Runtime error
Runtime error
""" | |
agent.py β LangGraph-Agent mit | |
β’ Gemini 2.0 Flash | |
β’ Datei-Tools (CSV, Excel, Audio, Bild-Describe, OCR) | |
β’ Fehler-Retry-Logik | |
""" | |
import os, base64, mimetypes, subprocess, json, tempfile | |
from typing import Any | |
from langgraph.graph import START, StateGraph, MessagesState | |
from langgraph.prebuilt import tools_condition, ToolNode | |
from langchain_core.tools import tool | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
# ---------------------------------------------------------------------- | |
# 1 ββ ENV / LLM | |
# ---------------------------------------------------------------------- | |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
llm = ChatGoogleGenerativeAI( | |
model="gemini-2.0-flash", | |
google_api_key=GOOGLE_API_KEY, | |
temperature=0, | |
max_output_tokens=2048, | |
) | |
# ---------------------------------------------------------------------- | |
# 2 ββ ERROR-WRAPPER (garantiert "ERROR:"-String statt Exception) | |
# ---------------------------------------------------------------------- | |
def error_guard(fn): | |
def wrapper(*args, **kwargs): | |
try: | |
return fn(*args, **kwargs) | |
except Exception as e: | |
return f"ERROR: {e}" | |
return wrapper | |
# ---------------------------------------------------------------------- | |
# 3 ββ BASIS-TOOLS | |
# ---------------------------------------------------------------------- | |
def simple_calculator(operation: str, a: float, b: float) -> float: | |
"""Basic maths: add, subtract, multiply, divide.""" | |
ops = {"add": a + b, "subtract": a - b, "multiply": a * b, | |
"divide": a / b if b else float("inf")} | |
return ops.get(operation, "ERROR: unknown operation") | |
def fetch_gaia_file(task_id: str) -> str: | |
"""Download attachment for current GAIA task_id; returns local file path.""" | |
import requests, pathlib, uuid | |
url = f"https://agents-course-unit4-scoring.hf.space/file/{task_id}" | |
r = requests.get(url, timeout=15) | |
r.raise_for_status() | |
suffix = pathlib.Path(url).suffix or "" | |
fp = pathlib.Path(tempfile.gettempdir())/f"{uuid.uuid4().hex}{suffix}" | |
fp.write_bytes(r.content) | |
return str(fp) | |
def parse_csv(file_path: str, query: str = "") -> str: | |
"""Load CSV & answer query using pandas.eval.""" | |
import pandas as pd | |
df = pd.read_csv(file_path) | |
if not query: | |
return df.head().to_markdown() | |
return str(pd.eval(query, local_dict={"df": df})) | |
def parse_excel(file_path: str, query: str = "") -> str: | |
"""Load first sheet of Excel & answer query using pandas.eval.""" | |
import pandas as pd | |
df = pd.read_excel(file_path) | |
if not query: | |
return df.head().to_markdown() | |
return str(pd.eval(query, local_dict={"df": df})) | |
# ---------------------------------------------------------------------- | |
# 4 ββ GEMINI MULTIMODAL-TOOLS | |
# ---------------------------------------------------------------------- | |
def describe_image(file_path: str, prompt: str = "Describe the image.") -> str: | |
"""Send a local image (base64) to Gemini Vision and return description.""" | |
mime, _ = mimetypes.guess_type(file_path) | |
if not (mime and mime.startswith("image/")): | |
return "ERROR: not an image." | |
with open(file_path, "rb") as f: | |
b64 = base64.b64encode(f.read()).decode() | |
content = [ | |
{"type": "text", "text": prompt}, | |
{"type": "image_url", "image_url": f"data:{mime};base64,{b64}"}, | |
] | |
resp = llm.invoke([HumanMessage(content=content)]) | |
return resp.content | |
def gemini_transcribe_audio(file_path: str, | |
prompt: str = "Transcribe the audio.") -> str: | |
"""Transcribe audio via Gemini multimodal.""" | |
mime, _ = mimetypes.guess_type(file_path) | |
if not (mime and mime.startswith("audio/")): | |
return "ERROR: not audio." | |
with open(file_path, "rb") as f: | |
b64 = base64.b64encode(f.read()).decode() | |
content = [ | |
{"type": "text", "text": prompt}, | |
{"type": "media", "data": b64, "mime_type": mime}, | |
] | |
resp = llm.invoke([HumanMessage(content=content)]) | |
return resp.content | |
# ---------------------------------------------------------------------- | |
# 5 ββ OFFLINE OCR-TOOL (pytesseract) | |
# ---------------------------------------------------------------------- | |
def ocr_image(file_path: str, lang: str = "eng") -> str: | |
"""Extract text from image using pytesseract.""" | |
from PIL import Image | |
import pytesseract | |
img = Image.open(file_path) | |
return pytesseract.image_to_string(img, lang=lang).strip() | |
# ---------------------------------------------------------------------- | |
# 6 ββ WEB / WIKI SEARCH | |
# ---------------------------------------------------------------------- | |
def web_search(query: str, max_results: int = 5) -> str: | |
"""Tavily web search β returns markdown list of results.""" | |
search = TavilySearchResults(max_results=max_results) | |
hits = search.invoke(query) | |
if not hits: | |
return "ERROR: no results." | |
return "\n\n".join(f"{hit['title']} β {hit['url']}" for hit in hits) | |
# ---------------------------------------------------------------------- | |
# 7 ββ SYSTEM-PROMPT | |
# ---------------------------------------------------------------------- | |
system_prompt = SystemMessage(content=( | |
"You are a precise GAIA challenge agent.\n" | |
"Always attempt a TOOL call before giving up. " | |
"If a tool returns 'ERROR', think of an alternative tool or parameters.\n" | |
"Use fetch_gaia_file(task_id) for any attachment." | |
)) | |
# ---------------------------------------------------------------------- | |
# 8 ββ LangGraph Nodes | |
# ---------------------------------------------------------------------- | |
tools = [ | |
fetch_gaia_file, | |
parse_csv, | |
parse_excel, | |
gemini_transcribe_audio, | |
ocr_image, | |
describe_image, | |
web_search, | |
simple_calculator, | |
] | |
llm_with_tools = llm.bind_tools(tools) | |
def safe_llm_invoke(msgs): | |
for attempt in range(2): | |
resp = llm_with_tools.invoke(msgs) | |
content = resp.content or "" | |
if not content.startswith("ERROR"): | |
return resp | |
msgs.append( | |
SystemMessage(content="Previous tool call returned ERROR. Try another approach.") | |
) | |
return resp | |
def assistant(state: MessagesState): | |
msgs = state["messages"] | |
if not msgs or msgs[0].type != "system": | |
msgs = [system_prompt] + msgs | |
return {"messages": [safe_llm_invoke(msgs)]} | |
# ---------------------------------------------------------------------- | |
# 9 ββ Graph | |
# ---------------------------------------------------------------------- | |
builder = StateGraph(MessagesState) | |
builder.add_node("assistant", assistant) | |
builder.add_node("tools", ToolNode(tools)) | |
builder.add_edge(START, "assistant") | |
builder.add_conditional_edges("assistant", tools_condition) | |
builder.add_edge("tools", "assistant") | |
agent_executor = builder.compile() |