Spaces:
Runtime error
Runtime error
# agent.py | |
import os | |
import time | |
import functools | |
import pandas as pd | |
from typing import Dict, Any, List | |
import re | |
from langgraph.graph import StateGraph, START, END, MessagesState | |
from langgraph.prebuilt import ToolNode | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langchain_core.tools import tool | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_community.utilities.wikipedia import WikipediaAPIWrapper | |
try: | |
from langchain_experimental.tools.python.tool import PythonAstREPLTool | |
except ImportError: | |
from langchain.tools.python.tool import PythonAstREPLTool | |
# --------------------------------------------------------------------- | |
# LangSmith optional | |
# --------------------------------------------------------------------- | |
if os.getenv("LANGCHAIN_API_KEY"): | |
os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com" | |
os.environ.setdefault("LANGCHAIN_PROJECT", "gaia-agent") | |
print("📱 LangSmith tracing enabled.") | |
# --------------------------------------------------------------------- | |
# Fehler-Wrapper | |
# --------------------------------------------------------------------- | |
def error_guard(fn): | |
def wrapper(*args, **kw): | |
try: | |
return fn(*args, **kw) | |
except Exception as e: | |
return f"ERROR: {e}" | |
return wrapper | |
# --------------------------------------------------------------------- | |
# Eigene Tools | |
# --------------------------------------------------------------------- | |
def parse_csv(file_path: str, query: str = "") -> str: | |
df = pd.read_csv(file_path) | |
if not query: | |
return f"Rows={len(df)}, Cols={list(df.columns)}" | |
return df.query(query).to_markdown(index=False) | |
def parse_excel(file_path: str, sheet: str | int | None = None, query: str = "") -> str: | |
sheet_arg = int(sheet) if isinstance(sheet, str) and sheet.isdigit() else sheet or 0 | |
df = pd.read_excel(file_path, sheet_name=sheet_arg) | |
if not query: | |
return f"Rows={len(df)}, Cols={list(df.columns)}" | |
return df.query(query).to_markdown(index=False) | |
def web_search(query: str, max_results: int = 5) -> str: | |
api_key = os.getenv("TAVILY_API_KEY") | |
hits = TavilySearchResults(max_results=max_results, api_key=api_key).invoke(query) | |
if not hits: | |
return "No results." | |
return "\n".join(f"{h['title']} – {h['url']}" for h in hits) | |
def wiki_search(query: str, sentences: int = 3) -> str: | |
wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=4000) | |
res = wrapper.run(query) | |
return "\n".join(res.split(". ")[:sentences]) if res else "No article found." | |
# Python Tool | |
python_repl = PythonAstREPLTool() | |
# --------------------------------------------------------------------- | |
# Gemini LLM | |
# --------------------------------------------------------------------- | |
gemini_llm = ChatGoogleGenerativeAI( | |
google_api_key=os.getenv("GOOGLE_API_KEY"), | |
model="gemini-2.0-flash", | |
temperature=0, | |
max_output_tokens=2048, | |
) | |
SYSTEM_PROMPT = SystemMessage( | |
content=( | |
"You are a helpful assistant with access to tools.\n" | |
"Use tools when appropriate using tool calls.\n" | |
"If the answer is clear, return it directly without explanation." | |
) | |
) | |
TOOLS = [web_search, wiki_search, parse_csv, parse_excel, python_repl] | |
# --------------------------------------------------------------------- | |
# LangGraph Nodes | |
# --------------------------------------------------------------------- | |
def planner(state: MessagesState): | |
messages = state["messages"] | |
if not any(m.type == "system" for m in messages): | |
messages = [SYSTEM_PROMPT] + messages | |
resp = gemini_llm.invoke(messages) | |
return {"messages": messages + [resp]} | |
def should_end(state: MessagesState) -> bool: | |
last = state["messages"][-1] | |
return not getattr(last, "tool_calls", None) | |
# --------------------------------------------------------------------- | |
# Build Graph | |
# --------------------------------------------------------------------- | |
graph = StateGraph(MessagesState) | |
graph.add_node("planner", planner) | |
graph.add_node("tools", ToolNode(TOOLS)) | |
graph.add_edge(START, "planner") | |
graph.add_conditional_edges( | |
"planner", | |
lambda state: "END" if should_end(state) else "tools", | |
{"tools": "tools", "END": END}, | |
) | |
graph.add_edge("tools", "planner") | |
agent_executor = graph.compile() | |
# --------------------------------------------------------------------- | |
# Öffentliche Klasse | |
# --------------------------------------------------------------------- | |
class GaiaAgent: | |
def __init__(self): | |
print("✅ GaiaAgent initialised (LangGraph)") | |
def __call__(self, task_id: str, question: str) -> str: | |
state = {"messages": [HumanMessage(content=question)]} | |
final = agent_executor.invoke(state) | |
return final["messages"][-1].content.strip() |