File size: 8,147 Bytes
476b409 |
|
import os
import json
import logging
logger = logging.getLogger(__name__)
import requests
import shutil
from typing import Any
import urllib.parse
from board_to_fen.predict import get_fen_from_image_path
from google import genai
from google.genai import types
from litellm import completion
from smolagents import Tool
from settings import Settings
class BaseCustomTool(Tool):
def __init__(self, settings):
super().__init__()
self.settings = settings
class GetTaskFileTool(BaseCustomTool):
name = "get_task_file_tool"
description = """Download the file_name associated with a given task_id. Get absolute file path"""
inputs = {
"task_id": {"type": "string", "description": "Task ID"},
"file_name": {"type": "string", "description": "File name"},
}
output_type = "string"
def __init__(self, settings):
super().__init__(settings)
self.directory_name = "downloads"
self.create_dir()
def forward(self, task_id: str, file_name: str) -> str:
try:
response = requests.get(f"{self.settings.evaluation_api_base_url}/files/{task_id}", timeout=15)
response.raise_for_status()
with open(f"{self.directory_name}/{file_name}", 'wb') as file:
file.write(response.content)
return os.path.abspath(f"{self.directory_name}/{file_name}")
except Exception as e:
# Fetch the local file instead, dealing with rate limits, etc.
shutil.copy2(f"files/{file_name}", f"{self.directory_name}/{file_name}")
return f"{self.directory_name}/{file_name}"
def create_dir(self):
# Create the directory if it doesn't exist
if not os.path.exists(self.directory_name):
os.makedirs(self.directory_name)
logger.info(f"Directory '{self.directory_name}' created successfully.")
else:
logger.debug(f"Directory '{self.directory_name}' already exists.")
class VideoUnderstandingTool(BaseCustomTool):
name = "VideoUnderstanding"
description = "Prompt a YouTube video with questions to understand its content."
inputs = {
"youtube_url": {"type": "string", "description": "The URL of the YouTube video"},
"prompt": {"type": "string", "description": "A question or request regarding the video"},
}
output_type = "string"
def __init__(self, settings, model):
super().__init__(settings)
self.model = model
def forward(self, youtube_url: str, prompt: str) -> str:
client = genai.Client(api_key=self.settings.gemini_api_key.get_secret_value())
try:
video_description = client.models.generate_content(
model=self.model,
contents=types.Content(
parts=[
types.Part(
file_data=types.FileData(file_uri=youtube_url)
),
types.Part(text=prompt)
]
)
)
return video_description.text
except Exception as e:
logger.error(f"Error understanding video: {e}")
return False
class AudioUnderstandingTool(BaseCustomTool):
name = "AudioUnderstanding"
description = "Prompt a local audio file with questions to understand its content."
inputs = {
"file_path": {"type": "string", "description": "The local file of the audio"},
"prompt": {"type": "string", "description": "A question or request regarding the audio"},
}
output_type = "string"
def __init__(self, settings, model):
super().__init__(settings)
self.model = model
def forward(self, file_path: str, prompt: str) -> str:
client = genai.Client(api_key=self.settings.gemini_api_key.get_secret_value())
try:
mp3_file = client.files.upload(file=f"{file_path}")
audio_description = client.models.generate_content(
model=self.model,
contents=[prompt, mp3_file]
)
return audio_description.text
except Exception as e:
logger.error(f"Error understanding audio: {e}")
return False
class ConvertChessMoveTool(BaseCustomTool):
name = "ConvertChessMove"
description = "Convert a chess move from coordinate notation to algebraic notation."
inputs = {
"piece_placement": {"type": "string", "description": "The chess piece placement in plain text"},
"move": {"type": "string", "description": "The move in coordinate notation (e.g., e2e4)"},
}
output_type = "string"
def __init__(self, settings, model):
super().__init__(settings)
self.model = model
def forward(self, piece_placement: str, move: str) -> str:
move_message = f"""Convert this chess move from coordinate notation to algebraic
notation: {move}. Use the following {piece_placement}. Do not provide any additional
thinking or commentary in the response, the algebraic notation only."""
messages = [{ "content": move_message,"role": "user"}]
response = completion(
model=self.model,
temperature=0.0,
messages=messages,
api_key=self.settings.openrouter_api_key.get_secret_value()
)
return response.choices[0].message.content
class BestChessMoveTool(BaseCustomTool):
name = "BestChessMove"
description = "Get best chess move in coordinate notation based on a FEN representation."
inputs = {
"fen": {"type": "string", "description": "The FEN (Forsyth-Edwards Notation) \
representation of the chess position. Example \
rn1q1rk1/pp2b1pp/2p2n2/3p1pB1/3P4/1QP2N2/PP1N1PPP/R4RK1 b - - 1 11"},
}
output_type = "string"
def forward(self, fen: str) -> str:
try:
url = f"{self.settings.chess_eval_url}?fen={urllib.parse.quote(fen)}&depth=15"
response = requests.get(url, timeout=15)
if response.status_code == 200 and json.loads(response.text)['success'] == True:
return json.loads(response.text)['bestmove'].split()[1]
else:
raise ValueError(f"Error getting chess evaluation: {response.status_code}")
except Exception as e:
logger.error(f"Error getting chess evaluation: {e}")
class ChessGameFenTool(BaseCustomTool):
name = "ChessGameFen"
description = "Get a FEN representation given chess piece placement and a move."
inputs = {
"piece_placement": {"type": "string", "description": "The chess piece placement in plain text"},
"player_turn": {"type": "string",
"description": "The player with the next turn in the match, black or white"},
}
output_type = "string"
def __init__(self, settings, model):
super().__init__(settings)
self.model = model
def forward(self, piece_placement: str, player_turn: str) -> str:
"""Use the tool."""
fen_message = f"""Assuming {player_turn} has the next turn, Use the following placement
{piece_placement} and provide the board state as FEN. Do not provide any
additional thinking or commentary in the response, the FEN only."""
messages = [{ "content": fen_message,"role": "user"}]
response = completion(
model=self.model,
temperature=0.0,
messages=messages,
api_key=self.settings.openrouter_api_key.get_secret_value()
)
return response.choices[0].message.content
class ChessPiecePlacementTool(Tool):
name = "ChessPiecePlacement"
description = "Get chess piece placement information from an image of a board."
inputs = {
"image_path": {"type": "string", "description": "The local file of the chess board image"},
}
output_type = "string"
def forward(self, image_path: str) -> str:
return get_fen_from_image_path(image_path)
|