Spaces:
Sleeping
Sleeping
"""GAIA benchmark agent using OpenAI Agents SDK. | |
This module exposes: | |
* ``gaia_agent()`` – factory returning a ready‑to‑use agent instance. | |
* ``GAIAAgent`` – a class that wraps ``openai_agents.Agent``. | |
The LLM backend is chosen at runtime via the ``MODEL_PROVIDER`` | |
environment variable (``hf`` or ``openai``). | |
""" | |
import os | |
import asyncio # Added for potential direct asyncio.run if needed, and for async def | |
from typing import Any, Sequence, Callable, Union # Added Callable and Union | |
from dotenv import load_dotenv | |
# OpenAI Agents SDK imports | |
from openai_agents import Agent, Runner | |
from openai_agents.models.openai_chat_completions import OpenAIChatCompletionsModel | |
from openai_agents.extensions.models.litellm_model import LitellmModel | |
# FunctionToolType could be imported if it's a public type, for now using Callable | |
# from openai_agents import FunctionToolType # Example if such type exists | |
# Custom Tools from tools.py (now functions) | |
from tools import ( | |
python_run, | |
load_spreadsheet, | |
youtube_transcript, | |
transcribe_audio, | |
image_ocr, | |
duckduckgo_search, # Added the new tool | |
) | |
# --------------------------------------------------------------------------- | |
# Load the added system prompt from system_prompt.txt (located in the same directory) | |
# --------------------------------------------------------------------------- | |
ADDED_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "added_prompt.txt") | |
with open(ADDED_PROMPT_PATH, "r", encoding="utf-8") as f: | |
ADDED_PROMPT = f.read().strip() | |
# --------------------------------------------------------------------------- | |
# Model selection helper | |
# --------------------------------------------------------------------------- | |
load_dotenv() # Make sure we read credentials from .env | |
def _select_model() -> Union[OpenAIChatCompletionsModel, LitellmModel]: | |
"""Return an OpenAI Agents SDK model instance as configured by env variables.""" | |
provider = os.getenv("MODEL_PROVIDER", "hf").lower() | |
# Ensure API keys are loaded if not directly passed to model constructors | |
# OpenAI API key is typically read by the library from OPENAI_API_KEY env var | |
# LiteLLM also often relies on environment variables for keys | |
if provider == "hf": | |
hf_model_id = os.getenv("HF_MODEL", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO") # Example, ensure this is a valid LiteLLM model ID | |
# LiteLLM typically requires a prefix for HuggingFace models | |
if not hf_model_id.startswith("huggingface/"): | |
hf_model_id = f"huggingface/{hf_model_id}" | |
hf_token = os.getenv("HF_API_KEY") # LiteLLM might use this or HUGGINGFACE_API_KEY | |
# For LiteLLM, api_key parameter might be used for specific providers, | |
# but often it relies on env vars like HUGGINGFACE_API_KEY. | |
# Passing token explicitly if LitellmModel supports it, or ensuring env var is set. | |
return LitellmModel(model=hf_model_id, api_key=hf_token if hf_token else None) | |
if provider == "openai": | |
openai_model_id = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo") | |
openai_token = os.getenv("OPENAI_API_KEY") # OpenAIChatCompletionsModel will use this by default if set in env | |
return OpenAIChatCompletionsModel( | |
model=openai_model_id, | |
api_key=openai_token # Explicitly passing, though often picked from env | |
) | |
raise ValueError( | |
f"Unsupported MODEL_PROVIDER: {provider!r}. " | |
"Use 'hf' (default) or 'openai'." | |
) | |
# --------------------------------------------------------------------------- | |
# Core Agent implementation | |
# --------------------------------------------------------------------------- | |
DEFAULT_TOOLS: Sequence[Callable] = [ | |
duckduckgo_search, | |
python_run, | |
load_spreadsheet, | |
youtube_transcript, | |
transcribe_audio, | |
image_ocr, | |
] | |
class GAIAAgent: | |
def __init__( | |
self, | |
tools: Sequence[Callable] | None = None | |
): | |
self.model = _select_model() | |
self.tools = tools or DEFAULT_TOOLS | |
base_system_prompt = "You are a helpful assistant designed to answer questions and complete tasks. You have access to a variety of tools to help you." | |
full_system_prompt = f"{base_system_prompt}\n\n{ADDED_PROMPT}" | |
self.agent = Agent( | |
model=self.model, | |
tools=self.tools, | |
instructions=full_system_prompt, | |
name="GAIAAgent" | |
) | |
async def __call__(self, question: str, **kwargs: Any) -> str: | |
""" | |
Asynchronously processes a question using the agent and returns the final answer. | |
kwargs are passed to Runner.run if supported, currently ignored as per plan. | |
""" | |
# As per plan, Runner.run(self.agent, question) is used. | |
# If session_id or other kwargs are needed by Runner.run, this might need adjustment. | |
response = await Runner.run(self.agent, question) | |
# Extract the final output. Assuming response.final_output is the way. | |
# The type of final_output needs to be handled (e.g. if it's a message object or just text) | |
final_answer = response.final_output | |
if hasattr(final_answer, 'content'): # Example if final_output is a message object | |
final_answer_text = str(final_answer.content) | |
else: | |
final_answer_text = str(final_answer) | |
return final_answer_text.strip() | |
# --------------------------------------------------------------------------- | |
# Factory helpers expected by app.py | |
# --------------------------------------------------------------------------- | |
def gaia_agent(*, extra_tools: Sequence[Callable] | None = None) -> GAIAAgent: | |
""" | |
Factory function to create a GAIAAgent instance with default and optional extra tools. | |
""" | |
toolset = list(DEFAULT_TOOLS) | |
if extra_tools: | |
toolset.extend(extra_tools) | |
return GAIAAgent(tools=toolset) | |
__all__ = ["GAIAAgent", "gaia_agent"] | |