Tesvia's picture
Upload agent.py
75a272e verified
raw
history blame
3.27 kB
"""agent.py – GAIA benchmark agent using *smolagents*.
This module exposes:
* ``gaia_agent()`` – factory returning a ready‑to‑use agent instance.
* ``GAIAAgent`` – subclass of ``smolagents.CodeAgent``.
The LLM backend is chosen at runtime via the ``MODEL_PROVIDER``
environment variable (``hf`` or ``openai``) exactly like *example.py*.
"""
import os
from typing import Any, Sequence
from dotenv import load_dotenv
# SmolAgents Tools
from smolagents import (
CodeAgent,
DuckDuckGoSearchTool,
Tool,
)
# Custom Tools from tools.py
from tools import (
PythonRunTool,
ExcelLoaderTool,
YouTubeTranscriptTool,
AudioTranscriptionTool,
SimpleOCRTool,
)
# ---------------------------------------------------------------------------
# Model selection helper
# ---------------------------------------------------------------------------
load_dotenv() # Make sure we read credentials from .env when running locally
def _select_model():
"""Return a smolagents *model* as configured by the ``MODEL_PROVIDER`` env."""
provider = os.getenv("MODEL_PROVIDER", "hf").lower()
if provider == "hf":
from smolagents import InferenceClientModel
hf_model_id = os.getenv("HF_MODEL", "HuggingFaceH4/zephyr-7b-beta")
hf_token = os.getenv("HF_API_KEY")
return InferenceClientModel(model_id=hf_model_id, token=hf_token)
if provider == "openai":
from smolagents import OpenAIServerModel
openai_model_id = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo")
openai_token = os.getenv("OPENAI_API_KEY")
return OpenAIServerModel(model_id=openai_model_id, api_key=openai_token)
raise ValueError(
f"Unsupported MODEL_PROVIDER: {provider!r}. "
"Use 'hf' (default) or 'openai'."
)
# ---------------------------------------------------------------------------
# Core Agent implementation
# ---------------------------------------------------------------------------
DEFAULT_TOOLS = [
DuckDuckGoSearchTool(),
PythonRunTool(),
ExcelLoaderTool(),
YouTubeTranscriptTool(),
AudioTranscriptionTool(),
SimpleOCRTool(),
]
class GAIAAgent(CodeAgent):
def __init__(self, tools=None):
super().__init__(tools=tools or DEFAULT_TOOLS, model=_select_model())
# Convenience so the object itself can be *called* directly
def __call__(self, question: str, **kwargs: Any) -> str:
steps = self.run(question, **kwargs)
last_step = None
for step in steps:
last_step = step
# If last_step is a FinalAnswerStep with .answer, return it
answer = getattr(last_step, "answer", None)
if answer is not None:
return str(answer)
return str(last_step)
# ---------------------------------------------------------------------------
# Factory helpers expected by app.py
# ---------------------------------------------------------------------------
def gaia_agent(*, extra_tools: Sequence[Tool] | None = None) -> GAIAAgent:
# Compose the toolset: always include all default tools, plus any extras
toolset = list(DEFAULT_TOOLS)
if extra_tools:
toolset.extend(extra_tools)
return GAIAAgent(tools=toolset)
__all__ = ["GAIAAgent", "gaia_agent"]