File size: 3,914 Bytes
52d1305
9623335
 
 
 
52d1305
9623335
52d1305
790cac2
 
 
52d1305
 
9623335
52d1305
9623335
75a272e
73bb16b
 
 
 
 
9623335
bb49a20
52d1305
790cac2
 
 
 
 
 
e4ed116
9623335
e4ed116
db0abac
 
 
e4ed116
9623335
52d1305
 
9623335
 
52d1305
 
 
9623335
 
 
 
 
 
52d1305
 
9623335
52d1305
 
 
9623335
73bb16b
 
 
 
 
9623335
52d1305
 
9623335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73bb16b
9623335
 
 
 
790cac2
 
9623335
790cac2
 
 
 
 
 
 
 
 
 
 
 
 
 
9623335
 
790cac2
9623335
 
 
790cac2
73bb16b
790cac2
52d1305
 
9623335
 
 
 
52d1305
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
GAIA benchmark agent using the OpenAI Agents SDK.
"""

from __future__ import annotations

import asyncio
import os
import time
import datetime
from typing import Any, Sequence, Callable, List, Optional

from dotenv import load_dotenv
from agents import Agent, Runner, FunctionTool, Tool

# Import all function tools
from tools import (
    python_run,
    load_spreadsheet,
    youtube_transcript,
    transcribe_audio,
    image_ocr,
    duckduckgo_search,
)

# ---------------------------------------------------------------------------
# Logging Utility
# ---------------------------------------------------------------------------
def log(msg):
    print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] {msg}")

# ---------------------------------------------------------------------------
# Load the added system prompt
# ---------------------------------------------------------------------------
ADDED_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "added_prompt.txt")
with open(ADDED_PROMPT_PATH, "r", encoding="utf-8") as f:
    ADDED_PROMPT = f.read().strip()

load_dotenv()


def _select_model() -> str:
    """Return a model identifier appropriate for the Agents SDK based on environment settings."""
    provider = os.getenv("MODEL_PROVIDER", "hf").lower()

    if provider == "openai":
        model_name = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
        return f"openai/{model_name}"

    if provider == "hf":
        hf_model_id = os.getenv("HF_MODEL", "Qwen/Qwen2.5-Coder-32B-Instruct")
        return f"litellm/huggingface/{hf_model_id}"

    raise ValueError(
        f"Unsupported MODEL_PROVIDER: {provider!r}. Expected 'openai' or 'hf'."
    )


DEFAULT_TOOLS: List[FunctionTool] = [
    python_run,
    load_spreadsheet,
    youtube_transcript,
    transcribe_audio,
    image_ocr,
    duckduckgo_search,
]


def _build_agent(extra_tools: Sequence[FunctionTool] | None = None) -> Agent:
    """Construct the underlying Agents SDK `Agent` instance."""
    instructions = (
        "You are a helpful assistant tasked with answering questions using the available tools.\n\n"
        + ADDED_PROMPT
    )

    tools: Sequence[Tool] = list(DEFAULT_TOOLS)
    if extra_tools:
        tools = list(tools) + list(extra_tools)

    return Agent(
        name="GAIA Agent",
        instructions=instructions,
        tools=tools,
        model=_select_model(),
    )


class GAIAAgent:
    """Thin synchronous wrapper around an asynchronous Agents SDK agent."""

    def __init__(self, *, extra_tools: Sequence[FunctionTool] | None = None):
        self._agent = _build_agent(extra_tools=extra_tools)
        # Store the model id for logging
        self.model_id = _select_model()

    async def _arun(self, question: str, q_index: Optional[int] = None) -> str:
        q_num = q_index + 1 if q_index is not None else "?"
        log(f"Answering question {q_num}:")
        log(f"    Question: {question!r}")
        log(f"    Model: {self.model_id}")

        t0 = time.time()
        try:
            result = await Runner.run(self._agent, question)
            duration = time.time() - t0
            log(f"    Total duration: {duration:.2f} seconds.")
        except Exception as e:
            log(f"    Error during answer: {e}")
            raise
        return str(result.final_output).strip()

    def __call__(self, question: str, q_index: Optional[int] = None, **kwargs: Any) -> str:
        try:
            loop = asyncio.get_running_loop()
        except RuntimeError:
            return asyncio.run(self._arun(question, q_index=q_index))
        else:
            return loop.run_until_complete(self._arun(question, q_index=q_index))


def gaia_agent(*, extra_tools: Sequence[FunctionTool] | None = None) -> GAIAAgent:
    """Factory returning a ready‑to‑use GAIAAgent instance."""
    return GAIAAgent(extra_tools=extra_tools)


__all__ = ["GAIAAgent", "gaia_agent"]