File size: 4,103 Bytes
52d1305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb49a20
75a272e
bb49a20
 
 
 
 
 
52d1305
e4ed116
 
 
 
 
 
 
 
 
52d1305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4ed116
 
eea77dd
e4ed116
52d1305
 
 
 
 
e4ed116
 
eea77dd
e4ed116
52d1305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eea77dd
 
 
 
 
 
52d1305
 
 
 
e4ed116
 
 
52d1305
 
 
e4ed116
 
 
52d1305
 
4a751ef
 
52d1305
 
 
 
 
 
bb49a20
 
52d1305
bb49a20
 
52d1305
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""agent.py – GAIA benchmark agent using *smolagents*.

This module exposes:

* ``gaia_agent()`` – factory returning a ready‑to‑use agent instance.
* ``GAIAAgent``  – subclass of ``smolagents.CodeAgent``.

The LLM backend is chosen at runtime via the ``MODEL_PROVIDER``
environment variable (``hf`` or ``openai``) exactly like *example.py*.
"""

import os
from typing import Any, Sequence

from dotenv import load_dotenv

# SmolAgents Tools
from smolagents import (
    CodeAgent,
    DuckDuckGoSearchTool,
    Tool,
)

# Custom Tools from tools.py
from tools import (
    PythonRunTool,
    ExcelLoaderTool,
    YouTubeTranscriptTool,
    AudioTranscriptionTool,
    SimpleOCRTool,
)

 
# ---------------------------------------------------------------------------
# Load the system prompt from system_prompt.txt (located in the same directory)
# ---------------------------------------------------------------------------
SYSTEM_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "system_prompt.txt")
with open(SYSTEM_PROMPT_PATH, "r", encoding="utf-8") as f:
    SYSTEM_PROMPT = f.read().strip()


# ---------------------------------------------------------------------------
# Model selection helper
# ---------------------------------------------------------------------------

load_dotenv()  # Make sure we read credentials from .env when running locally

def _select_model():
    """Return a smolagents *model* as configured by the ``MODEL_PROVIDER`` env."""

    provider = os.getenv("MODEL_PROVIDER", "hf").lower()

    if provider == "hf":
        from smolagents import InferenceClientModel
        hf_model_id = os.getenv("HF_MODEL", "HuggingFaceH4/zephyr-7b-beta")
        hf_token = os.getenv("HF_API_KEY")
        return InferenceClientModel(
            model_id=hf_model_id,
            token=hf_token
        )

    if provider == "openai":
        from smolagents import OpenAIServerModel
        openai_model_id = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo")
        openai_token = os.getenv("OPENAI_API_KEY")
        return OpenAIServerModel(
            model_id=openai_model_id,
            api_key=openai_token
        )

    raise ValueError(
        f"Unsupported MODEL_PROVIDER: {provider!r}. "
        "Use 'hf' (default) or 'openai'."
    )

# ---------------------------------------------------------------------------
# Core Agent implementation
# ---------------------------------------------------------------------------

DEFAULT_TOOLS = [
    DuckDuckGoSearchTool(),
    PythonRunTool(),
    ExcelLoaderTool(),
    YouTubeTranscriptTool(),
    AudioTranscriptionTool(),
    SimpleOCRTool(),
]

class GAIAAgent(CodeAgent):
    def __init__(self, tools=None, *, system_prompt: str = SYSTEM_PROMPT):
        super().__init__(
            tools=tools or DEFAULT_TOOLS,
            model=_select_model(),
            system_prompt=system_prompt
        )

    # Convenience so the object itself can be *called* directly
    def __call__(self, question: str, **kwargs: Any) -> str:
        steps = self.run(question, **kwargs)
        # If steps is a primitive, just return it
        if isinstance(steps, (int, float, str)):
            return str(steps).strip()
        last_step = None
        for step in steps:
            last_step = step
        # Defensive: handle int/float/str directly
        if isinstance(last_step, (int, float, str)):
            return str(last_step).strip()
        answer = getattr(last_step, "answer", None)
        if answer is not None:
            return str(answer).strip()
        return str(last_step).strip()

# ---------------------------------------------------------------------------
# Factory helpers expected by app.py
# ---------------------------------------------------------------------------

def gaia_agent(*, extra_tools: Sequence[Tool] | None = None) -> GAIAAgent:
    # Compose the toolset: always include all default tools, plus any extras
    toolset = list(DEFAULT_TOOLS)
    if extra_tools:
        toolset.extend(extra_tools)
    return GAIAAgent(tools=toolset)

__all__ = ["GAIAAgent", "gaia_agent"]