Spaces:
Runtime error
Runtime error
File size: 9,876 Bytes
7cc5531 7d74ca3 eeeeed7 7cc5531 5aa21de 7cc5531 5aa21de 7d74ca3 7cc5531 60684f0 7cc5531 d046ba6 7cc5531 7d74ca3 7cc5531 7d74ca3 7cc5531 e8d1b6b 7d74ca3 e8d1b6b 7cc5531 e8d1b6b 7d74ca3 e8d1b6b 7cc5531 7d74ca3 7cc5531 e8d1b6b 7d74ca3 e8d1b6b 7cc5531 7d74ca3 7cc5531 60684f0 7d74ca3 7cc5531 7d74ca3 7cc5531 c21f73b 7cc5531 f5078a2 7d74ca3 7cc5531 c21f73b 7d74ca3 7cc5531 d046ba6 7d74ca3 7cc5531 d046ba6 7d74ca3 5aa21de 7cc5531 5aa21de 7cc5531 c21f73b 7cc5531 c21f73b 7cc5531 257cce5 c8c37ed 7cc5531 c8c37ed 7cc5531 60684f0 257cce5 7cc5531 60684f0 7cc5531 257cce5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
# agent.py – Gemini 2.0 Flash · LangGraph · Mehrere Tools
# =========================================================
import os, asyncio, base64, mimetypes, tempfile, functools, json
from typing import Dict, Any, List, Optional
from langgraph.graph import START, StateGraph, MessagesState, END
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.tools.tavily_search import TavilySearchResults
# ---------------------------------------------------------------------
# Konstanten / API-Keys
# ---------------------------------------------------------------------
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
TAVILY_KEY = os.getenv("TAVILY_API_KEY")
# ---------------------------------------------------------------------
# Fehler-Wrapper – behält Doc-String dank wraps
# ---------------------------------------------------------------------
import functools
def error_guard(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception as e:
return f"ERROR: {e}"
return wrapper
# ---------------------------------------------------------------------
# 1) fetch_gaia_file – Datei vom GAIA-Server holen
# ---------------------------------------------------------------------
GAIA_FILE_ENDPOINT = "https://agents-course-unit4-scoring.hf.space/file"
@tool
@error_guard
def fetch_gaia_file(task_id: str) -> str:
"""Download the attachment for the given GAIA task_id and return local path."""
url = f"{GAIA_FILE_ENDPOINT}/{task_id}"
try:
response = requests.get(url, timeout=30)
response.raise_for_status()
file_name = response.headers.get("x-gaia-filename", f"{task_id}")
tmp_path = tempfile.gettempdir() + "/" + file_name
with open(tmp_path, "wb") as f:
f.write(response.content)
return tmp_path
except Exception as e:
return f"ERROR: could not fetch file – {e}"
# ---------------------------------------------------------------------
# 2) CSV-Parser
# ---------------------------------------------------------------------
import pandas as pd
@tool
@error_guard
def parse_csv(file_path: str, query: str = "") -> str:
"""Load a CSV file and answer a quick pandas query (optional)."""
df = pd.read_csv(file_path)
if not query:
return f"Loaded CSV with {len(df)} rows and {len(df.columns)} cols.\nColumns: {list(df.columns)}"
try:
result = df.query(query)
return result.to_markdown()
except Exception as e:
return f"ERROR in pandas query: {e}"
# ---------------------------------------------------------------------
# 3) Excel-Parser
# ---------------------------------------------------------------------
@tool
@error_guard
def parse_excel(file_path: str, query: str = "") -> str:
"""Load an Excel file (first sheet) and answer a pandas query (optional)."""
df = pd.read_excel(file_path)
if not query:
return f"Loaded Excel with {len(df)} rows and {len(df.columns)} cols.\nColumns: {list(df.columns)}"
try:
result = df.query(query)
return result.to_markdown()
except Exception as e:
return f"ERROR in pandas query: {e}"
# ---------------------------------------------------------------------
# 4) Gemini-Audio-Transkription
# ---------------------------------------------------------------------
@tool
@error_guard
def gemini_transcribe_audio(file_path: str, prompt: str = "Transcribe the audio.") -> str:
"""Use Gemini to transcribe an audio file."""
with open(file_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode()
mime = mimetypes.guess_type(file_path)[0] or "audio/mpeg"
message = HumanMessage(
content=[
{"type": "text", "text": prompt},
{"type": "media", "data": b64, "mime_type": mime},
]
)
resp = asyncio.run(safe_invoke([message]))
return resp.content if hasattr(resp, "content") else str(resp)
# ---------------------------------------------------------------------
# 5) Bild-Beschreibung
# ---------------------------------------------------------------------
@tool
@error_guard
def describe_image(file_path: str, prompt: str = "Describe this image.") -> str:
"""Gemini vision – Bild beschreiben."""
from PIL import Image
img = Image.open(file_path)
message = HumanMessage(
content=[
{"type": "text", "text": prompt},
img, # langchain übernimmt Encoding
]
)
resp = asyncio.run(safe_invoke([message]))
return resp.content
# ---------------------------------------------------------------------
# 6) OCR-Tool
# ---------------------------------------------------------------------
@tool
@error_guard
def ocr_image(file_path: str, lang: str = "eng") -> str:
"""Extract text from an image via pytesseract."""
try:
import pytesseract
from PIL import Image
text = pytesseract.image_to_string(Image.open(file_path), lang=lang)
return text.strip() or "No text found."
except Exception as e:
return f"ERROR: {e}"
# ---------------------------------------------------------------------
# 7) Tavily-Web-Suche
# ---------------------------------------------------------------------
@tool
@error_guard
def web_search(query: str, max_results: int = 5) -> str:
"""Search the web via Tavily and return a markdown list of results."""
hits = TavilySearchResults(max_results=max_results, api_key=TAVILY_KEY).invoke(query)
if not hits:
return "No results."
return "\n\n".join(f"{h['title']} – {h['url']}" for h in hits)
# ---------------------------------------------------------------------
# 8) Kleiner Rechner
# ---------------------------------------------------------------------
@tool
@error_guard
def simple_calculator(operation: str, a: float, b: float) -> float:
"""Basic maths (add, subtract, multiply, divide)."""
ops = {
"add": a + b,
"subtract": a - b,
"multiply": a * b,
"divide": a / b if b else float("inf"),
}
return ops.get(operation, f"ERROR: unknown op '{operation}'")
# ---------------------------------------------------------------------
# LLM + Semaphore-Throttle (Gemini 2.0 Flash)
# ---------------------------------------------------------------------
gemini_llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
google_api_key=GOOGLE_API_KEY,
temperature=0,
max_output_tokens=2048,
).bind_tools([
fetch_gaia_file, parse_csv, parse_excel,
gemini_transcribe_audio, describe_image, ocr_image,
web_search, simple_calculator,] ,return_named_tools=True)
LLM_SEMA = asyncio.Semaphore(2) # 3 gleichz. Anfragen ≈ < 15/min
# safe_invoke neu (ersetzt die alte Funktion)
async def safe_invoke(msgs, tries: int = 4):
"""Gemini-Aufruf mit Semaphor + Exponential-Back-off bei 429 / Netzfehlern."""
delay = 4
for t in range(tries):
async with LLM_SEMA:
try:
return await gemini_llm.ainvoke(msgs)
except Exception as e:
# nur bei Rate-Limit oder Netzwerk erneut versuchen
if ("429" in str(e) or "RateLimit" in str(e)) and t < tries - 1:
await asyncio.sleep(delay)
delay *= 2 # 4 s, 8 s, 16 s …
continue
raise
# ---------------------------------------------------------------------
# System-Prompt
# ---------------------------------------------------------------------
system_prompt = SystemMessage(content="""
You are a helpful assistant tasked with answering questions using a set of tools.
Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
FINAL ANSWER: [YOUR FINAL ANSWER].
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
""")
# ---------------------------------------------------------------------
# LangGraph – Assistant-Node
# ---------------------------------------------------------------------
def assistant(state: MessagesState):
msgs = state["messages"]
if msgs[0].type != "system":
msgs = [system_prompt] + msgs
resp = asyncio.run(safe_invoke(msgs))
finished = resp.content.lower().lstrip().startswith("final answer") or not resp.tool_calls
return {"messages": [resp], "should_end": finished}
def route(state):
return "END" if state["should_end"] else "tools"
# ---------------------------------------------------------------------
# Tools-Liste & Graph
# ---------------------------------------------------------------------
tools = [
fetch_gaia_file, parse_csv, parse_excel,
gemini_transcribe_audio, describe_image, ocr_image,
web_search, simple_calculator,
]
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", route, {"tools": "tools", "END": END})
# Compile
agent_executor = builder.compile() |