Spaces:
Sleeping
Sleeping
"""GAIA benchmark agent using *smolagents*. | |
This module exposes: | |
* ``gaia_agent()`` – factory returning a ready‑to‑use agent instance. | |
* ``GAIAAgent`` – subclass of ``smolagents.CodeAgent``. | |
The LLM backend is chosen at runtime via the ``MODEL_PROVIDER`` | |
environment variable (``hf`` or ``openai``) exactly like *example.py*. | |
""" | |
import os | |
from typing import Any, Sequence | |
from dotenv import load_dotenv | |
# SmolAgents Tools | |
from smolagents import ( | |
CodeAgent, | |
DuckDuckGoSearchTool, | |
Tool | |
) | |
# Custom Tools from tools.py | |
from tools import ( | |
PythonRunTool, | |
ExcelLoaderTool, | |
YouTubeTranscriptTool, | |
AudioTranscriptionTool, | |
SimpleOCRTool, | |
) | |
# --------------------------------------------------------------------------- | |
# Load the added system prompt from system_prompt.txt (located in the same directory) | |
# --------------------------------------------------------------------------- | |
ADDED_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "added_prompt.txt") | |
with open(ADDED_PROMPT_PATH, "r", encoding="utf-8") as f: | |
ADDED_PROMPT = f.read().strip() | |
# --------------------------------------------------------------------------- | |
# Model selection helper | |
# --------------------------------------------------------------------------- | |
load_dotenv() # Make sure we read credentials from .env when running locally | |
def _select_model(): | |
"""Return a smolagents *model* as configured by the ``MODEL_PROVIDER`` env.""" | |
provider = os.getenv("MODEL_PROVIDER", "hf").lower() | |
if provider == "hf": | |
from smolagents import InferenceClientModel | |
hf_model_id = os.getenv("HF_MODEL", "HuggingFaceH4/zephyr-7b-beta") | |
hf_token = os.getenv("HF_API_KEY") | |
return InferenceClientModel( | |
model_id=hf_model_id, | |
token=hf_token | |
) | |
if provider == "openai": | |
from smolagents import OpenAIServerModel | |
openai_model_id = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo") | |
openai_token = os.getenv("OPENAI_API_KEY") | |
return OpenAIServerModel( | |
model_id=openai_model_id, | |
api_key=openai_token | |
) | |
raise ValueError( | |
f"Unsupported MODEL_PROVIDER: {provider!r}. " | |
"Use 'hf' (default) or 'openai'." | |
) | |
# --------------------------------------------------------------------------- | |
# Core Agent implementation | |
# --------------------------------------------------------------------------- | |
DEFAULT_TOOLS = [ | |
DuckDuckGoSearchTool(), | |
PythonRunTool(), | |
ExcelLoaderTool(), | |
YouTubeTranscriptTool(), | |
AudioTranscriptionTool(), | |
SimpleOCRTool(), | |
] | |
class GAIAAgent(CodeAgent): | |
def __init__( | |
self, | |
tools=None | |
): | |
super().__init__( | |
tools=tools or DEFAULT_TOOLS, | |
model=_select_model() | |
) | |
# Append the additional prompt to the existing system prompt | |
self.prompt_templates["system_prompt"] += f"\n\n{ADDED_PROMPT}" | |
# Convenience so the object itself can be *called* directly | |
def __call__(self, question: str, **kwargs: Any) -> str: | |
steps = self.run(question, **kwargs) | |
# If steps is a primitive, just return it | |
if isinstance(steps, (int, float, str)): | |
return str(steps).strip() | |
last_step = None | |
for step in steps: | |
last_step = step | |
# Defensive: handle int/float/str directly | |
if isinstance(last_step, (int, float, str)): | |
return str(last_step).strip() | |
answer = getattr(last_step, "answer", None) | |
if answer is not None: | |
return str(answer).strip() | |
return str(last_step).strip() | |
# --------------------------------------------------------------------------- | |
# Factory helpers expected by app.py | |
# --------------------------------------------------------------------------- | |
def gaia_agent(*, extra_tools: Sequence[Tool] | None = None) -> GAIAAgent: | |
# Compose the toolset: always include all default tools, plus any extras | |
toolset = list(DEFAULT_TOOLS) | |
if extra_tools: | |
toolset.extend(extra_tools) | |
return GAIAAgent(tools=toolset) | |
__all__ = ["GAIAAgent", "gaia_agent"] | |