Tesvia's picture
Upload 5 files
4191a9b verified
"""GAIA benchmark agent using *smolagents*.
This module exposes:
* ``gaia_agent()`` – factory returning a ready‑to‑use agent instance.
* ``GAIAAgent`` – subclass of ``smolagents.CodeAgent``.
The LLM backend is chosen at runtime via the ``MODEL_PROVIDER``
environment variable (``hf`` or ``openai``) exactly like *example.py*.
"""
import os
from typing import Any, Sequence
from dotenv import load_dotenv
# SmolAgents Tools
from smolagents import (
CodeAgent,
DuckDuckGoSearchTool,
Tool
)
# Custom Tools from tools.py
from tools import (
PythonRunTool,
ExcelLoaderTool,
YouTubeTranscriptTool,
AudioTranscriptionTool,
SimpleOCRTool,
)
# ---------------------------------------------------------------------------
# Load the added system prompt from system_prompt.txt (located in the same directory)
# ---------------------------------------------------------------------------
ADDED_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "added_prompt.txt")
with open(ADDED_PROMPT_PATH, "r", encoding="utf-8") as f:
ADDED_PROMPT = f.read().strip()
# ---------------------------------------------------------------------------
# Model selection helper
# ---------------------------------------------------------------------------
load_dotenv() # Make sure we read credentials from .env when running locally
def _select_model():
"""Return a smolagents *model* as configured by the ``MODEL_PROVIDER`` env."""
provider = os.getenv("MODEL_PROVIDER", "hf").lower()
if provider == "hf":
from smolagents import InferenceClientModel
hf_model_id = os.getenv("HF_MODEL", "HuggingFaceH4/zephyr-7b-beta")
hf_token = os.getenv("HF_API_KEY")
return InferenceClientModel(
model_id=hf_model_id,
token=hf_token
)
if provider == "openai":
from smolagents import OpenAIServerModel
openai_model_id = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo")
openai_token = os.getenv("OPENAI_API_KEY")
return OpenAIServerModel(
model_id=openai_model_id,
api_key=openai_token
)
raise ValueError(
f"Unsupported MODEL_PROVIDER: {provider!r}. "
"Use 'hf' (default) or 'openai'."
)
# ---------------------------------------------------------------------------
# Core Agent implementation
# ---------------------------------------------------------------------------
DEFAULT_TOOLS = [
DuckDuckGoSearchTool(),
PythonRunTool(),
ExcelLoaderTool(),
YouTubeTranscriptTool(),
AudioTranscriptionTool(),
SimpleOCRTool(),
]
class GAIAAgent(CodeAgent):
def __init__(
self,
tools=None
):
super().__init__(
tools=tools or DEFAULT_TOOLS,
model=_select_model()
)
# Append the additional prompt to the existing system prompt
self.prompt_templates["system_prompt"] += f"\n\n{ADDED_PROMPT}"
# Convenience so the object itself can be *called* directly
def __call__(self, question: str, **kwargs: Any) -> str:
steps = self.run(question, **kwargs)
# If steps is a primitive, just return it
if isinstance(steps, (int, float, str)):
return str(steps).strip()
last_step = None
for step in steps:
last_step = step
# Defensive: handle int/float/str directly
if isinstance(last_step, (int, float, str)):
return str(last_step).strip()
answer = getattr(last_step, "answer", None)
if answer is not None:
return str(answer).strip()
return str(last_step).strip()
# ---------------------------------------------------------------------------
# Factory helpers expected by app.py
# ---------------------------------------------------------------------------
def gaia_agent(*, extra_tools: Sequence[Tool] | None = None) -> GAIAAgent:
# Compose the toolset: always include all default tools, plus any extras
toolset = list(DEFAULT_TOOLS)
if extra_tools:
toolset.extend(extra_tools)
return GAIAAgent(tools=toolset)
__all__ = ["GAIAAgent", "gaia_agent"]