File size: 5,991 Bytes
73bb16b
52d1305
 
 
 
73bb16b
52d1305
 
73bb16b
52d1305
 
 
73bb16b
 
52d1305
 
 
73bb16b
 
 
 
 
 
52d1305
73bb16b
75a272e
73bb16b
 
 
 
 
 
bb49a20
52d1305
e4ed116
db0abac
e4ed116
db0abac
 
 
e4ed116
52d1305
 
 
 
73bb16b
52d1305
73bb16b
 
52d1305
 
73bb16b
 
 
52d1305
 
73bb16b
 
 
 
 
 
 
 
 
 
52d1305
 
 
73bb16b
 
 
 
e4ed116
52d1305
 
 
 
 
 
 
 
 
 
73bb16b
 
 
 
 
 
 
52d1305
 
73bb16b
db0abac
 
73bb16b
db0abac
73bb16b
 
 
 
 
 
 
 
 
 
 
eea77dd
73bb16b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52d1305
 
 
 
 
73bb16b
 
 
 
bb49a20
52d1305
bb49a20
 
52d1305
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""GAIA benchmark agent using OpenAI Agents SDK.

This module exposes:

* ``gaia_agent()`` – factory returning a ready‑to‑use agent instance.
* ``GAIAAgent``  – a class that wraps ``openai_agents.Agent``.

The LLM backend is chosen at runtime via the ``MODEL_PROVIDER``
environment variable (``hf`` or ``openai``).
"""

import os
import asyncio # Added for potential direct asyncio.run if needed, and for async def
from typing import Any, Sequence, Callable, Union # Added Callable and Union

from dotenv import load_dotenv

# OpenAI Agents SDK imports
from openai_agents import Agent, Runner
from openai_agents.models.openai_chat_completions import OpenAIChatCompletionsModel
from openai_agents.extensions.models.litellm_model import LitellmModel
# FunctionToolType could be imported if it's a public type, for now using Callable
# from openai_agents import FunctionToolType # Example if such type exists

# Custom Tools from tools.py (now functions)
from tools import (
    python_run,
    load_spreadsheet,
    youtube_transcript,
    transcribe_audio,
    image_ocr,
    duckduckgo_search, # Added the new tool
)

# ---------------------------------------------------------------------------
# Load the added system prompt from system_prompt.txt (located in the same directory)
# ---------------------------------------------------------------------------
ADDED_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "added_prompt.txt")
with open(ADDED_PROMPT_PATH, "r", encoding="utf-8") as f:
    ADDED_PROMPT = f.read().strip()

# ---------------------------------------------------------------------------
# Model selection helper
# ---------------------------------------------------------------------------

load_dotenv()  # Make sure we read credentials from .env

def _select_model() -> Union[OpenAIChatCompletionsModel, LitellmModel]:
    """Return an OpenAI Agents SDK model instance as configured by env variables."""

    provider = os.getenv("MODEL_PROVIDER", "hf").lower()
    # Ensure API keys are loaded if not directly passed to model constructors
    # OpenAI API key is typically read by the library from OPENAI_API_KEY env var
    # LiteLLM also often relies on environment variables for keys

    if provider == "hf":
        hf_model_id = os.getenv("HF_MODEL", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO") # Example, ensure this is a valid LiteLLM model ID
        # LiteLLM typically requires a prefix for HuggingFace models
        if not hf_model_id.startswith("huggingface/"):
            hf_model_id = f"huggingface/{hf_model_id}"
        hf_token = os.getenv("HF_API_KEY") # LiteLLM might use this or HUGGINGFACE_API_KEY
        # For LiteLLM, api_key parameter might be used for specific providers,
        # but often it relies on env vars like HUGGINGFACE_API_KEY.
        # Passing token explicitly if LitellmModel supports it, or ensuring env var is set.
        return LitellmModel(model=hf_model_id, api_key=hf_token if hf_token else None)


    if provider == "openai":
        openai_model_id = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo")
        openai_token = os.getenv("OPENAI_API_KEY") # OpenAIChatCompletionsModel will use this by default if set in env
        return OpenAIChatCompletionsModel(
            model=openai_model_id,
            api_key=openai_token # Explicitly passing, though often picked from env
        )

    raise ValueError(
        f"Unsupported MODEL_PROVIDER: {provider!r}. "
        "Use 'hf' (default) or 'openai'."
    )

# ---------------------------------------------------------------------------
# Core Agent implementation
# ---------------------------------------------------------------------------

DEFAULT_TOOLS: Sequence[Callable] = [
    duckduckgo_search,
    python_run,
    load_spreadsheet,
    youtube_transcript,
    transcribe_audio,
    image_ocr,
]

class GAIAAgent:
    def __init__(
        self,
        tools: Sequence[Callable] | None = None
    ):
        self.model = _select_model()
        self.tools = tools or DEFAULT_TOOLS
        
        base_system_prompt = "You are a helpful assistant designed to answer questions and complete tasks. You have access to a variety of tools to help you."
        full_system_prompt = f"{base_system_prompt}\n\n{ADDED_PROMPT}"
        
        self.agent = Agent(
            model=self.model,
            tools=self.tools,
            instructions=full_system_prompt,
            name="GAIAAgent"
        )

    async def __call__(self, question: str, **kwargs: Any) -> str:
        """
        Asynchronously processes a question using the agent and returns the final answer.
        kwargs are passed to Runner.run if supported, currently ignored as per plan.
        """
        # As per plan, Runner.run(self.agent, question) is used.
        # If session_id or other kwargs are needed by Runner.run, this might need adjustment.
        response = await Runner.run(self.agent, question)
        
        # Extract the final output. Assuming response.final_output is the way.
        # The type of final_output needs to be handled (e.g. if it's a message object or just text)
        final_answer = response.final_output
        if hasattr(final_answer, 'content'): # Example if final_output is a message object
             final_answer_text = str(final_answer.content)
        else:
             final_answer_text = str(final_answer)
             
        return final_answer_text.strip()

# ---------------------------------------------------------------------------
# Factory helpers expected by app.py
# ---------------------------------------------------------------------------

def gaia_agent(*, extra_tools: Sequence[Callable] | None = None) -> GAIAAgent:
    """
    Factory function to create a GAIAAgent instance with default and optional extra tools.
    """
    toolset = list(DEFAULT_TOOLS)
    if extra_tools:
        toolset.extend(extra_tools)
    return GAIAAgent(tools=toolset)

__all__ = ["GAIAAgent", "gaia_agent"]