File size: 5,076 Bytes
55bc352
 
690538b
 
 
400e97a
55bc352
 
690538b
400e97a
55bc352
5aa21de
7cc5531
690538b
5aa21de
400e97a
690538b
 
 
 
 
 
7cc5531
55bc352
7cc5531
400e97a
9b6ad5a
 
 
55bc352
d046ba6
7cc5531
55bc352
7cc5531
7d74ca3
7cc5531
400e97a
7d74ca3
400e97a
7d74ca3
 
 
 
7cc5531
55bc352
7cc5531
e8d1b6b
7d74ca3
e8d1b6b
7d74ca3
 
400e97a
55bc352
7cc5531
e8d1b6b
7d74ca3
400e97a
 
 
7d74ca3
400e97a
55bc352
7cc5531
60684f0
7d74ca3
400e97a
 
 
 
 
 
7d74ca3
d046ba6
7d74ca3
400e97a
 
 
 
7cc5531
55bc352
400e97a
7cc5531
 
55bc352
7cc5531
 
400e97a
7cc5531
 
 
400e97a
 
 
 
55bc352
 
 
400e97a
 
7cc5531
55bc352
 
7cc5531
55bc352
7cc5531
400e97a
55bc352
 
 
 
 
7cc5531
55bc352
 
 
400e97a
55bc352
 
 
400e97a
 
 
 
55bc352
 
 
 
 
 
400e97a
130522b
400e97a
7cc5531
55bc352
7cc5531
400e97a
 
 
60684f0
400e97a
55bc352
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# agent.py

import os
import time
import functools
import pandas as pd
from typing import Dict, Any, List
import re

from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities.wikipedia import WikipediaAPIWrapper

try:
    from langchain_experimental.tools.python.tool import PythonAstREPLTool
except ImportError:
    from langchain.tools.python.tool import PythonAstREPLTool

# ---------------------------------------------------------------------
# LangSmith optional
# ---------------------------------------------------------------------
if os.getenv("LANGCHAIN_API_KEY"):
    os.environ["LANGCHAIN_TRACING_V2"] = "true"
    os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
    os.environ.setdefault("LANGCHAIN_PROJECT", "gaia-agent")
    print("📱 LangSmith tracing enabled.")

# ---------------------------------------------------------------------
# Fehler-Wrapper
# ---------------------------------------------------------------------
def error_guard(fn):
    @functools.wraps(fn)
    def wrapper(*args, **kw):
        try:
            return fn(*args, **kw)
        except Exception as e:
            return f"ERROR: {e}"
    return wrapper

# ---------------------------------------------------------------------
# Eigene Tools
# ---------------------------------------------------------------------
@tool
@error_guard
def parse_csv(file_path: str, query: str = "") -> str:
    df = pd.read_csv(file_path)
    if not query:
        return f"Rows={len(df)}, Cols={list(df.columns)}"
    return df.query(query).to_markdown(index=False)

@tool
@error_guard
def parse_excel(file_path: str, sheet: str | int | None = None, query: str = "") -> str:
    sheet_arg = int(sheet) if isinstance(sheet, str) and sheet.isdigit() else sheet or 0
    df = pd.read_excel(file_path, sheet_name=sheet_arg)
    if not query:
        return f"Rows={len(df)}, Cols={list(df.columns)}"
    return df.query(query).to_markdown(index=False)

@tool
@error_guard
def web_search(query: str, max_results: int = 5) -> str:
    api_key = os.getenv("TAVILY_API_KEY")
    hits = TavilySearchResults(max_results=max_results, api_key=api_key).invoke(query)
    if not hits:
        return "No results."
    return "\n".join(f"{h['title']}{h['url']}" for h in hits)

@tool
@error_guard
def wiki_search(query: str, sentences: int = 3) -> str:
    wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=4000)
    res = wrapper.run(query)
    return "\n".join(res.split(". ")[:sentences]) if res else "No article found."

# Python Tool
python_repl = PythonAstREPLTool()

# ---------------------------------------------------------------------
# Gemini LLM
# ---------------------------------------------------------------------
gemini_llm = ChatGoogleGenerativeAI(
    google_api_key=os.getenv("GOOGLE_API_KEY"),
    model="gemini-2.0-flash",
    temperature=0,
    max_output_tokens=2048,
)

SYSTEM_PROMPT = SystemMessage(
    content=(
        "You are a helpful assistant with access to tools.\n"
        "Use tools when appropriate using tool calls.\n"
        "If the answer is clear, return it directly without explanation."
    )
)

TOOLS = [web_search, wiki_search, parse_csv, parse_excel, python_repl]

# ---------------------------------------------------------------------
# LangGraph Nodes
# ---------------------------------------------------------------------
def planner(state: MessagesState):
    messages = state["messages"]
    if not any(m.type == "system" for m in messages):
        messages = [SYSTEM_PROMPT] + messages
    resp = gemini_llm.invoke(messages)
    return {"messages": messages + [resp]}

def should_end(state: MessagesState) -> bool:
    last = state["messages"][-1]
    return not getattr(last, "tool_calls", None)

# ---------------------------------------------------------------------
# Build Graph
# ---------------------------------------------------------------------
graph = StateGraph(MessagesState)
graph.add_node("planner", planner)
graph.add_node("tools", ToolNode(TOOLS))
graph.add_edge(START, "planner")
graph.add_conditional_edges(
    "planner",
    lambda state: "END" if should_end(state) else "tools",
    {"tools": "tools", "END": END},
)
graph.add_edge("tools", "planner")

agent_executor = graph.compile()

# ---------------------------------------------------------------------
# Öffentliche Klasse
# ---------------------------------------------------------------------
class GaiaAgent:
    def __init__(self):
        print("✅ GaiaAgent initialised (LangGraph)")

    def __call__(self, task_id: str, question: str) -> str:
        state = {"messages": [HumanMessage(content=question)]}
        final = agent_executor.invoke(state)
        return final["messages"][-1].content.strip()