ZeroTimo's picture
Update agent.py
130522b verified
raw
history blame
7.52 kB
import os
import re
import time
import functools
from typing import Dict, Any, List
import pandas as pd
# LangGraph
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
# LangChain Core
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
# Google Gemini
from langchain_google_genai import ChatGoogleGenerativeAI
# Tools
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
# Python REPL Tool
try:
from langchain_experimental.tools.python.tool import PythonAstREPLTool
except ImportError:
from langchain.tools.python.tool import PythonAstREPLTool
# ---------------------------------------------------------------------
# 0) Optionale LangSmith-Tracing (setze ENV: LANGCHAIN_API_KEY)
# ---------------------------------------------------------------------
if os.getenv("LANGCHAIN_API_KEY"):
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ.setdefault("LANGCHAIN_PROJECT", "gaia-agent")
print("📡 LangSmith tracing enabled.")
# ---------------------------------------------------------------------
# 1) Helfer: Fehler-Decorator + Backoff-Wrapper
# ---------------------------------------------------------------------
def error_guard(fn):
"""Fängt Tool-Fehler ab & gibt String zurück (bricht Agent nicht ab)."""
@functools.wraps(fn)
def wrapper(*args, **kw):
try:
return fn(*args, **kw)
except Exception as e:
return f"ERROR: {e}"
return wrapper
def with_backoff(fn, tries: int = 4, delay: int = 4):
"""Synchrones Retry-Wrapper für LLM-Aufrufe."""
for t in range(tries):
try:
return fn()
except Exception as e:
if ("429" in str(e) or "RateLimit" in str(e)) and t < tries - 1:
time.sleep(delay)
delay *= 2
continue
raise
# ---------------------------------------------------------------------
# 2) Eigene Tools (CSV / Excel)
# ---------------------------------------------------------------------
@tool
@error_guard
def parse_csv(file_path: str, query: str = "") -> str:
"""Load a CSV file and (optional) run a pandas query."""
df = pd.read_csv(file_path)
if not query:
return f"Rows={len(df)}, Cols={list(df.columns)}"
try:
return df.query(query).to_markdown(index=False)
except Exception as e:
return f"ERROR query: {e}"
@tool
@error_guard
def parse_excel(file_path: str, sheet: str | int | None = None, query: str = "") -> str:
"""Load an Excel sheet (name or index) and (optional) run a pandas query."""
sheet_arg = int(sheet) if isinstance(sheet, str) and sheet.isdigit() else sheet or 0
df = pd.read_excel(file_path, sheet_name=sheet_arg)
if not query:
return f"Rows={len(df)}, Cols={list(df.columns)}"
try:
return df.query(query).to_markdown(index=False)
except Exception as e:
return f"ERROR query: {e}"
# ---------------------------------------------------------------------
# 3) Externe Search-Tools (Tavily, Wikipedia)
# ---------------------------------------------------------------------
@tool
@error_guard
def web_search(query: str, max_results: int = 5) -> str:
"""Search the web via Tavily and return markdown list of results."""
api_key = os.getenv("TAVILY_API_KEY")
hits = TavilySearchResults(max_results=max_results, api_key=api_key).invoke(query)
if not hits:
return "No results."
return "\n".join(f"{h['title']}{h['url']}" for h in hits)
@tool
@error_guard
def wiki_search(query: str, sentences: int = 3) -> str:
"""Quick Wikipedia summary."""
wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=4000)
res = wrapper.run(query)
return "\n".join(res.split(". ")[:sentences]) if res else "No article found."
# ---------------------------------------------------------------------
# 4) Python-REPL Tool (fertig aus LangChain)
# ---------------------------------------------------------------------
python_repl = PythonAstREPLTool()
# ---------------------------------------------------------------------
# 5) LLM – Gemini Flash, an Tools gebunden
# ---------------------------------------------------------------------
gemini_llm = ChatGoogleGenerativeAI(
google_api_key=os.getenv("GOOGLE_API_KEY"),
model="gemini-2.0-flash",
temperature=0,
max_output_tokens=2048,
).bind_tools(
[web_search, wiki_search, parse_csv, parse_excel, python_repl],
return_named_tools=True,
)
# ---------------------------------------------------------------------
# 6) System-Prompt (ReAct, keine Prefixe im Final-Output!)
# ---------------------------------------------------------------------
SYSTEM_PROMPT = SystemMessage(
content=(
"You are a helpful assistant with access to Python tools.\n"
"• Think step by step.\n"
"• Call a tool when needed – reply in this JSON format:\n"
" {\"tool\": \"<tool_name>\", \"tool_input\": { ... }}\n"
"• When you have the answer, reply with the answer **only** "
"– no prefix, no explanations.\n"
"Answer format rules:\n"
" • Single number → no separators / units unless required.\n"
" • Single string → no articles/abbrev.\n"
" • List → comma + single space separated, keep required order.\n"
)
)
# ---------------------------------------------------------------------
# 7) LangGraph – Planner + Tools + Router
# ---------------------------------------------------------------------
def planner(state: MessagesState):
"""LLM-Planner – entscheidet, ob Tool nötig oder Final Answer erreicht."""
msgs = state["messages"]
if msgs[0].type != "system":
msgs = [SYSTEM_PROMPT] + msgs
resp = with_backoff(lambda: gemini_llm.invoke(msgs))
finished = (
not getattr(resp, "tool_calls", None) # keine Toolaufrufe
and "\n" not in resp.content # heuristik: kurze Endantwort
)
return {"messages": [resp], "should_end": finished}
def route(state):
return "END" if state["should_end"] else "tools"
# Tool-Knoten
TOOLS = [web_search, wiki_search, parse_csv, parse_excel, python_repl]
graph = StateGraph(MessagesState)
graph.add_node("planner", planner)
graph.add_node("tools", ToolNode(TOOLS))
graph.add_edge(START, "planner")
graph.add_conditional_edges("planner", route, {"tools": "tools", "END": END})
# compile → LangGraph-Executor
agent_executor = graph.compile()
# ---------------------------------------------------------------------
# 8) Öffentliche Klasse – wird von app.py / logic.py verwendet
# ---------------------------------------------------------------------
class GaiaAgent:
"""LangChain·LangGraph-Agent für GAIA Level 1."""
def __init__(self):
print("✅ GaiaAgent initialised (LangGraph)")
def __call__(self, task_id: str, question: str) -> str:
"""Run the agent on a single GAIA question → exact answer string."""
start_state = {"messages": [HumanMessage(content=question)]}
final_state = agent_executor.invoke(start_state)
# letze Message enthält Antwort
answer = final_state["messages"][-1].content
return answer.strip()