Spaces:
Sleeping
Sleeping
File size: 5,991 Bytes
73bb16b 52d1305 73bb16b 52d1305 73bb16b 52d1305 73bb16b 52d1305 73bb16b 52d1305 73bb16b 75a272e 73bb16b bb49a20 52d1305 e4ed116 db0abac e4ed116 db0abac e4ed116 52d1305 73bb16b 52d1305 73bb16b 52d1305 73bb16b 52d1305 73bb16b 52d1305 73bb16b e4ed116 52d1305 73bb16b 52d1305 73bb16b db0abac 73bb16b db0abac 73bb16b eea77dd 73bb16b 52d1305 73bb16b bb49a20 52d1305 bb49a20 52d1305 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
"""GAIA benchmark agent using OpenAI Agents SDK.
This module exposes:
* ``gaia_agent()`` – factory returning a ready‑to‑use agent instance.
* ``GAIAAgent`` – a class that wraps ``openai_agents.Agent``.
The LLM backend is chosen at runtime via the ``MODEL_PROVIDER``
environment variable (``hf`` or ``openai``).
"""
import os
import asyncio # Added for potential direct asyncio.run if needed, and for async def
from typing import Any, Sequence, Callable, Union # Added Callable and Union
from dotenv import load_dotenv
# OpenAI Agents SDK imports
from openai_agents import Agent, Runner
from openai_agents.models.openai_chat_completions import OpenAIChatCompletionsModel
from openai_agents.extensions.models.litellm_model import LitellmModel
# FunctionToolType could be imported if it's a public type, for now using Callable
# from openai_agents import FunctionToolType # Example if such type exists
# Custom Tools from tools.py (now functions)
from tools import (
python_run,
load_spreadsheet,
youtube_transcript,
transcribe_audio,
image_ocr,
duckduckgo_search, # Added the new tool
)
# ---------------------------------------------------------------------------
# Load the added system prompt from system_prompt.txt (located in the same directory)
# ---------------------------------------------------------------------------
ADDED_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "added_prompt.txt")
with open(ADDED_PROMPT_PATH, "r", encoding="utf-8") as f:
ADDED_PROMPT = f.read().strip()
# ---------------------------------------------------------------------------
# Model selection helper
# ---------------------------------------------------------------------------
load_dotenv() # Make sure we read credentials from .env
def _select_model() -> Union[OpenAIChatCompletionsModel, LitellmModel]:
"""Return an OpenAI Agents SDK model instance as configured by env variables."""
provider = os.getenv("MODEL_PROVIDER", "hf").lower()
# Ensure API keys are loaded if not directly passed to model constructors
# OpenAI API key is typically read by the library from OPENAI_API_KEY env var
# LiteLLM also often relies on environment variables for keys
if provider == "hf":
hf_model_id = os.getenv("HF_MODEL", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO") # Example, ensure this is a valid LiteLLM model ID
# LiteLLM typically requires a prefix for HuggingFace models
if not hf_model_id.startswith("huggingface/"):
hf_model_id = f"huggingface/{hf_model_id}"
hf_token = os.getenv("HF_API_KEY") # LiteLLM might use this or HUGGINGFACE_API_KEY
# For LiteLLM, api_key parameter might be used for specific providers,
# but often it relies on env vars like HUGGINGFACE_API_KEY.
# Passing token explicitly if LitellmModel supports it, or ensuring env var is set.
return LitellmModel(model=hf_model_id, api_key=hf_token if hf_token else None)
if provider == "openai":
openai_model_id = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo")
openai_token = os.getenv("OPENAI_API_KEY") # OpenAIChatCompletionsModel will use this by default if set in env
return OpenAIChatCompletionsModel(
model=openai_model_id,
api_key=openai_token # Explicitly passing, though often picked from env
)
raise ValueError(
f"Unsupported MODEL_PROVIDER: {provider!r}. "
"Use 'hf' (default) or 'openai'."
)
# ---------------------------------------------------------------------------
# Core Agent implementation
# ---------------------------------------------------------------------------
DEFAULT_TOOLS: Sequence[Callable] = [
duckduckgo_search,
python_run,
load_spreadsheet,
youtube_transcript,
transcribe_audio,
image_ocr,
]
class GAIAAgent:
def __init__(
self,
tools: Sequence[Callable] | None = None
):
self.model = _select_model()
self.tools = tools or DEFAULT_TOOLS
base_system_prompt = "You are a helpful assistant designed to answer questions and complete tasks. You have access to a variety of tools to help you."
full_system_prompt = f"{base_system_prompt}\n\n{ADDED_PROMPT}"
self.agent = Agent(
model=self.model,
tools=self.tools,
instructions=full_system_prompt,
name="GAIAAgent"
)
async def __call__(self, question: str, **kwargs: Any) -> str:
"""
Asynchronously processes a question using the agent and returns the final answer.
kwargs are passed to Runner.run if supported, currently ignored as per plan.
"""
# As per plan, Runner.run(self.agent, question) is used.
# If session_id or other kwargs are needed by Runner.run, this might need adjustment.
response = await Runner.run(self.agent, question)
# Extract the final output. Assuming response.final_output is the way.
# The type of final_output needs to be handled (e.g. if it's a message object or just text)
final_answer = response.final_output
if hasattr(final_answer, 'content'): # Example if final_output is a message object
final_answer_text = str(final_answer.content)
else:
final_answer_text = str(final_answer)
return final_answer_text.strip()
# ---------------------------------------------------------------------------
# Factory helpers expected by app.py
# ---------------------------------------------------------------------------
def gaia_agent(*, extra_tools: Sequence[Callable] | None = None) -> GAIAAgent:
"""
Factory function to create a GAIAAgent instance with default and optional extra tools.
"""
toolset = list(DEFAULT_TOOLS)
if extra_tools:
toolset.extend(extra_tools)
return GAIAAgent(tools=toolset)
__all__ = ["GAIAAgent", "gaia_agent"]
|