Spaces:
Runtime error
Runtime error
File size: 7,249 Bytes
7d74ca3 6726bdc 7d74ca3 5aa21de 7d74ca3 60684f0 d046ba6 7d74ca3 e8d1b6b 7d74ca3 ad6ae38 7d74ca3 e8d1b6b 7d74ca3 e8d1b6b 7d74ca3 e8d1b6b 7d74ca3 e8d1b6b 7d74ca3 e8d1b6b 7d74ca3 e8d1b6b 7d74ca3 60684f0 7d74ca3 d046ba6 f5078a2 7d74ca3 d046ba6 7d74ca3 d046ba6 7d74ca3 5aa21de 7d74ca3 5aa21de 7d74ca3 e8d1b6b 7d74ca3 e8d1b6b 7d74ca3 257cce5 60684f0 7d74ca3 e8d1b6b 7d74ca3 e8d1b6b 257cce5 c8c37ed 7d74ca3 60684f0 7d74ca3 257cce5 60684f0 257cce5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
"""
agent.py – LangGraph-Agent mit
• Gemini 2.0 Flash
• Datei-Tools (CSV, Excel, Audio, Bild-Describe, OCR)
• Fehler-Retry-Logik
"""
import os, base64, mimetypes, subprocess, json, tempfile
import functools
from typing import Any
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.tools.tavily_search import TavilySearchResults
# ----------------------------------------------------------------------
# 1 ── ENV / LLM
# ----------------------------------------------------------------------
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
google_api_key=GOOGLE_API_KEY,
temperature=0,
max_output_tokens=2048,
)
# ----------------------------------------------------------------------
# 2 ── ERROR-WRAPPER (garantiert "ERROR:"-String statt Exception)
# ----------------------------------------------------------------------
def error_guard(fn):
@functools.wraps(fn) # ➜ übernimmt __doc__, __name__, …
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception as e:
return f"ERROR: {e}"
return wrapper
# ----------------------------------------------------------------------
# 3 ── BASIS-TOOLS
# ----------------------------------------------------------------------
@tool
@error_guard
def simple_calculator(operation: str, a: float, b: float) -> float:
"""Basic maths: add, subtract, multiply, divide."""
ops = {"add": a + b, "subtract": a - b, "multiply": a * b,
"divide": a / b if b else float("inf")}
return ops.get(operation, "ERROR: unknown operation")
@tool
@error_guard
def fetch_gaia_file(task_id: str) -> str:
"""Download attachment for current GAIA task_id; returns local file path."""
import requests, pathlib, uuid
url = f"https://agents-course-unit4-scoring.hf.space/file/{task_id}"
r = requests.get(url, timeout=15)
r.raise_for_status()
suffix = pathlib.Path(url).suffix or ""
fp = pathlib.Path(tempfile.gettempdir())/f"{uuid.uuid4().hex}{suffix}"
fp.write_bytes(r.content)
return str(fp)
@tool
@error_guard
def parse_csv(file_path: str, query: str = "") -> str:
"""Load CSV & answer query using pandas.eval."""
import pandas as pd
df = pd.read_csv(file_path)
if not query:
return df.head().to_markdown()
return str(pd.eval(query, local_dict={"df": df}))
@tool
@error_guard
def parse_excel(file_path: str, query: str = "") -> str:
"""Load first sheet of Excel & answer query using pandas.eval."""
import pandas as pd
df = pd.read_excel(file_path)
if not query:
return df.head().to_markdown()
return str(pd.eval(query, local_dict={"df": df}))
# ----------------------------------------------------------------------
# 4 ── GEMINI MULTIMODAL-TOOLS
# ----------------------------------------------------------------------
@tool
@error_guard
def describe_image(file_path: str, prompt: str = "Describe the image.") -> str:
"""Send a local image (base64) to Gemini Vision and return description."""
mime, _ = mimetypes.guess_type(file_path)
if not (mime and mime.startswith("image/")):
return "ERROR: not an image."
with open(file_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode()
content = [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": f"data:{mime};base64,{b64}"},
]
resp = llm.invoke([HumanMessage(content=content)])
return resp.content
@tool
@error_guard
def gemini_transcribe_audio(file_path: str,
prompt: str = "Transcribe the audio.") -> str:
"""Transcribe audio via Gemini multimodal."""
mime, _ = mimetypes.guess_type(file_path)
if not (mime and mime.startswith("audio/")):
return "ERROR: not audio."
with open(file_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode()
content = [
{"type": "text", "text": prompt},
{"type": "media", "data": b64, "mime_type": mime},
]
resp = llm.invoke([HumanMessage(content=content)])
return resp.content
# ----------------------------------------------------------------------
# 5 ── OFFLINE OCR-TOOL (pytesseract)
# ----------------------------------------------------------------------
@tool
@error_guard
def ocr_image(file_path: str, lang: str = "eng") -> str:
"""Extract text from image using pytesseract."""
from PIL import Image
import pytesseract
img = Image.open(file_path)
return pytesseract.image_to_string(img, lang=lang).strip()
# ----------------------------------------------------------------------
# 6 ── WEB / WIKI SEARCH
# ----------------------------------------------------------------------
@tool
@error_guard
def web_search(query: str, max_results: int = 5) -> str:
"""Tavily web search – returns markdown list of results."""
search = TavilySearchResults(max_results=max_results)
hits = search.invoke(query)
if not hits:
return "ERROR: no results."
return "\n\n".join(f"{hit['title']} – {hit['url']}" for hit in hits)
# ----------------------------------------------------------------------
# 7 ── SYSTEM-PROMPT
# ----------------------------------------------------------------------
system_prompt = SystemMessage(content=(
"You are a precise GAIA challenge agent.\n"
"Always attempt a TOOL call before giving up. "
"If a tool returns 'ERROR', think of an alternative tool or parameters.\n"
"Use fetch_gaia_file(task_id) for any attachment."
))
# ----------------------------------------------------------------------
# 8 ── LangGraph Nodes
# ----------------------------------------------------------------------
tools = [
fetch_gaia_file,
parse_csv,
parse_excel,
gemini_transcribe_audio,
ocr_image,
describe_image,
web_search,
simple_calculator,
]
llm_with_tools = llm.bind_tools(tools)
def safe_llm_invoke(msgs):
for attempt in range(2):
resp = llm_with_tools.invoke(msgs)
content = resp.content or ""
if not content.startswith("ERROR"):
return resp
msgs.append(
SystemMessage(content="Previous tool call returned ERROR. Try another approach.")
)
return resp
def assistant(state: MessagesState):
msgs = state["messages"]
if not msgs or msgs[0].type != "system":
msgs = [system_prompt] + msgs
return {"messages": [safe_llm_invoke(msgs)]}
# ----------------------------------------------------------------------
# 9 ── Graph
# ----------------------------------------------------------------------
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
agent_executor = builder.compile() |