File size: 3,009 Bytes
52d1305
9623335
 
 
 
52d1305
9623335
52d1305
9623335
52d1305
 
9623335
52d1305
9623335
75a272e
73bb16b
 
 
 
 
9623335
bb49a20
52d1305
e4ed116
9623335
e4ed116
db0abac
 
 
e4ed116
9623335
52d1305
 
9623335
 
52d1305
 
 
9623335
 
 
 
 
 
52d1305
 
9623335
52d1305
 
 
9623335
73bb16b
 
 
 
 
9623335
52d1305
 
9623335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73bb16b
9623335
 
 
 
 
 
 
 
 
 
 
 
 
 
73bb16b
9623335
52d1305
 
9623335
 
 
 
52d1305
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
GAIA benchmark agent using the OpenAI Agents SDK.
"""

from __future__ import annotations

import asyncio
import os
from typing import Any, Sequence, Callable, List

from dotenv import load_dotenv
from agents import Agent, Runner, FunctionTool, Tool

# Import all function tools
from tools import (
    python_run,
    load_spreadsheet,
    youtube_transcript,
    transcribe_audio,
    image_ocr,
    duckduckgo_search,
)

# ---------------------------------------------------------------------------
# Load the added system prompt
# ---------------------------------------------------------------------------
ADDED_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "added_prompt.txt")
with open(ADDED_PROMPT_PATH, "r", encoding="utf-8") as f:
    ADDED_PROMPT = f.read().strip()

load_dotenv()


def _select_model() -> str:
    """Return a model identifier appropriate for the Agents SDK based on environment settings."""
    provider = os.getenv("MODEL_PROVIDER", "hf").lower()

    if provider == "openai":
        model_name = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
        return f"openai/{model_name}"

    if provider == "hf":
        hf_model_id = os.getenv("HF_MODEL", "Qwen/Qwen2.5-Coder-32B-Instruct")
        return f"litellm/huggingface/{hf_model_id}"

    raise ValueError(
        f"Unsupported MODEL_PROVIDER: {provider!r}. Expected 'openai' or 'hf'."
    )


DEFAULT_TOOLS: List[FunctionTool] = [
    python_run,
    load_spreadsheet,
    youtube_transcript,
    transcribe_audio,
    image_ocr,
    duckduckgo_search,
]


def _build_agent(extra_tools: Sequence[FunctionTool] | None = None) -> Agent:
    """Construct the underlying Agents SDK `Agent` instance."""
    instructions = (
        "You are a helpful assistant tasked with answering questions using the available tools.\n\n"
        + ADDED_PROMPT
    )

    tools: Sequence[Tool] = list(DEFAULT_TOOLS)
    if extra_tools:
        tools = list(tools) + list(extra_tools)

    return Agent(
        name="GAIA Agent",
        instructions=instructions,
        tools=tools,
        model=_select_model(),
    )


class GAIAAgent:
    """Thin synchronous wrapper around an asynchronous Agents SDK agent."""

    def __init__(self, *, extra_tools: Sequence[FunctionTool] | None = None):
        self._agent = _build_agent(extra_tools=extra_tools)

    async def _arun(self, question: str) -> str:
        result = await Runner.run(self._agent, question)
        return str(result.final_output).strip()

    def __call__(self, question: str, **kwargs: Any) -> str:
        try:
            loop = asyncio.get_running_loop()
        except RuntimeError:
            return asyncio.run(self._arun(question))
        else:
            return loop.run_until_complete(self._arun(question))


def gaia_agent(*, extra_tools: Sequence[FunctionTool] | None = None) -> GAIAAgent:
    """Factory returning a ready‑to‑use GAIAAgent instance."""
    return GAIAAgent(extra_tools=extra_tools)


__all__ = ["GAIAAgent", "gaia_agent"]