hassenhamdi's picture
Create tools.py
c38cd66 verified
raw
history blame
18.5 kB
# tools.py
import os
import json
import logging
import re
import requests
import shutil
import urllib.parse
import pandas as pd # For ExcelParsingTool
from board_to_fen.predict import get_fen_from_image_path # For ChessBoardFENTool
from google import genai
from google.genai import types
# from litellm import completion # Removed - no longer used for ConvertChessMoveTool
from smolagents import Tool
from settings import Settings
from models import GoogleModelID # Import GoogleModelID
logger = logging.getLogger(__name__)
class BaseCustomTool(Tool):
"""Base class for custom tools to easily pass settings."""
def __init__(self, settings: Settings):
super().__init__()
self.settings = settings
class GetTaskFileTool(BaseCustomTool):
name = "get_task_file_tool"
description = """If a file_name is provided in the task, use this tool to download the file associated with a given task_id. Returns the absolute file path to the downloaded file. This path can then be used by other tools like AudioUnderstandingTool or ExcelParsingTool. Example: get_task_file_tool(task_id="1234", file_name="example.mp3")"""
inputs = {
"task_id": {"type": "string", "description": "Task ID (required)"},
"file_name": {"type": "string", "description": "File name (required)"},
}
output_type = "string"
def __init__(self, settings: Settings):
super().__init__(settings)
self.directory_name = "downloads"
self.create_dir()
def forward(self, task_id: str, file_name: str) -> str:
try:
# Use the scoring API base URL for file downloads
response = requests.get(f"{self.settings.scoring_api_base_url}/files/{task_id}", timeout=15)
response.raise_for_status()
# Ensure the downloads directory exists
os.makedirs(self.directory_name, exist_ok=True)
file_path = os.path.join(self.directory_name, file_name)
with open(file_path, 'wb') as file:
file.write(response.content)
absolute_file_path = os.path.abspath(file_path)
logger.info(f"Successfully downloaded file '{file_name}' for task_id {task_id} to {absolute_file_path}")
return absolute_file_path
except requests.exceptions.RequestException as e:
logger.error(f"Error downloading file for task_id {task_id} from API: {e}")
# Fallback to local 'files' directory if API download fails
local_file_path = os.path.join("files", file_name)
if os.path.exists(local_file_path):
destination_path = os.path.join(self.directory_name, file_name)
os.makedirs(self.directory_name, exist_ok=True)
shutil.copy2(local_file_path, destination_path)
absolute_local_file_path = os.path.abspath(destination_path)
logger.info(f"Copied local fallback file '{file_name}' to {absolute_local_file_path}")
return absolute_local_file_path
else:
logger.error(f"Local fallback file '{local_file_path}' not found.")
return f"Error: Could not download or find file '{file_name}' for task_id {task_id}. {e}"
except Exception as e:
logger.error(f"An unexpected error occurred in GetTaskFileTool: {e}")
return f"Error: An unexpected error occurred while getting file '{file_name}'. {e}"
def create_dir(self):
"""Creates the download directory if it doesn't exist."""
if not os.path.exists(self.directory_name):
os.makedirs(self.directory_name)
logger.info(f"Directory '{self.directory_name}' created successfully.")
else:
logger.debug(f"Directory '{self.directory_name}' already exists.")
class VideoUnderstandingTool(BaseCustomTool):
name = "video_understanding_tool"
description = "Analyzes a YouTube video given its URL and a specific prompt/question about its content. Returns a text description or answer from the video. Use this for tasks involving video content. Example: video_understanding_tool(youtube_url=\"https://www.youtube.com/watch?v=VIDEO_ID\", prompt=\"What is the main topic of this video?\")"
inputs = {
"youtube_url": {"type": "string", "description": "The URL of the YouTube video"},
"prompt": {"type": "string", "description": "A question or request regarding the video content"},
}
output_type = "string"
def __init__(self, settings: Settings, model: GoogleModelID):
super().__init__(settings)
self.model = model
# Initialize Google GenAI client with API key
genai.configure(api_key=self.settings.gemini_api_key.get_secret_value())
logger.info(f"VideoUnderstandingTool initialized with model: {self.model}")
def forward(self, youtube_url: str, prompt: str) -> str:
try:
# Use the genai.GenerativeModel for multimodal content
model_instance = genai.GenerativeModel(self.model)
# Create a FileData part from the YouTube URL
video_file_data = types.Part(
file_data=types.FileData(
file_uri=youtube_url,
mime_type="video/mp4" # Assuming common video type, adjust if needed
)
)
# Generate content with both video and text prompt
response = model_instance.generate_content(
contents=[video_file_data, types.Part(text=prompt)]
)
return response.text
except Exception as e:
logger.error(f"Error understanding video from URL '{youtube_url}': {e}")
return f"Error understanding video: {e}"
class AudioUnderstandingTool(BaseCustomTool):
name = "audio_understanding_tool"
description = "Analyzes a local audio file given its file path and a specific prompt/question about its content. Returns a text description or answer from the audio. Use this for tasks involving audio files. You must first download the audio file using 'get_task_file_tool'. Example: audio_understanding_tool(file_path=\"/tmp/audio.mp3\", prompt=\"What are the key ingredients mentioned?\")"
inputs = {
"file_path": {"type": "string", "description": "The local file path of the audio file (e.g., from get_task_file_tool)."},
"prompt": {"type": "string", "description": "A question or request regarding the audio content."},
}
output_type = "string"
def __init__(self, settings: Settings, model: GoogleModelID):
super().__init__(settings)
self.model = model
# Initialize Google GenAI client with API key
genai.configure(api_key=self.settings.gemini_api_key.get_secret_value())
logger.info(f"AudioUnderstandingTool initialized with model: {self.model}")
def forward(self, file_path: str, prompt: str) -> str:
try:
# Upload the local audio file to Gemini Files API
mp3_file = genai.upload_file(path=file_path)
logger.info(f"Uploaded audio file: {mp3_file.uri}")
# Use the genai.GenerativeModel for multimodal content
model_instance = genai.GenerativeModel(self.model)
# Generate content with both audio file and text prompt
response = model_instance.generate_content(
contents=[mp3_file, types.Part(text=prompt)]
)
# Delete the uploaded file from Gemini Files API (optional, but good practice)
# genai.delete_file(mp3_file.name) # This might require a separate API call or context manager
return response.text
except Exception as e:
logger.error(f"Error understanding audio from file '{file_path}': {e}")
return f"Error understanding audio: {e}"
class ExcelParsingTool(BaseCustomTool):
name = "excel_parsing_tool"
description = "Parses an Excel (.xlsx) file given its local file path. It reads the first sheet and returns its content as a CSV formatted string. Use this for tasks involving Excel data. You must first download the Excel file using 'get_task_file_tool'. Example: excel_parsing_tool(file_path=\"/tmp/sales_data.xlsx\")"
inputs = {"file_path": {"type": "string", "description": "The local path to the Excel file (e.g., from get_task_file_tool)."}}
output_type = "string"
def __init__(self, settings: Settings):
super().__init__(settings)
logger.info("ExcelParsingTool initialized.")
def forward(self, file_path: str) -> str:
"""
Reads an Excel file and returns its content (first sheet) as a CSV string.
"""
try:
# Ensure the file exists before trying to read
if not os.path.exists(file_path):
raise FileNotFoundError(f"Excel file not found at: {file_path}")
df = pd.read_excel(file_path)
csv_content = df.to_csv(index=False)
logger.info(f"Successfully parsed Excel file: {file_path}")
return csv_content
except Exception as e:
logger.error(f"Error parsing Excel file {file_path}: {e}")
return f"Error parsing Excel file: {e}"
class ConvertChessMoveTool(BaseCustomTool):
name = "convert_chess_move_tool"
description = "Converts a chess move from coordinate notation (e.g., 'e2e4') to standard algebraic notation. Requires the current piece placement as plain text. Example: convert_chess_move_tool(piece_placement=\"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR\", move=\"e2e4\")"
inputs = {
"piece_placement": {"type": "string", "description": "The chess piece placement in plain text (e.g., a FEN board part)."},
"move": {"type": "string", "description": "The move in coordinate notation (e.g., 'e2e4')"},
}
output_type = "string"
def __init__(self, settings: Settings, model: GoogleModelID): # Changed model type to GoogleModelID
super().__init__(settings)
self.model = model
genai.configure(api_key=self.settings.gemini_api_key.get_secret_value()) # Configure genai for this tool
logger.info(f"ConvertChessMoveTool initialized with model: {self.model}")
def forward(self, piece_placement: str, move: str) -> str:
move_message = (
f"Convert this chess move from coordinate notation to algebraic "
f"notation: {move}. Use the following board state for context: {piece_placement}. "
"Do not provide any additional thinking or commentary in the response, "
"return only the algebraic notation for the move."
)
messages = [{ "content": move_message, "role": "user"}]
try:
model_instance = genai.GenerativeModel(self.model) # Use genai.GenerativeModel
response = model_instance.generate_content(
contents=messages[0]['content'] # Pass content directly
)
return response.text
except Exception as e:
logger.error(f"Error converting chess move: {e}")
return f"Error converting chess move: {e}"
class BestChessMoveTool(BaseCustomTool):
name = "best_chess_move_tool"
description = "Gets the best chess move in coordinate notation (e.g., 'e2e4') based on a FEN (Forsyth-Edwards Notation) representation of the chess position. Example: best_chess_move_tool(fen=\"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1\")"
inputs = {
"fen": {"type": "string", "description": "The FEN (Forsyth-Edwards Notation) representation of the chess position. Example: 'rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1'"},
}
output_type = "string"
def forward(self, fen: str) -> str:
try:
url = f"{self.settings.chess_eval_url}?fen={urllib.parse.quote(fen)}&depth=15" # Depth 15 for reasonable accuracy
response = requests.get(url, timeout=15)
response.raise_for_status() # Raise HTTPError for bad responses
response_json = response.json()
if response_json.get('success') == True and 'bestmove' in response_json:
# Stockfish API often returns "bestmove e2e4 ponder e7e5"
# We need only the move itself, which is the second part
best_move = response_json['bestmove'].split()[1]
logger.info(f"Successfully retrieved best chess move: {best_move} for FEN: {fen}")
return best_move
else:
raise ValueError(f"Stockfish API returned unsuccessful response or missing 'bestmove': {response_json}")
except Exception as e:
logger.error(f"Error getting best chess move for FEN '{fen}': {e}")
return f"Error getting best chess move: {e}"
class ChessBoardFENTool(Tool):
name = "chess_board_fen_tool"
description = "Generates the FEN (Forsyth-Edwards Notation) representation from a local image file of a chess board and the player whose turn it is. Returns the FEN string. You must first download the image file using 'get_task_file_tool'. Example: chess_board_fen_tool(image_path=\"/tmp/board.png\", player_turn=\"b\")"
inputs = {
"image_path": {"type": "string", "description": "The local file path of the chess board image (e.g., from get_task_file_tool)."},
"player_turn": {"type": "string", "description": "The player with the next turn in the match, must be 'w' (white) or 'b' (black)."}
}
output_type = "string"
def _expand_fen_rank(self, rank_str):
"""Expands a single rank string from FEN notation into a list of 8 characters."""
expanded_rank = []
for char in rank_str:
if char.isdigit():
expanded_rank.extend([' '] * int(char))
else:
expanded_rank.append(char)
if len(expanded_rank) != 8:
raise ValueError(f"Invalid FEN rank string (length != 8): {rank_str}")
return expanded_rank
def _compress_fen_rank(self, rank_list):
"""Compresses a list of 8 characters (representing a rank) back into FEN rank notation."""
if len(rank_list) != 8:
raise ValueError(f"Invalid rank list (length != 8): {rank_list}")
compressed_rank = ""
empty_count = 0
for char in rank_list:
if char == ' ':
empty_count += 1
else:
if empty_count > 0:
compressed_rank += str(empty_count)
empty_count = 0
compressed_rank += char
if empty_count > 0:
compressed_rank += str(empty_count)
return compressed_rank
def _invert_mirror_fen(self, fen_string: str) -> str:
"""
Takes a FEN string, inverts the board vertically, mirrors it horizontally,
and returns the new FEN string representing this transformed view.
This is often needed to convert board_to_fen output to Stockfish compatible FEN.
"""
try:
parts = fen_string.strip().split(' ')
if len(parts) != 6:
raise ValueError("FEN string must have 6 space-separated fields (board, turn, castling, ep, halfmove, fullmove).")
board_part = parts[0]
other_parts = parts[1:]
rank_strings = board_part.split('/')
if len(rank_strings) != 8:
raise ValueError("FEN board part must have 8 ranks separated by '/'.")
original_board = [self._expand_fen_rank(r) for r in rank_strings]
transformed_board = [[' ' for _ in range(8)] for _ in range(8)]
for r in range(8):
for c in range(8):
transformed_board[7 - r][7 - c] = original_board[r][c]
new_rank_strings = [self._compress_fen_rank(row) for row in transformed_board]
new_board_part = "/".join(new_rank_strings)
return " ".join([new_board_part] + other_parts)
except Exception as e:
logger.error(f"Error processing FEN for inversion/mirroring: {e}. Input: '{fen_string}'")
return f"Error processing FEN: {e}"
def _add_fen_game_state(self, board_placement: str,
side_to_move: str,
castling: str = "-",
en_passant: str = "-",
halfmove_clock: int = 0,
fullmove_number: int = 1) -> str:
"""
Appends standard game state information to a FEN board placement string.
"""
side_to_move_lower = str(side_to_move).lower()
if side_to_move_lower not in ['w', 'b']:
return f"Error: side_to_move must be 'w' or 'b', received '{side_to_move}'"
try:
halfmove_clock = int(halfmove_clock)
fullmove_number = int(fullmove_number)
if halfmove_clock < 0:
raise ValueError("halfmove_clock cannot be negative.")
if fullmove_number < 1:
raise ValueError("fullmove_number must be 1 or greater.")
except (ValueError, TypeError):
return (f"Error: halfmove_clock ('{halfmove_clock}') and "
f"fullmove_number ('{fullmove_number}') must be valid integers "
f"(non-negative and positive respectively).")
full_fen = (f"{board_placement} {side_to_move_lower} {castling} "
f"{en_passant} {halfmove_clock} {fullmove_number}")
return full_fen
def forward(self, image_path: str, player_turn: str) -> str:
try:
board_placement = get_fen_from_image_path(image_path)
# Add game state to the board placement
board_fen_with_state = self._add_fen_game_state(board_placement, player_turn)
# Inversion makes board_to_fen output Stockfish compatible
board_fen_inverted = self._invert_mirror_fen(board_fen_with_state)
logger.info(f"Generated FEN from image '{image_path}': {board_fen_inverted}")
return board_fen_inverted
except Exception as e:
logger.error(f"Error generating FEN from image '{image_path}': {e}")
return f"Error generating FEN from image: {e}"