import os import pandas as pd from langgraph.graph import StateGraph, START, MessagesState from langgraph.prebuilt import tools_condition, ToolNode from langchain_google_genai import ChatGoogleGenerativeAI from langchain_core.tools import tool from langchain_community.document_loaders import WikipediaLoader, ArxivLoader from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper from langchain_core.messages import SystemMessage, HumanMessage import requests import tempfile # Lade Umgebungsvariablen (Google API Key) GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") # === Tools definieren === GAIA_BASE_URL = "https://agents-course-unit4-scoring.hf.space" @tool def fetch_gaia_file(task_id: str) -> str: """ Download the file attached to a GAIA task and return the local file-path. Args: task_id: The GAIA task_id (string in the JSON payload). Returns: Absolute path to the downloaded temp-file. """ try: url = f"{GAIA_BASE_URL}/files/{task_id}" response = requests.get(url, timeout=20) response.raise_for_status() # Server liefert den echten Dateinamen im Header – fallback auf "download" filename = ( response.headers.get("x-filename") or response.headers.get("content-disposition", "download").split("filename=")[-1].strip('"') ) if not filename: filename = f"{task_id}.bin" tmp_path = os.path.join(tempfile.gettempdir(), filename) with open(tmp_path, "wb") as f: f.write(response.content) return tmp_path except Exception as e: return f"ERROR: could not download file for task {task_id}: {e}" @tool def parse_csv(file_path: str, query: str = "") -> str: """ Load a CSV file from `file_path` and optionally run a simple analysis query. Args: file_path: absolute path to a CSV file (from fetch_gaia_file) query: optional natural-language instruction, e.g. "sum of column Sales where Category != 'Drinks'" Returns: A concise string with the answer OR a preview of the dataframe if no query given. """ try: df = pd.read_csv(file_path) # Auto-preview if kein query if not query: preview = df.head(5).to_markdown(index=False) return f"CSV loaded. First rows:\n\n{preview}\n\nColumns: {', '.join(df.columns)}" # Mini-query-engine (sehr simpel, reicht für Summen / Mittelwerte) query_lc = query.lower() if "sum" in query_lc: # ermitteln, welche Spalte summiert werden soll for col in df.columns: if col.lower() in query_lc: s = df[col] if "where" in query_lc: # naive Filter-Parsing: where != 'Drinks' cond_part = query_lc.split("where", 1)[1].strip() # SEHR einfaches != oder == Parsing if "!=" in cond_part: key, val = [x.strip().strip("'\"") for x in cond_part.split("!=")] s = df.loc[df[key] != val, col] elif "==" in cond_part: key, val = [x.strip().strip("'\"") for x in cond_part.split("==")] s = df.loc[df[key] == val, col] return str(round(s.sum(), 2)) # Fallback return "Query type not supported by parse_csv." except Exception as e: return f"ERROR parsing CSV: {e}" @tool def parse_excel(file_path: str, query: str = "") -> str: """ Identisch zu parse_csv, nur für XLS/XLSX. """ try: df = pd.read_excel(file_path) if not query: preview = df.head(5).to_markdown(index=False) return f"Excel loaded. First rows:\n\n{preview}\n\nColumns: {', '.join(df.columns)}" query_lc = query.lower() if "sum" in query_lc: for col in df.columns: if col.lower() in query_lc: s = df[col] if "where" in query_lc: cond_part = query_lc.split("where", 1)[1].strip() if "!=" in cond_part: key, val = [x.strip().strip("'\"") for x in cond_part.split("!=")] s = df.loc[df[key] != val, col] elif "==" in cond_part: key, val = [x.strip().strip("'\"") for x in cond_part.split("==")] s = df.loc[df[key] == val, col] return str(round(s.sum(), 2)) return "Query type not supported by parse_excel." except Exception as e: return f"ERROR parsing Excel: {e}" @tool def transcribe_audio(file_path: str, language: str = "en") -> str: """ Transcribe an audio file (MP3/WAV/etc.) using Faster-Whisper. Args: file_path: absolute path to an audio file (from fetch_gaia_file) language: ISO language code, default "en" Returns: Full transcription as plain text, or "ERROR …" """ try: from faster_whisper import WhisperModel # Tiny model reicht für kurze Sprachmemos, ~75 MB model = WhisperModel("tiny", device="cpu", compute_type="int8") segments, _ = model.transcribe(file_path, language=language) transcript = " ".join(segment.text.strip() for segment in segments).strip() if not transcript: return "ERROR: transcription empty." return transcript except Exception as e: return f"ERROR: audio transcription failed – {e}" @tool def multiply(a: int, b: int) -> int: """Multiplies two numbers.""" return a * b @tool def add(a: int, b: int) -> int: """Adds two numbers.""" return a + b @tool def subtract(a: int, b: int) -> int: """Subtracts two numbers.""" return a - b @tool def divide(a: int, b: int) -> float: """Divides two numbers.""" if b == 0: raise ValueError("Cannot divide by zero.") return a / b @tool def modulo(a: int, b: int) -> int: """Returns the remainder of dividing two numbers.""" return a % b @tool def wiki_search(query: str) -> str: """Search Wikipedia for a query and return the result.""" search_docs = WikipediaLoader(query=query, load_max_docs=2).load() return "\n\n".join(doc.page_content for doc in search_docs) @tool def arxiv_search(query: str) -> str: """Search Arxiv for academic papers about a query.""" search_docs = ArxivLoader(query=query, load_max_docs=3).load() return "\n\n".join(doc.page_content[:1000] for doc in search_docs) @tool def web_search(query: str) -> str: """Perform a DuckDuckGo web search.""" wrapper = DuckDuckGoSearchAPIWrapper(max_results=5) results = wrapper.run(query) return results # === System Prompt definieren === system_prompt = SystemMessage(content=( system_prompt = SystemMessage( content=( "You are a focused, factual AI agent competing on the GAIA evaluation.\n" "\n" "GENERAL RULES\n" "-------------\n" "1. Always try to answer every question.\n" "2. If you are NOT 100 % certain, prefer using a TOOL.\n" "3. Never invent facts.\n" "\n" "TOOLS\n" "-----\n" "- fetch_gaia_file(task_id): downloads any attachment for the current task.\n" "- parse_csv(file_path, query): analyse CSV files.\n" "- parse_excel(file_path, query): analyse Excel files.\n" "- transcribe_audio(file_path): transcribe MP3 / WAV audio.\n" "- wiki_search(query): query English Wikipedia.\n" "- arxiv_search(query): query arXiv.\n" "- web_search(query): DuckDuckGo web search.\n" "- simple_calculator(operation,a,b): basic maths.\n" "\n" "WHEN TO USE WHICH TOOL\n" "----------------------\n" "・If the prompt or GAIA metadata mentions an *attached* file, FIRST call " "fetch_gaia_file with the given task_id. Then:\n" " • CSV → parse_csv\n" " • XLS/XLSX → parse_excel\n" " • MP3/WAV → transcribe_audio (language auto-detect is OK)\n" " • Image → (currently unsupported) answer that image processing is unavailable\n" "・If you need factual data (dates, numbers, names) → wiki_search or web_search.\n" "・If you need a scientific paper → arxiv_search.\n" "・If a numeric operation is required → simple_calculator.\n" "\n" "ERROR HANDLING\n" "--------------\n" "If a tool call returns a string that starts with \"ERROR:\", IMMEDIATELY think of " "an alternative strategy: retry with a different tool or modified parameters. " "Do not repeat the same failing call twice.\n" "\n" "OUTPUT FORMAT\n" "-------------\n" "Follow the exact format asked in the question (e.g. single word, CSV, comma-list). " "Do not add extra commentary.\n" ) ) )) # === LLM definieren === llm = ChatGoogleGenerativeAI( model="gemini-2.0-flash", google_api_key=GOOGLE_API_KEY, temperature=0, max_output_tokens=2048, system_message=system_prompt, ) # === Tools in LLM einbinden === tools = [ fetch_gaia_file, parse_csv, parse_excel, transcribe_audio, wiki_search, arxiv_search, web_search, simple_calculator, ] llm_with_tools = llm.bind_tools(tools) def safe_llm_invoke(messages): """ Ruft LLM einmal auf. Wenn das Ergebnis mit ERROR beginnt, ruft es genau EIN weiteres Mal auf – jetzt weiß das LLM, dass der vorige Tool-Call fehlgeschlagen ist. """ max_attempts = 2 for attempt in range(max_attempts): result = llm_with_tools.invoke(messages) content = result.content if hasattr(result, "content") else "" if "ERROR:" not in content: return result # Fehler: füge eine System-Korrektur hinzu und versuche erneut messages.append( SystemMessage( content="Previous tool call returned an ERROR. " "Try a different tool or revise the input." ) ) # nach max_attempts immer noch Fehler → zurückgeben return result # === Nodes für LangGraph === def assistant(state: MessagesState): """ Assistant node mit eingebautem Retry bei Tool-Fehlern. """ result_msg = safe_llm_invoke(state["messages"]) return {"messages": [result_msg]} # === LangGraph bauen === builder = StateGraph(MessagesState) builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(tools)) builder.add_edge(START, "assistant") builder.add_conditional_edges("assistant", tools_condition) builder.add_edge("tools", "assistant") # === Agent Executor === agent_executor = builder.compile()