Tesvia's picture
Upload 3 files
790cac2 verified
raw
history blame
3.91 kB
"""
GAIA benchmark agent using the OpenAI Agents SDK.
"""
from __future__ import annotations
import asyncio
import os
import time
import datetime
from typing import Any, Sequence, Callable, List, Optional
from dotenv import load_dotenv
from agents import Agent, Runner, FunctionTool, Tool
# Import all function tools
from tools import (
python_run,
load_spreadsheet,
youtube_transcript,
transcribe_audio,
image_ocr,
duckduckgo_search,
)
# ---------------------------------------------------------------------------
# Logging Utility
# ---------------------------------------------------------------------------
def log(msg):
print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] {msg}")
# ---------------------------------------------------------------------------
# Load the added system prompt
# ---------------------------------------------------------------------------
ADDED_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "added_prompt.txt")
with open(ADDED_PROMPT_PATH, "r", encoding="utf-8") as f:
ADDED_PROMPT = f.read().strip()
load_dotenv()
def _select_model() -> str:
"""Return a model identifier appropriate for the Agents SDK based on environment settings."""
provider = os.getenv("MODEL_PROVIDER", "hf").lower()
if provider == "openai":
model_name = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
return f"openai/{model_name}"
if provider == "hf":
hf_model_id = os.getenv("HF_MODEL", "Qwen/Qwen2.5-Coder-32B-Instruct")
return f"litellm/huggingface/{hf_model_id}"
raise ValueError(
f"Unsupported MODEL_PROVIDER: {provider!r}. Expected 'openai' or 'hf'."
)
DEFAULT_TOOLS: List[FunctionTool] = [
python_run,
load_spreadsheet,
youtube_transcript,
transcribe_audio,
image_ocr,
duckduckgo_search,
]
def _build_agent(extra_tools: Sequence[FunctionTool] | None = None) -> Agent:
"""Construct the underlying Agents SDK `Agent` instance."""
instructions = (
"You are a helpful assistant tasked with answering questions using the available tools.\n\n"
+ ADDED_PROMPT
)
tools: Sequence[Tool] = list(DEFAULT_TOOLS)
if extra_tools:
tools = list(tools) + list(extra_tools)
return Agent(
name="GAIA Agent",
instructions=instructions,
tools=tools,
model=_select_model(),
)
class GAIAAgent:
"""Thin synchronous wrapper around an asynchronous Agents SDK agent."""
def __init__(self, *, extra_tools: Sequence[FunctionTool] | None = None):
self._agent = _build_agent(extra_tools=extra_tools)
# Store the model id for logging
self.model_id = _select_model()
async def _arun(self, question: str, q_index: Optional[int] = None) -> str:
q_num = q_index + 1 if q_index is not None else "?"
log(f"Answering question {q_num}:")
log(f" Question: {question!r}")
log(f" Model: {self.model_id}")
t0 = time.time()
try:
result = await Runner.run(self._agent, question)
duration = time.time() - t0
log(f" Total duration: {duration:.2f} seconds.")
except Exception as e:
log(f" Error during answer: {e}")
raise
return str(result.final_output).strip()
def __call__(self, question: str, q_index: Optional[int] = None, **kwargs: Any) -> str:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(self._arun(question, q_index=q_index))
else:
return loop.run_until_complete(self._arun(question, q_index=q_index))
def gaia_agent(*, extra_tools: Sequence[FunctionTool] | None = None) -> GAIAAgent:
"""Factory returning a ready‑to‑use GAIAAgent instance."""
return GAIAAgent(extra_tools=extra_tools)
__all__ = ["GAIAAgent", "gaia_agent"]