Spaces:
Sleeping
Sleeping
""" | |
GAIA benchmark agent using the OpenAI Agents SDK. | |
""" | |
from __future__ import annotations | |
import asyncio | |
import os | |
from typing import Any, Sequence, Callable, List | |
from datetime import datetime | |
from agents import RunHooks # for lifecycle hooks | |
from dotenv import load_dotenv | |
from agents import Agent, Runner, FunctionTool, Tool | |
# Import all function tools | |
from tools import ( | |
python_run, | |
load_spreadsheet, | |
youtube_transcript, | |
transcribe_audio, | |
image_ocr, | |
duckduckgo_search, | |
) | |
# --------------------------------------------------------------------------- | |
# Load the added system prompt | |
# --------------------------------------------------------------------------- | |
ADDED_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "added_prompt.txt") | |
with open(ADDED_PROMPT_PATH, "r", encoding="utf-8") as f: | |
ADDED_PROMPT = f.read().strip() | |
load_dotenv() | |
def _select_model() -> str: | |
"""Return a model identifier appropriate for the Agents SDK based on environment settings.""" | |
provider = os.getenv("MODEL_PROVIDER", "hf").lower() | |
if provider == "openai": | |
model_name = os.getenv("OPENAI_MODEL", "gpt-4o-mini") | |
return f"openai/{model_name}" | |
if provider == "hf": | |
hf_model_id = os.getenv("HF_MODEL", "Qwen/Qwen2.5-Coder-32B-Instruct") | |
return f"litellm/huggingface/{hf_model_id}" | |
raise ValueError( | |
f"Unsupported MODEL_PROVIDER: {provider!r}. Expected 'openai' or 'hf'." | |
) | |
DEFAULT_TOOLS: List[FunctionTool] = [ | |
python_run, | |
load_spreadsheet, | |
youtube_transcript, | |
transcribe_audio, | |
image_ocr, | |
duckduckgo_search, | |
] | |
def _build_agent(extra_tools: Sequence[FunctionTool] | None = None) -> Agent: | |
"""Construct the underlying Agents SDK `Agent` instance.""" | |
instructions = ( | |
"You are a helpful assistant tasked with answering questions using the available tools.\n\n" | |
+ ADDED_PROMPT | |
) | |
tools: Sequence[Tool] = list(DEFAULT_TOOLS) | |
if extra_tools: | |
tools = list(tools) + list(extra_tools) | |
return Agent( | |
name="GAIA Agent", | |
instructions=instructions, | |
tools=tools, | |
model=_select_model(), | |
) | |
class LoggingHooks(RunHooks): | |
"""RunHooks to log question start, model used, and each tool‐call step.""" | |
def __init__(self): | |
self.step_counter = 0 | |
async def on_agent_start(self, context, agent): | |
qnum = context.context.get("question_number") | |
qtext = context.context.get("question_text") | |
model = agent.model | |
ts = datetime.now().isoformat() | |
print(f"[{ts}] [Question {qnum}] Starting agent (model={model}) for question: '{qtext}'") | |
async def on_tool_start(self, context, agent, tool): | |
self.step_counter += 1 | |
qnum = context.context.get("question_number") | |
ts = datetime.now().isoformat() | |
print(f"[{ts}] [Question {qnum}] Step {self.step_counter}: Invoking tool '{tool.name}'") | |
async def on_tool_end(self, context, agent, tool, result): | |
qnum = context.context.get("question_number") | |
ts = datetime.now().isoformat() | |
print(f"[{ts}] [Question {qnum}] Step {self.step_counter}: Tool '{tool.name}' completed") | |
class GAIAAgent: | |
"""Thin synchronous wrapper around an asynchronous Agents SDK agent.""" | |
def __init__(self, *, extra_tools: Sequence[FunctionTool] | None = None): | |
self._agent = _build_agent(extra_tools=extra_tools) | |
async def _arun(self, question: str, context_data=None, hooks=None) -> str: | |
# Pass context and hooks to Runner.run if provided | |
if context_data is not None and hooks is not None: | |
result = await Runner.run( | |
self._agent, | |
question, | |
context=context_data, | |
hooks=hooks | |
) | |
else: | |
result = await Runner.run(self._agent, question) | |
return str(result.final_output).strip() | |
def __call__(self, question: str, question_number: int | None = None, **_kwargs) -> str: | |
# Prepare logging context if a question_number is given | |
context_data = None | |
hooks = None | |
if question_number is not None: | |
context_data = { | |
"question_number": question_number, | |
"question_text": question | |
} | |
hooks = LoggingHooks() | |
try: | |
loop = asyncio.get_running_loop() | |
except RuntimeError: | |
# No running loop: use asyncio.run | |
return asyncio.run(self._arun(question, context_data, hooks)) | |
else: | |
return loop.run_until_complete(self._arun(question, context_data, hooks)) | |
def gaia_agent(*, extra_tools: Sequence[FunctionTool] | None = None) -> GAIAAgent: | |
"""Factory returning a ready‑to‑use GAIAAgent instance.""" | |
return GAIAAgent(extra_tools=extra_tools) | |
__all__ = ["GAIAAgent", "gaia_agent"] | |