Spaces:
Runtime error
Runtime error
""" | |
agent.py – LangGraph-Agent mit | |
• Gemini 2.0 Flash | |
• Datei-Tools (CSV, Excel, Audio, Bild-Describe, OCR) | |
• Fehler-Retry-Logik | |
""" | |
import os, base64, mimetypes, subprocess, json, tempfile | |
import functools | |
from typing import Any | |
from langgraph.graph import START, StateGraph, MessagesState | |
from langgraph.prebuilt import tools_condition, ToolNode | |
from langchain_core.tools import tool | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
# ---------------------------------------------------------------------- | |
# 1 ── ENV / LLM | |
# ---------------------------------------------------------------------- | |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
llm = ChatGoogleGenerativeAI( | |
model="gemini-2.0-flash", | |
google_api_key=GOOGLE_API_KEY, | |
temperature=0, | |
max_output_tokens=2048, | |
) | |
# ---------------------------------------------------------------------- | |
# 2 ── ERROR-WRAPPER (garantiert "ERROR:"-String statt Exception) | |
# ---------------------------------------------------------------------- | |
def error_guard(fn): | |
# ➜ übernimmt __doc__, __name__, … | |
def wrapper(*args, **kwargs): | |
try: | |
return fn(*args, **kwargs) | |
except Exception as e: | |
return f"ERROR: {e}" | |
return wrapper | |
# ---------------------------------------------------------------------- | |
# 3 ── BASIS-TOOLS | |
# ---------------------------------------------------------------------- | |
def simple_calculator(operation: str, a: float, b: float) -> float: | |
"""Basic maths: add, subtract, multiply, divide.""" | |
ops = {"add": a + b, "subtract": a - b, "multiply": a * b, | |
"divide": a / b if b else float("inf")} | |
return ops.get(operation, "ERROR: unknown operation") | |
def fetch_gaia_file(task_id: str) -> str: | |
"""Download attachment for current GAIA task_id; returns local file path.""" | |
import requests, pathlib, uuid | |
url = f"https://agents-course-unit4-scoring.hf.space/file/{task_id}" | |
r = requests.get(url, timeout=15) | |
r.raise_for_status() | |
suffix = pathlib.Path(url).suffix or "" | |
fp = pathlib.Path(tempfile.gettempdir())/f"{uuid.uuid4().hex}{suffix}" | |
fp.write_bytes(r.content) | |
return str(fp) | |
def parse_csv(file_path: str, query: str = "") -> str: | |
"""Load CSV & answer query using pandas.eval.""" | |
import pandas as pd | |
df = pd.read_csv(file_path) | |
if not query: | |
return df.head().to_markdown() | |
return str(pd.eval(query, local_dict={"df": df})) | |
def parse_excel(file_path: str, query: str = "") -> str: | |
"""Load first sheet of Excel & answer query using pandas.eval.""" | |
import pandas as pd | |
df = pd.read_excel(file_path) | |
if not query: | |
return df.head().to_markdown() | |
return str(pd.eval(query, local_dict={"df": df})) | |
# ---------------------------------------------------------------------- | |
# 4 ── GEMINI MULTIMODAL-TOOLS | |
# ---------------------------------------------------------------------- | |
def describe_image(file_path: str, prompt: str = "Describe the image.") -> str: | |
"""Send a local image (base64) to Gemini Vision and return description.""" | |
mime, _ = mimetypes.guess_type(file_path) | |
if not (mime and mime.startswith("image/")): | |
return "ERROR: not an image." | |
with open(file_path, "rb") as f: | |
b64 = base64.b64encode(f.read()).decode() | |
content = [ | |
{"type": "text", "text": prompt}, | |
{"type": "image_url", "image_url": f"data:{mime};base64,{b64}"}, | |
] | |
resp = llm.invoke([HumanMessage(content=content)]) | |
return resp.content | |
def gemini_transcribe_audio(file_path: str, | |
prompt: str = "Transcribe the audio.") -> str: | |
"""Transcribe audio via Gemini multimodal.""" | |
mime, _ = mimetypes.guess_type(file_path) | |
if not (mime and mime.startswith("audio/")): | |
return "ERROR: not audio." | |
with open(file_path, "rb") as f: | |
b64 = base64.b64encode(f.read()).decode() | |
content = [ | |
{"type": "text", "text": prompt}, | |
{"type": "media", "data": b64, "mime_type": mime}, | |
] | |
resp = llm.invoke([HumanMessage(content=content)]) | |
return resp.content | |
# ---------------------------------------------------------------------- | |
# 5 ── OFFLINE OCR-TOOL (pytesseract) | |
# ---------------------------------------------------------------------- | |
def ocr_image(file_path: str, lang: str = "eng") -> str: | |
"""Extract text from image using pytesseract.""" | |
from PIL import Image | |
import pytesseract | |
img = Image.open(file_path) | |
return pytesseract.image_to_string(img, lang=lang).strip() | |
# ---------------------------------------------------------------------- | |
# 6 ── WEB / WIKI SEARCH | |
# ---------------------------------------------------------------------- | |
def web_search(query: str, max_results: int = 5) -> str: | |
"""Tavily web search – returns markdown list of results.""" | |
search = TavilySearchResults(max_results=max_results) | |
hits = search.invoke(query) | |
if not hits: | |
return "ERROR: no results." | |
return "\n\n".join(f"{hit['title']} – {hit['url']}" for hit in hits) | |
# ---------------------------------------------------------------------- | |
# 7 ── SYSTEM-PROMPT | |
# ---------------------------------------------------------------------- | |
system_prompt = SystemMessage(content=( | |
"""" | |
You are GAIA-Assist, an accurate, tool-using agent. | |
TOOLS YOU CAN CALL | |
------------------ | |
• fetch_gaia_file(task_id) – download the current task’s attachment | |
• parse_csv(file_path, query="") | |
• parse_excel(file_path, query="") | |
• gemini_transcribe_audio(file_path[, prompt]) | |
• describe_image(file_path[, prompt]) | |
• ocr_image(file_path[, lang="eng"]) | |
• web_search(query [, max_results=5]) | |
• simple_calculator(operation, a, b) | |
WORKFLOW RULES | |
-------------- | |
1. **If** the question mentions an attachment, first call | |
fetch_gaia_file(task_id). | |
– After it returns a path, choose exactly one specialised parser. | |
2. **Otherwise**, think whether a web_search or calculator is needed. | |
3. **NEVER** call the same tool twice in a row with the same input. | |
ANSWER FORMAT | |
------------- | |
*If a tool is needed* | |
Thought: Do I need to use a tool? **Yes** | |
Action: <tool name> | |
Action Input: <JSON-encoded arguments> | |
*If no tool is needed* | |
Thought: Do I need to use a tool? **No** | |
Final Answer: <your concise answer here> | |
Once you have written **Final Answer:** you are done – do **not** call any further tool. | |
""" | |
)) | |
# ---------------------------------------------------------------------- | |
# 8 ── LangGraph Nodes | |
# ---------------------------------------------------------------------- | |
tools = [ | |
fetch_gaia_file, | |
parse_csv, | |
parse_excel, | |
gemini_transcribe_audio, | |
ocr_image, | |
describe_image, | |
web_search, | |
simple_calculator, | |
] | |
llm_with_tools = llm.bind_tools(tools) | |
def safe_llm_invoke(msgs): | |
for attempt in range(2): | |
resp = llm_with_tools.invoke(msgs) | |
content = resp.content or "" | |
if not content.startswith("ERROR"): | |
return resp | |
msgs.append( | |
SystemMessage(content="Previous tool call returned ERROR. Try another approach.") | |
) | |
return resp | |
def assistant(state: MessagesState): | |
msgs = state["messages"] | |
if not msgs or msgs[0].type != "system": | |
msgs = [system_prompt] + msgs | |
return {"messages": [safe_llm_invoke(msgs)]} | |
# ---------------------------------------------------------------------- | |
# 9 ── Graph | |
# ---------------------------------------------------------------------- | |
builder = StateGraph(MessagesState) | |
builder.add_node("assistant", assistant) | |
builder.add_node("tools", ToolNode(tools)) | |
builder.add_edge(START, "assistant") | |
builder.add_conditional_edges("assistant", tools_condition) | |
builder.add_edge("tools", "assistant") | |
agent_executor = builder.compile() |