File size: 11,004 Bytes
beb1eb8
e8d1b6b
257cce5
 
3ed16ee
257cce5
 
 
 
e8d1b6b
 
d046ba6
257cce5
60684f0
d046ba6
257cce5
e8d1b6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60684f0
d046ba6
257cce5
d046ba6
 
60684f0
d046ba6
257cce5
d046ba6
 
60684f0
d046ba6
257cce5
d046ba6
 
60684f0
d046ba6
257cce5
d046ba6
 
 
 
f5078a2
d046ba6
257cce5
d046ba6
 
f5078a2
d046ba6
257cce5
d046ba6
257cce5
d046ba6
 
 
257cce5
d046ba6
257cce5
d046ba6
 
 
257cce5
 
 
 
 
 
 
e8d1b6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257cce5
 
 
 
 
 
 
 
 
60684f0
f5078a2
257cce5
e8d1b6b
 
 
 
 
 
 
 
 
 
257cce5
60684f0
e8d1b6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257cce5
 
e8d1b6b
 
 
 
 
60684f0
257cce5
 
 
 
 
 
 
60684f0
257cce5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import os
import pandas as pd
from langgraph.graph import StateGraph, START, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.tools import tool
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from langchain_core.messages import SystemMessage, HumanMessage
import requests
import tempfile

# Lade Umgebungsvariablen (Google API Key)
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")

# === Tools definieren ===

GAIA_BASE_URL = "https://agents-course-unit4-scoring.hf.space"

@tool
def fetch_gaia_file(task_id: str) -> str:
    """
    Download the file attached to a GAIA task and return the local file-path.

    Args:
        task_id: The GAIA task_id (string in the JSON payload).

    Returns:
        Absolute path to the downloaded temp-file.
    """
    try:
        url = f"{GAIA_BASE_URL}/files/{task_id}"
        response = requests.get(url, timeout=20)
        response.raise_for_status()

        # Server liefert den echten Dateinamen im Header – fallback auf "download"
        filename = (
            response.headers.get("x-filename") or
            response.headers.get("content-disposition", "download").split("filename=")[-1].strip('"')
        )
        if not filename:
            filename = f"{task_id}.bin"

        tmp_path = os.path.join(tempfile.gettempdir(), filename)
        with open(tmp_path, "wb") as f:
            f.write(response.content)

        return tmp_path
    except Exception as e:
        return f"ERROR: could not download file for task {task_id}: {e}"

@tool
def parse_csv(file_path: str, query: str = "") -> str:
    """
    Load a CSV file from `file_path` and optionally run a simple analysis query.

    Args:
        file_path: absolute path to a CSV file (from fetch_gaia_file)
        query: optional natural-language instruction, e.g.
               "sum of column Sales where Category != 'Drinks'"

    Returns:
        A concise string with the answer OR a preview of the dataframe
        if no query given.
    """
    try:
        df = pd.read_csv(file_path)

        # Auto-preview if kein query
        if not query:
            preview = df.head(5).to_markdown(index=False)
            return f"CSV loaded. First rows:\n\n{preview}\n\nColumns: {', '.join(df.columns)}"

        # Mini-query-engine (sehr simpel, reicht für Summen / Mittelwerte)
        query_lc = query.lower()
        if "sum" in query_lc:
            # ermitteln, welche Spalte summiert werden soll
            for col in df.columns:
                if col.lower() in query_lc:
                    s = df[col]
                    if "where" in query_lc:
                        # naive Filter-Parsing: where <col> != 'Drinks'
                        cond_part = query_lc.split("where", 1)[1].strip()
                        # SEHR einfaches  !=  oder ==  Parsing
                        if "!=" in cond_part:
                            key, val = [x.strip().strip("'\"") for x in cond_part.split("!=")]
                            s = df.loc[df[key] != val, col]
                        elif "==" in cond_part:
                            key, val = [x.strip().strip("'\"") for x in cond_part.split("==")]
                            s = df.loc[df[key] == val, col]
                    return str(round(s.sum(), 2))
        # Fallback
        return "Query type not supported by parse_csv."
    except Exception as e:
        return f"ERROR parsing CSV: {e}"

@tool
def parse_excel(file_path: str, query: str = "") -> str:
    """
    Identisch zu parse_csv, nur für XLS/XLSX.
    """
    try:
        df = pd.read_excel(file_path)

        if not query:
            preview = df.head(5).to_markdown(index=False)
            return f"Excel loaded. First rows:\n\n{preview}\n\nColumns: {', '.join(df.columns)}"

        query_lc = query.lower()
        if "sum" in query_lc:
            for col in df.columns:
                if col.lower() in query_lc:
                    s = df[col]
                    if "where" in query_lc:
                        cond_part = query_lc.split("where", 1)[1].strip()
                        if "!=" in cond_part:
                            key, val = [x.strip().strip("'\"") for x in cond_part.split("!=")]
                            s = df.loc[df[key] != val, col]
                        elif "==" in cond_part:
                            key, val = [x.strip().strip("'\"") for x in cond_part.split("==")]
                            s = df.loc[df[key] == val, col]
                    return str(round(s.sum(), 2))
        return "Query type not supported by parse_excel."
    except Exception as e:
        return f"ERROR parsing Excel: {e}"

