Tesvia's picture
Upload 5 files
73bb16b verified
raw
history blame
5.99 kB
"""GAIA benchmark agent using OpenAI Agents SDK.
This module exposes:
* ``gaia_agent()`` – factory returning a ready‑to‑use agent instance.
* ``GAIAAgent`` – a class that wraps ``openai_agents.Agent``.
The LLM backend is chosen at runtime via the ``MODEL_PROVIDER``
environment variable (``hf`` or ``openai``).
"""
import os
import asyncio # Added for potential direct asyncio.run if needed, and for async def
from typing import Any, Sequence, Callable, Union # Added Callable and Union
from dotenv import load_dotenv
# OpenAI Agents SDK imports
from openai_agents import Agent, Runner
from openai_agents.models.openai_chat_completions import OpenAIChatCompletionsModel
from openai_agents.extensions.models.litellm_model import LitellmModel
# FunctionToolType could be imported if it's a public type, for now using Callable
# from openai_agents import FunctionToolType # Example if such type exists
# Custom Tools from tools.py (now functions)
from tools import (
python_run,
load_spreadsheet,
youtube_transcript,
transcribe_audio,
image_ocr,
duckduckgo_search, # Added the new tool
)
# ---------------------------------------------------------------------------
# Load the added system prompt from system_prompt.txt (located in the same directory)
# ---------------------------------------------------------------------------
ADDED_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "added_prompt.txt")
with open(ADDED_PROMPT_PATH, "r", encoding="utf-8") as f:
ADDED_PROMPT = f.read().strip()
# ---------------------------------------------------------------------------
# Model selection helper
# ---------------------------------------------------------------------------
load_dotenv() # Make sure we read credentials from .env
def _select_model() -> Union[OpenAIChatCompletionsModel, LitellmModel]:
"""Return an OpenAI Agents SDK model instance as configured by env variables."""
provider = os.getenv("MODEL_PROVIDER", "hf").lower()
# Ensure API keys are loaded if not directly passed to model constructors
# OpenAI API key is typically read by the library from OPENAI_API_KEY env var
# LiteLLM also often relies on environment variables for keys
if provider == "hf":
hf_model_id = os.getenv("HF_MODEL", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO") # Example, ensure this is a valid LiteLLM model ID
# LiteLLM typically requires a prefix for HuggingFace models
if not hf_model_id.startswith("huggingface/"):
hf_model_id = f"huggingface/{hf_model_id}"
hf_token = os.getenv("HF_API_KEY") # LiteLLM might use this or HUGGINGFACE_API_KEY
# For LiteLLM, api_key parameter might be used for specific providers,
# but often it relies on env vars like HUGGINGFACE_API_KEY.
# Passing token explicitly if LitellmModel supports it, or ensuring env var is set.
return LitellmModel(model=hf_model_id, api_key=hf_token if hf_token else None)
if provider == "openai":
openai_model_id = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo")
openai_token = os.getenv("OPENAI_API_KEY") # OpenAIChatCompletionsModel will use this by default if set in env
return OpenAIChatCompletionsModel(
model=openai_model_id,
api_key=openai_token # Explicitly passing, though often picked from env
)
raise ValueError(
f"Unsupported MODEL_PROVIDER: {provider!r}. "
"Use 'hf' (default) or 'openai'."
)
# ---------------------------------------------------------------------------
# Core Agent implementation
# ---------------------------------------------------------------------------
DEFAULT_TOOLS: Sequence[Callable] = [
duckduckgo_search,
python_run,
load_spreadsheet,
youtube_transcript,
transcribe_audio,
image_ocr,
]
class GAIAAgent:
def __init__(
self,
tools: Sequence[Callable] | None = None
):
self.model = _select_model()
self.tools = tools or DEFAULT_TOOLS
base_system_prompt = "You are a helpful assistant designed to answer questions and complete tasks. You have access to a variety of tools to help you."
full_system_prompt = f"{base_system_prompt}\n\n{ADDED_PROMPT}"
self.agent = Agent(
model=self.model,
tools=self.tools,
instructions=full_system_prompt,
name="GAIAAgent"
)
async def __call__(self, question: str, **kwargs: Any) -> str:
"""
Asynchronously processes a question using the agent and returns the final answer.
kwargs are passed to Runner.run if supported, currently ignored as per plan.
"""
# As per plan, Runner.run(self.agent, question) is used.
# If session_id or other kwargs are needed by Runner.run, this might need adjustment.
response = await Runner.run(self.agent, question)
# Extract the final output. Assuming response.final_output is the way.
# The type of final_output needs to be handled (e.g. if it's a message object or just text)
final_answer = response.final_output
if hasattr(final_answer, 'content'): # Example if final_output is a message object
final_answer_text = str(final_answer.content)
else:
final_answer_text = str(final_answer)
return final_answer_text.strip()
# ---------------------------------------------------------------------------
# Factory helpers expected by app.py
# ---------------------------------------------------------------------------
def gaia_agent(*, extra_tools: Sequence[Callable] | None = None) -> GAIAAgent:
"""
Factory function to create a GAIAAgent instance with default and optional extra tools.
"""
toolset = list(DEFAULT_TOOLS)
if extra_tools:
toolset.extend(extra_tools)
return GAIAAgent(tools=toolset)
__all__ = ["GAIAAgent", "gaia_agent"]