Tesvia's picture
Upload 3 files
6e06cc8 verified
raw
history blame
4.91 kB
"""
GAIA benchmark agent using the OpenAI Agents SDK.
"""
from __future__ import annotations
import asyncio
import os
from typing import Any, Sequence, Callable, List
from datetime import datetime
from agents import RunHooks # for lifecycle hooks
from dotenv import load_dotenv
from agents import Agent, Runner, FunctionTool, Tool
# Import all function tools
from tools import (
python_run,
load_spreadsheet,
youtube_transcript,
transcribe_audio,
image_ocr,
duckduckgo_search,
)
# ---------------------------------------------------------------------------
# Load the added system prompt
# ---------------------------------------------------------------------------
ADDED_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "added_prompt.txt")
with open(ADDED_PROMPT_PATH, "r", encoding="utf-8") as f:
ADDED_PROMPT = f.read().strip()
load_dotenv()
def _select_model() -> str:
"""Return a model identifier appropriate for the Agents SDK based on environment settings."""
provider = os.getenv("MODEL_PROVIDER", "hf").lower()
if provider == "openai":
model_name = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
return f"openai/{model_name}"
if provider == "hf":
hf_model_id = os.getenv("HF_MODEL", "Qwen/Qwen2.5-Coder-32B-Instruct")
return f"litellm/huggingface/{hf_model_id}"
raise ValueError(
f"Unsupported MODEL_PROVIDER: {provider!r}. Expected 'openai' or 'hf'."
)
DEFAULT_TOOLS: List[FunctionTool] = [
python_run,
load_spreadsheet,
youtube_transcript,
transcribe_audio,
image_ocr,
duckduckgo_search,
]
def _build_agent(extra_tools: Sequence[FunctionTool] | None = None) -> Agent:
"""Construct the underlying Agents SDK `Agent` instance."""
instructions = (
"You are a helpful assistant tasked with answering questions using the available tools.\n\n"
+ ADDED_PROMPT
)
tools: Sequence[Tool] = list(DEFAULT_TOOLS)
if extra_tools:
tools = list(tools) + list(extra_tools)
return Agent(
name="GAIA Agent",
instructions=instructions,
tools=tools,
model=_select_model(),
)
class LoggingHooks(RunHooks):
"""RunHooks to log question start, model used, and each tool‐call step."""
def __init__(self):
self.step_counter = 0
async def on_agent_start(self, context, agent):
qnum = context.context.get("question_number")
qtext = context.context.get("question_text")
model = agent.model
ts = datetime.now().isoformat()
print(f"[{ts}] [Question {qnum}] Starting agent (model={model}) for question: '{qtext}'")
async def on_tool_start(self, context, agent, tool):
self.step_counter += 1
qnum = context.context.get("question_number")
ts = datetime.now().isoformat()
print(f"[{ts}] [Question {qnum}] Step {self.step_counter}: Invoking tool '{tool.name}'")
async def on_tool_end(self, context, agent, tool, result):
qnum = context.context.get("question_number")
ts = datetime.now().isoformat()
print(f"[{ts}] [Question {qnum}] Step {self.step_counter}: Tool '{tool.name}' completed")
class GAIAAgent:
"""Thin synchronous wrapper around an asynchronous Agents SDK agent."""
def __init__(self, *, extra_tools: Sequence[FunctionTool] | None = None):
self._agent = _build_agent(extra_tools=extra_tools)
async def _arun(self, question: str, context_data=None, hooks=None) -> str:
# Pass context and hooks to Runner.run if provided
if context_data is not None and hooks is not None:
result = await Runner.run(
self._agent,
question,
context=context_data,
hooks=hooks
)
else:
result = await Runner.run(self._agent, question)
return str(result.final_output).strip()
def __call__(self, question: str, question_number: int | None = None, **_kwargs) -> str:
# Prepare logging context if a question_number is given
context_data = None
hooks = None
if question_number is not None:
context_data = {
"question_number": question_number,
"question_text": question
}
hooks = LoggingHooks()
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# No running loop: use asyncio.run
return asyncio.run(self._arun(question, context_data, hooks))
else:
return loop.run_until_complete(self._arun(question, context_data, hooks))
def gaia_agent(*, extra_tools: Sequence[FunctionTool] | None = None) -> GAIAAgent:
"""Factory returning a ready‑to‑use GAIAAgent instance."""
return GAIAAgent(extra_tools=extra_tools)
__all__ = ["GAIAAgent", "gaia_agent"]