Spaces:
Sleeping
Sleeping
File size: 4,174 Bytes
4191a9b 9623335 4191a9b 52d1305 4191a9b 52d1305 4191a9b 75a272e 4191a9b bb49a20 52d1305 4191a9b e4ed116 4191a9b e4ed116 db0abac e4ed116 52d1305 4191a9b 52d1305 4191a9b 52d1305 4191a9b 9623335 4191a9b 52d1305 4191a9b 52d1305 4191a9b 52d1305 4191a9b 52d1305 4191a9b 9623335 4191a9b 9623335 4191a9b 9623335 4191a9b 52d1305 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""GAIA benchmark agent using *smolagents*.
This module exposes:
* ``gaia_agent()`` – factory returning a ready‑to‑use agent instance.
* ``GAIAAgent`` – subclass of ``smolagents.CodeAgent``.
The LLM backend is chosen at runtime via the ``MODEL_PROVIDER``
environment variable (``hf`` or ``openai``) exactly like *example.py*.
"""
import os
from typing import Any, Sequence
from dotenv import load_dotenv
# SmolAgents Tools
from smolagents import (
CodeAgent,
DuckDuckGoSearchTool,
Tool
)
# Custom Tools from tools.py
from tools import (
PythonRunTool,
ExcelLoaderTool,
YouTubeTranscriptTool,
AudioTranscriptionTool,
SimpleOCRTool,
)
# ---------------------------------------------------------------------------
# Load the added system prompt from system_prompt.txt (located in the same directory)
# ---------------------------------------------------------------------------
ADDED_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "added_prompt.txt")
with open(ADDED_PROMPT_PATH, "r", encoding="utf-8") as f:
ADDED_PROMPT = f.read().strip()
# ---------------------------------------------------------------------------
# Model selection helper
# ---------------------------------------------------------------------------
load_dotenv() # Make sure we read credentials from .env when running locally
def _select_model():
"""Return a smolagents *model* as configured by the ``MODEL_PROVIDER`` env."""
provider = os.getenv("MODEL_PROVIDER", "hf").lower()
if provider == "hf":
from smolagents import InferenceClientModel
hf_model_id = os.getenv("HF_MODEL", "HuggingFaceH4/zephyr-7b-beta")
hf_token = os.getenv("HF_API_KEY")
return InferenceClientModel(
model_id=hf_model_id,
token=hf_token
)
if provider == "openai":
from smolagents import OpenAIServerModel
openai_model_id = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo")
openai_token = os.getenv("OPENAI_API_KEY")
return OpenAIServerModel(
model_id=openai_model_id,
api_key=openai_token
)
raise ValueError(
f"Unsupported MODEL_PROVIDER: {provider!r}. "
"Use 'hf' (default) or 'openai'."
)
# ---------------------------------------------------------------------------
# Core Agent implementation
# ---------------------------------------------------------------------------
DEFAULT_TOOLS = [
DuckDuckGoSearchTool(),
PythonRunTool(),
ExcelLoaderTool(),
YouTubeTranscriptTool(),
AudioTranscriptionTool(),
SimpleOCRTool(),
]
class GAIAAgent(CodeAgent):
def __init__(
self,
tools=None
):
super().__init__(
tools=tools or DEFAULT_TOOLS,
model=_select_model()
)
# Append the additional prompt to the existing system prompt
self.prompt_templates["system_prompt"] += f"\n\n{ADDED_PROMPT}"
# Convenience so the object itself can be *called* directly
def __call__(self, question: str, **kwargs: Any) -> str:
steps = self.run(question, **kwargs)
# If steps is a primitive, just return it
if isinstance(steps, (int, float, str)):
return str(steps).strip()
last_step = None
for step in steps:
last_step = step
# Defensive: handle int/float/str directly
if isinstance(last_step, (int, float, str)):
return str(last_step).strip()
answer = getattr(last_step, "answer", None)
if answer is not None:
return str(answer).strip()
return str(last_step).strip()
# ---------------------------------------------------------------------------
# Factory helpers expected by app.py
# ---------------------------------------------------------------------------
def gaia_agent(*, extra_tools: Sequence[Tool] | None = None) -> GAIAAgent:
# Compose the toolset: always include all default tools, plus any extras
toolset = list(DEFAULT_TOOLS)
if extra_tools:
toolset.extend(extra_tools)
return GAIAAgent(tools=toolset)
__all__ = ["GAIAAgent", "gaia_agent"]
|