@tool
def transcribe_audio(file_path: str, language: str = "en") -> str:
    """
    Transcribe an audio file (MP3/WAV/etc.) using Faster-Whisper.

    Args:
        file_path: absolute path to an audio file (from fetch_gaia_file)
        language: ISO language code, default "en"

    Returns:
        Full transcription as plain text, or "ERROR …"
    """
    try:
        from faster_whisper import WhisperModel

        # Tiny model reicht für kurze Sprachmemos, ~75 MB
        model = WhisperModel("tiny", device="cpu", compute_type="int8")

        segments, _ = model.transcribe(file_path, language=language)
        transcript = " ".join(segment.text.strip() for segment in segments).strip()

        if not transcript:
            return "ERROR: transcription empty."
        return transcript
    except Exception as e:
        return f"ERROR: audio transcription failed – {e}"


@tool
def multiply(a: int, b: int) -> int:
    """Multiplies two numbers."""
    return a * b

@tool
def add(a: int, b: int) -> int:
    """Adds two numbers."""
    return a + b

@tool
def subtract(a: int, b: int) -> int:
    """Subtracts two numbers."""
    return a - b

@tool
def divide(a: int, b: int) -> float:
    """Divides two numbers."""
    if b == 0:
        raise ValueError("Cannot divide by zero.")
    return a / b

@tool
def modulo(a: int, b: int) -> int:
    """Returns the remainder of dividing two numbers."""
    return a % b

@tool
def wiki_search(query: str) -> str:
    """Search Wikipedia for a query and return the result."""
    search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
    return "\n\n".join(doc.page_content for doc in search_docs)

@tool
def arxiv_search(query: str) -> str:
    """Search Arxiv for academic papers about a query."""
    search_docs = ArxivLoader(query=query, load_max_docs=3).load()
    return "\n\n".join(doc.page_content[:1000] for doc in search_docs)

@tool
def web_search(query: str) -> str:
    """Perform a DuckDuckGo web search."""
    wrapper = DuckDuckGoSearchAPIWrapper(max_results=5)
    results = wrapper.run(query)
    return results

# === System Prompt definieren ===
system_prompt = SystemMessage(content=(
    system_prompt = SystemMessage(
    content=(
        "You are a focused, factual AI agent competing on the GAIA evaluation.\n"
        "\n"
        "GENERAL RULES\n"
        "-------------\n"
        "1. Always try to answer every question.\n"
        "2. If you are NOT 100 % certain, prefer using a TOOL.\n"
        "3. Never invent facts.\n"
        "\n"
        "TOOLS\n"
        "-----\n"
        "- fetch_gaia_file(task_id): downloads any attachment for the current task.\n"
        "- parse_csv(file_path, query): analyse CSV files.\n"
        "- parse_excel(file_path, query): analyse Excel files.\n"
        "- transcribe_audio(file_path): transcribe MP3 / WAV audio.\n"
        "- wiki_search(query): query English Wikipedia.\n"
        "- arxiv_search(query): query arXiv.\n"
        "- web_search(query): DuckDuckGo web search.\n"
        "- simple_calculator(operation,a,b): basic maths.\n"
        "\n"
        "WHEN TO USE WHICH TOOL\n"
        "----------------------\n"
        "・If the prompt or GAIA metadata mentions an *attached* file, FIRST call "
        "fetch_gaia_file with the given task_id. Then:\n"
        "   • CSV  → parse_csv\n"
        "   • XLS/XLSX → parse_excel\n"
        "   • MP3/WAV → transcribe_audio (language auto-detect is OK)\n"
        "   • Image → (currently unsupported) answer that image processing is unavailable\n"
        "・If you need factual data (dates, numbers, names) → wiki_search or web_search.\n"
        "・If you need a scientific paper → arxiv_search.\n"
        "・If a numeric operation is required → simple_calculator.\n"
        "\n"
        "ERROR HANDLING\n"
        "--------------\n"
        "If a tool call returns a string that starts with \"ERROR:\", IMMEDIATELY think of "
        "an alternative strategy: retry with a different tool or modified parameters. "
        "Do not repeat the same failing call twice.\n"
        "\n"
        "OUTPUT FORMAT\n"
        "-------------\n"
        "Follow the exact format asked in the question (e.g. single word, CSV, comma-list). "
        "Do not add extra commentary.\n"
    )
)
))

# === LLM definieren ===
llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    google_api_key=GOOGLE_API_KEY,
    temperature=0,
    max_output_tokens=2048,
    system_message=system_prompt,
)

# === Tools in LLM einbinden ===
tools = [
    fetch_gaia_file,
    parse_csv,
    parse_excel,
    transcribe_audio,
    wiki_search,
    arxiv_search,
    web_search,
    simple_calculator,
]
llm_with_tools = llm.bind_tools(tools)

def safe_llm_invoke(messages):
    """
    Ruft LLM einmal auf. Wenn das Ergebnis mit ERROR beginnt,
    ruft es genau EIN weiteres Mal auf – jetzt weiß das LLM,
    dass der vorige Tool-Call fehlgeschlagen ist.
    """
    max_attempts = 2
    for attempt in range(max_attempts):
        result = llm_with_tools.invoke(messages)
        content = result.content if hasattr(result, "content") else ""
        if "ERROR:" not in content:
            return result
        # Fehler: füge eine System-Korrektur hinzu und versuche erneut
        messages.append(
            SystemMessage(
                content="Previous tool call returned an ERROR. "
                        "Try a different tool or revise the input."
            )
        )
    # nach max_attempts immer noch Fehler → zurückgeben
    return result


# === Nodes für LangGraph ===
def assistant(state: MessagesState):
    """
    Assistant node mit eingebautem Retry bei Tool-Fehlern.
    """
    result_msg = safe_llm_invoke(state["messages"])
    return {"messages": [result_msg]}

# === LangGraph bauen ===
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")

# === Agent Executor ===
agent_executor = builder.compile()