hassenhamdi's picture
Create agent.py
6573e0d verified
raw
history blame
8.36 kB
# agent.py
import contextlib
import io
import logging
import os
logger = logging.getLogger(__name__)
from models import GoogleModelID # Import GoogleModelID
from settings import Settings
from smolagents import OpenAIServerModel, CodeAgent, FinalAnswerTool # Changed from LiteLLMModel
from smolagents import DuckDuckGoSearchTool, VisitWebpageTool # Changed from GoogleSearchTool
from smolagents.local_python_executor import BASE_PYTHON_TOOLS
from tools import GetTaskFileTool, VideoUnderstandingTool, AudioUnderstandingTool
from tools import ChessBoardFENTool, BestChessMoveTool, ConvertChessMoveTool, ExcelParsingTool
import json # Added for BASE_PYTHON_TOOLS
import pandas as pd # Added for BASE_PYTHON_TOOLS
# Extend BASE_PYTHON_TOOLS for the PythonInterpreterTool to have access to these
BASE_PYTHON_TOOLS["open"] = open
BASE_PYTHON_TOOLS["os"] = os
BASE_PYTHON_TOOLS["io"] = io
BASE_PYTHON_TOOLS["contextlib"] = contextlib
BASE_PYTHON_TOOLS["exec"] = exec # Note: exec is powerful, use with caution in production
BASE_PYTHON_TOOLS["json"] = json # For parsing JSON if needed by agent
BASE_PYTHON_TOOLS["pd"] = pd # For pandas operations if needed by agent
class ResearchAgent:
def __init__(self, settings: Settings):
self.agent = CodeAgent(
name="researcher",
description="A specialized agent for web research, video analysis, and audio understanding. Give it your query as an argument. Use 'duckduckgo_search_tool' for web searches, 'visit_webpage_tool' to read web page content, 'video_understanding_tool' for YouTube videos, and 'audio_understanding_tool' for local audio files.",
add_base_tools=False,
tools=[
DuckDuckGoSearchTool(), # Changed from GoogleSearchTool
VisitWebpageTool(max_output_length=100000),
VideoUnderstandingTool(settings, GoogleModelID.GEMINI_2_0_FLASH), # Still uses 2.0 Flash for specific multimodal tasks
AudioUnderstandingTool(settings, GoogleModelID.GEMINI_2_0_FLASH) # Still uses 2.0 Flash for specific multimodal tasks
],
additional_authorized_imports=[
"unicodedata", "stat", "datetime", "random", "pandas", "itertools",
"math", "statistics", "queue", "time", "collections", "re", "os",
"json", "io", "urllib.parse"
],
max_steps=15,
verbosity_level=2,
model=OpenAIServerModel( # Changed to OpenAIServerModel
model_id=GoogleModelID.GEMINI_2_5_FLASH_PREVIEW, # Set to GEMINI_2_5_FLASH_PREVIEW
api_base="https://generativelanguage.googleapis.com/v1beta/openai/", # Gemini API base
api_key = settings.gemini_api_key.get_secret_value(), # Use Gemini API key
temperature=0.1,
timeout=180
)
)
logger.info("ResearchAgent initialized.")
class ChessAgent:
def __init__(self, settings: Settings):
self.agent = CodeAgent(
name="chess_player",
description="Makes a chess move. Give it a query including board image filepath and player turn (black or white).",
add_base_tools=False,
tools=[
ChessBoardFENTool(),
BestChessMoveTool(settings),
ConvertChessMoveTool(settings, GoogleModelID.GEMINI_2_5_FLASH_PREVIEW), # Changed to Gemini Flash Preview
],
additional_authorized_imports=[
"unicodedata", "stat", "datetime", "random", "pandas", "itertools",
"math", "statistics", "queue", "time", "collections", "re", "os",
"json", "urllib.parse"
],
max_steps=10,
verbosity_level=2,
model=OpenAIServerModel( # Changed to OpenAIServerModel
model_id=GoogleModelID.GEMINI_2_5_FLASH_PREVIEW, # Set to GEMINI_2_5_FLASH_PREVIEW
api_base="https://generativelanguage.googleapis.com/v1beta/openai/", # Gemini API base
api_key = settings.gemini_api_key.get_secret_value(), # Use Gemini API key
temperature=0.0,
timeout=180
)
)
logger.info("ChessAgent initialized.")
class ManagerAgent:
"""
The main orchestrating agent that routes questions to specialized sub-agents
or handles them directly with its own tools.
"""
def __init__(self, settings: Settings):
self.settings = settings
self.researcher = ResearchAgent(settings).agent
self.chess_player = ChessAgent(settings).agent
# Main manager agent
self.agent = CodeAgent(
name="manager",
description=(
"You are a highly capable AI assistant designed to solve complex GAIA benchmark questions. "
"Your primary role is to route tasks to the most appropriate specialized agent: "
"'researcher' for general knowledge, web browsing, video, and audio understanding tasks, "
"or 'chess_player' for chess-related tasks. "
"If a task involves downloading a file, use 'get_task_file_tool' first. "
"If you have the final answer, use 'final_answer_tool'.\n\n"
"**Available Tools:**\n"
"- `get_task_file_tool(task_id: str, file_name: str)`: Downloads a file associated with a task.\n"
"- `final_answer_tool(answer: str)`: Use this when you have the exact final answer.\n\n"
"**Managed Agents:**\n"
"- `researcher(query: str)`: Use for questions requiring web search, video analysis, or audio analysis.\n"
"- `chess_player(query: str)`: Use for questions related to chess positions or moves.\n\n"
"Think step-by-step. If a task involves a file, use `get_task_file_tool` first to download it, then pass the file path to the appropriate sub-agent or tool."
),
tools=[
GetTaskFileTool(settings),
FinalAnswerTool(),
ExcelParsingTool(settings) # Added ExcelParsingTool to ManagerAgent as it handles file paths
],
model=OpenAIServerModel( # Changed to OpenAIServerModel
model_id=GoogleModelID.GEMINI_2_5_FLASH_PREVIEW, # Set to GEMINI_2_5_FLASH_PREVIEW
api_base="https://generativelanguage.googleapis.com/v1beta/openai/", # Gemini API base
api_key = settings.gemini_api_key.get_secret_value(), # Use Gemini API key
temperature=0.0,
timeout=180
),
managed_agents=[self.researcher, self.chess_player],
verbosity_level=2,
max_steps=20
)
logger.info("ManagerAgent initialized.")
def __call__(self, question_data: dict) -> str:
task_id = question_data.get("task_id", "N/A")
question_text = question_data.get("question", "")
file_name = question_data.get("file_name", "")
enriched_question = (
f"{question_text} "
f"task_id: {task_id}. "
f"Your final answer should be a number or as few words as possible. "
f"Only use abbreviations when the question calls for abbreviations. "
f"If needed, use a comma separated list of values; the comma is always followed by a space. "
f"Critically review your answer before making it the final answer. "
f"Double check the answer to make sure it meets all format requirements stated in the question. "
)
if file_name:
enriched_question = f"{enriched_question} file_name: {file_name} (use get_task_file_tool to fetch this file and then pass its path to the relevant tool/agent, or excel_parsing_tool if it's an Excel file)." # Updated prompt for Excel
logger.info(f"ManagerAgent received question (first 100 chars): {enriched_question[:100]}...")
try:
final_answer = self.agent.run(enriched_question)
logger.info(f"ManagerAgent returning final answer: {final_answer}")
return final_answer
except Exception as e:
logger.error(f"Error running ManagerAgent on task {task_id}: {e}")
return f"AGENT ERROR: {e}"