hassenhamdi commited on
Commit
c38cd66
·
verified ·
1 Parent(s): 9137cdb

Create tools.py

Browse files
Files changed (1) hide show
  1. tools.py +363 -0
tools.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tools.py
2
+ import os
3
+ import json
4
+ import logging
5
+ import re
6
+ import requests
7
+ import shutil
8
+ import urllib.parse
9
+ import pandas as pd # For ExcelParsingTool
10
+ from board_to_fen.predict import get_fen_from_image_path # For ChessBoardFENTool
11
+ from google import genai
12
+ from google.genai import types
13
+ # from litellm import completion # Removed - no longer used for ConvertChessMoveTool
14
+ from smolagents import Tool
15
+ from settings import Settings
16
+ from models import GoogleModelID # Import GoogleModelID
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class BaseCustomTool(Tool):
21
+ """Base class for custom tools to easily pass settings."""
22
+ def __init__(self, settings: Settings):
23
+ super().__init__()
24
+ self.settings = settings
25
+
26
+ class GetTaskFileTool(BaseCustomTool):
27
+ name = "get_task_file_tool"
28
+ description = """If a file_name is provided in the task, use this tool to download the file associated with a given task_id. Returns the absolute file path to the downloaded file. This path can then be used by other tools like AudioUnderstandingTool or ExcelParsingTool. Example: get_task_file_tool(task_id="1234", file_name="example.mp3")"""
29
+ inputs = {
30
+ "task_id": {"type": "string", "description": "Task ID (required)"},
31
+ "file_name": {"type": "string", "description": "File name (required)"},
32
+ }
33
+ output_type = "string"
34
+
35
+ def __init__(self, settings: Settings):
36
+ super().__init__(settings)
37
+ self.directory_name = "downloads"
38
+ self.create_dir()
39
+
40
+ def forward(self, task_id: str, file_name: str) -> str:
41
+ try:
42
+ # Use the scoring API base URL for file downloads
43
+ response = requests.get(f"{self.settings.scoring_api_base_url}/files/{task_id}", timeout=15)
44
+ response.raise_for_status()
45
+
46
+ # Ensure the downloads directory exists
47
+ os.makedirs(self.directory_name, exist_ok=True)
48
+
49
+ file_path = os.path.join(self.directory_name, file_name)
50
+ with open(file_path, 'wb') as file:
51
+ file.write(response.content)
52
+
53
+ absolute_file_path = os.path.abspath(file_path)
54
+ logger.info(f"Successfully downloaded file '{file_name}' for task_id {task_id} to {absolute_file_path}")
55
+ return absolute_file_path
56
+ except requests.exceptions.RequestException as e:
57
+ logger.error(f"Error downloading file for task_id {task_id} from API: {e}")
58
+ # Fallback to local 'files' directory if API download fails
59
+ local_file_path = os.path.join("files", file_name)
60
+ if os.path.exists(local_file_path):
61
+ destination_path = os.path.join(self.directory_name, file_name)
62
+ os.makedirs(self.directory_name, exist_ok=True)
63
+ shutil.copy2(local_file_path, destination_path)
64
+ absolute_local_file_path = os.path.abspath(destination_path)
65
+ logger.info(f"Copied local fallback file '{file_name}' to {absolute_local_file_path}")
66
+ return absolute_local_file_path
67
+ else:
68
+ logger.error(f"Local fallback file '{local_file_path}' not found.")
69
+ return f"Error: Could not download or find file '{file_name}' for task_id {task_id}. {e}"
70
+ except Exception as e:
71
+ logger.error(f"An unexpected error occurred in GetTaskFileTool: {e}")
72
+ return f"Error: An unexpected error occurred while getting file '{file_name}'. {e}"
73
+
74
+ def create_dir(self):
75
+ """Creates the download directory if it doesn't exist."""
76
+ if not os.path.exists(self.directory_name):
77
+ os.makedirs(self.directory_name)
78
+ logger.info(f"Directory '{self.directory_name}' created successfully.")
79
+ else:
80
+ logger.debug(f"Directory '{self.directory_name}' already exists.")
81
+
82
+ class VideoUnderstandingTool(BaseCustomTool):
83
+ name = "video_understanding_tool"
84
+ description = "Analyzes a YouTube video given its URL and a specific prompt/question about its content. Returns a text description or answer from the video. Use this for tasks involving video content. Example: video_understanding_tool(youtube_url=\"https://www.youtube.com/watch?v=VIDEO_ID\", prompt=\"What is the main topic of this video?\")"
85
+ inputs = {
86
+ "youtube_url": {"type": "string", "description": "The URL of the YouTube video"},
87
+ "prompt": {"type": "string", "description": "A question or request regarding the video content"},
88
+ }
89
+ output_type = "string"
90
+
91
+ def __init__(self, settings: Settings, model: GoogleModelID):
92
+ super().__init__(settings)
93
+ self.model = model
94
+ # Initialize Google GenAI client with API key
95
+ genai.configure(api_key=self.settings.gemini_api_key.get_secret_value())
96
+ logger.info(f"VideoUnderstandingTool initialized with model: {self.model}")
97
+
98
+ def forward(self, youtube_url: str, prompt: str) -> str:
99
+ try:
100
+ # Use the genai.GenerativeModel for multimodal content
101
+ model_instance = genai.GenerativeModel(self.model)
102
+
103
+ # Create a FileData part from the YouTube URL
104
+ video_file_data = types.Part(
105
+ file_data=types.FileData(
106
+ file_uri=youtube_url,
107
+ mime_type="video/mp4" # Assuming common video type, adjust if needed
108
+ )
109
+ )
110
+
111
+ # Generate content with both video and text prompt
112
+ response = model_instance.generate_content(
113
+ contents=[video_file_data, types.Part(text=prompt)]
114
+ )
115
+
116
+ return response.text
117
+ except Exception as e:
118
+ logger.error(f"Error understanding video from URL '{youtube_url}': {e}")
119
+ return f"Error understanding video: {e}"
120
+
121
+ class AudioUnderstandingTool(BaseCustomTool):
122
+ name = "audio_understanding_tool"
123
+ description = "Analyzes a local audio file given its file path and a specific prompt/question about its content. Returns a text description or answer from the audio. Use this for tasks involving audio files. You must first download the audio file using 'get_task_file_tool'. Example: audio_understanding_tool(file_path=\"/tmp/audio.mp3\", prompt=\"What are the key ingredients mentioned?\")"
124
+ inputs = {
125
+ "file_path": {"type": "string", "description": "The local file path of the audio file (e.g., from get_task_file_tool)."},
126
+ "prompt": {"type": "string", "description": "A question or request regarding the audio content."},
127
+ }
128
+ output_type = "string"
129
+
130
+ def __init__(self, settings: Settings, model: GoogleModelID):
131
+ super().__init__(settings)
132
+ self.model = model
133
+ # Initialize Google GenAI client with API key
134
+ genai.configure(api_key=self.settings.gemini_api_key.get_secret_value())
135
+ logger.info(f"AudioUnderstandingTool initialized with model: {self.model}")
136
+
137
+ def forward(self, file_path: str, prompt: str) -> str:
138
+ try:
139
+ # Upload the local audio file to Gemini Files API
140
+ mp3_file = genai.upload_file(path=file_path)
141
+ logger.info(f"Uploaded audio file: {mp3_file.uri}")
142
+
143
+ # Use the genai.GenerativeModel for multimodal content
144
+ model_instance = genai.GenerativeModel(self.model)
145
+
146
+ # Generate content with both audio file and text prompt
147
+ response = model_instance.generate_content(
148
+ contents=[mp3_file, types.Part(text=prompt)]
149
+ )
150
+
151
+ # Delete the uploaded file from Gemini Files API (optional, but good practice)
152
+ # genai.delete_file(mp3_file.name) # This might require a separate API call or context manager
153
+
154
+ return response.text
155
+ except Exception as e:
156
+ logger.error(f"Error understanding audio from file '{file_path}': {e}")
157
+ return f"Error understanding audio: {e}"
158
+
159
+ class ExcelParsingTool(BaseCustomTool):
160
+ name = "excel_parsing_tool"
161
+ description = "Parses an Excel (.xlsx) file given its local file path. It reads the first sheet and returns its content as a CSV formatted string. Use this for tasks involving Excel data. You must first download the Excel file using 'get_task_file_tool'. Example: excel_parsing_tool(file_path=\"/tmp/sales_data.xlsx\")"
162
+ inputs = {"file_path": {"type": "string", "description": "The local path to the Excel file (e.g., from get_task_file_tool)."}}
163
+ output_type = "string"
164
+
165
+ def __init__(self, settings: Settings):
166
+ super().__init__(settings)
167
+ logger.info("ExcelParsingTool initialized.")
168
+
169
+ def forward(self, file_path: str) -> str:
170
+ """
171
+ Reads an Excel file and returns its content (first sheet) as a CSV string.
172
+ """
173
+ try:
174
+ # Ensure the file exists before trying to read
175
+ if not os.path.exists(file_path):
176
+ raise FileNotFoundError(f"Excel file not found at: {file_path}")
177
+
178
+ df = pd.read_excel(file_path)
179
+ csv_content = df.to_csv(index=False)
180
+ logger.info(f"Successfully parsed Excel file: {file_path}")
181
+ return csv_content
182
+ except Exception as e:
183
+ logger.error(f"Error parsing Excel file {file_path}: {e}")
184
+ return f"Error parsing Excel file: {e}"
185
+
186
+ class ConvertChessMoveTool(BaseCustomTool):
187
+ name = "convert_chess_move_tool"
188
+ description = "Converts a chess move from coordinate notation (e.g., 'e2e4') to standard algebraic notation. Requires the current piece placement as plain text. Example: convert_chess_move_tool(piece_placement=\"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR\", move=\"e2e4\")"
189
+ inputs = {
190
+ "piece_placement": {"type": "string", "description": "The chess piece placement in plain text (e.g., a FEN board part)."},
191
+ "move": {"type": "string", "description": "The move in coordinate notation (e.g., 'e2e4')"},
192
+ }
193
+ output_type = "string"
194
+
195
+ def __init__(self, settings: Settings, model: GoogleModelID): # Changed model type to GoogleModelID
196
+ super().__init__(settings)
197
+ self.model = model
198
+ genai.configure(api_key=self.settings.gemini_api_key.get_secret_value()) # Configure genai for this tool
199
+ logger.info(f"ConvertChessMoveTool initialized with model: {self.model}")
200
+
201
+ def forward(self, piece_placement: str, move: str) -> str:
202
+ move_message = (
203
+ f"Convert this chess move from coordinate notation to algebraic "
204
+ f"notation: {move}. Use the following board state for context: {piece_placement}. "
205
+ "Do not provide any additional thinking or commentary in the response, "
206
+ "return only the algebraic notation for the move."
207
+ )
208
+ messages = [{ "content": move_message, "role": "user"}]
209
+ try:
210
+ model_instance = genai.GenerativeModel(self.model) # Use genai.GenerativeModel
211
+ response = model_instance.generate_content(
212
+ contents=messages[0]['content'] # Pass content directly
213
+ )
214
+ return response.text
215
+ except Exception as e:
216
+ logger.error(f"Error converting chess move: {e}")
217
+ return f"Error converting chess move: {e}"
218
+
219
+ class BestChessMoveTool(BaseCustomTool):
220
+ name = "best_chess_move_tool"
221
+ description = "Gets the best chess move in coordinate notation (e.g., 'e2e4') based on a FEN (Forsyth-Edwards Notation) representation of the chess position. Example: best_chess_move_tool(fen=\"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1\")"
222
+ inputs = {
223
+ "fen": {"type": "string", "description": "The FEN (Forsyth-Edwards Notation) representation of the chess position. Example: 'rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1'"},
224
+ }
225
+ output_type = "string"
226
+
227
+ def forward(self, fen: str) -> str:
228
+ try:
229
+ url = f"{self.settings.chess_eval_url}?fen={urllib.parse.quote(fen)}&depth=15" # Depth 15 for reasonable accuracy
230
+ response = requests.get(url, timeout=15)
231
+ response.raise_for_status() # Raise HTTPError for bad responses
232
+
233
+ response_json = response.json()
234
+ if response_json.get('success') == True and 'bestmove' in response_json:
235
+ # Stockfish API often returns "bestmove e2e4 ponder e7e5"
236
+ # We need only the move itself, which is the second part
237
+ best_move = response_json['bestmove'].split()[1]
238
+ logger.info(f"Successfully retrieved best chess move: {best_move} for FEN: {fen}")
239
+ return best_move
240
+ else:
241
+ raise ValueError(f"Stockfish API returned unsuccessful response or missing 'bestmove': {response_json}")
242
+ except Exception as e:
243
+ logger.error(f"Error getting best chess move for FEN '{fen}': {e}")
244
+ return f"Error getting best chess move: {e}"
245
+
246
+ class ChessBoardFENTool(Tool):
247
+ name = "chess_board_fen_tool"
248
+ description = "Generates the FEN (Forsyth-Edwards Notation) representation from a local image file of a chess board and the player whose turn it is. Returns the FEN string. You must first download the image file using 'get_task_file_tool'. Example: chess_board_fen_tool(image_path=\"/tmp/board.png\", player_turn=\"b\")"
249
+ inputs = {
250
+ "image_path": {"type": "string", "description": "The local file path of the chess board image (e.g., from get_task_file_tool)."},
251
+ "player_turn": {"type": "string", "description": "The player with the next turn in the match, must be 'w' (white) or 'b' (black)."}
252
+ }
253
+ output_type = "string"
254
+
255
+ def _expand_fen_rank(self, rank_str):
256
+ """Expands a single rank string from FEN notation into a list of 8 characters."""
257
+ expanded_rank = []
258
+ for char in rank_str:
259
+ if char.isdigit():
260
+ expanded_rank.extend([' '] * int(char))
261
+ else:
262
+ expanded_rank.append(char)
263
+ if len(expanded_rank) != 8:
264
+ raise ValueError(f"Invalid FEN rank string (length != 8): {rank_str}")
265
+ return expanded_rank
266
+
267
+ def _compress_fen_rank(self, rank_list):
268
+ """Compresses a list of 8 characters (representing a rank) back into FEN rank notation."""
269
+ if len(rank_list) != 8:
270
+ raise ValueError(f"Invalid rank list (length != 8): {rank_list}")
271
+
272
+ compressed_rank = ""
273
+ empty_count = 0
274
+ for char in rank_list:
275
+ if char == ' ':
276
+ empty_count += 1
277
+ else:
278
+ if empty_count > 0:
279
+ compressed_rank += str(empty_count)
280
+ empty_count = 0
281
+ compressed_rank += char
282
+ if empty_count > 0:
283
+ compressed_rank += str(empty_count)
284
+ return compressed_rank
285
+
286
+ def _invert_mirror_fen(self, fen_string: str) -> str:
287
+ """
288
+ Takes a FEN string, inverts the board vertically, mirrors it horizontally,
289
+ and returns the new FEN string representing this transformed view.
290
+ This is often needed to convert board_to_fen output to Stockfish compatible FEN.
291
+ """
292
+ try:
293
+ parts = fen_string.strip().split(' ')
294
+ if len(parts) != 6:
295
+ raise ValueError("FEN string must have 6 space-separated fields (board, turn, castling, ep, halfmove, fullmove).")
296
+ board_part = parts[0]
297
+ other_parts = parts[1:]
298
+
299
+ rank_strings = board_part.split('/')
300
+ if len(rank_strings) != 8:
301
+ raise ValueError("FEN board part must have 8 ranks separated by '/'.")
302
+
303
+ original_board = [self._expand_fen_rank(r) for r in rank_strings]
304
+ transformed_board = [[' ' for _ in range(8)] for _ in range(8)]
305
+
306
+ for r in range(8):
307
+ for c in range(8):
308
+ transformed_board[7 - r][7 - c] = original_board[r][c]
309
+
310
+ new_rank_strings = [self._compress_fen_rank(row) for row in transformed_board]
311
+ new_board_part = "/".join(new_rank_strings)
312
+
313
+ return " ".join([new_board_part] + other_parts)
314
+
315
+ except Exception as e:
316
+ logger.error(f"Error processing FEN for inversion/mirroring: {e}. Input: '{fen_string}'")
317
+ return f"Error processing FEN: {e}"
318
+
319
+ def _add_fen_game_state(self, board_placement: str,
320
+ side_to_move: str,
321
+ castling: str = "-",
322
+ en_passant: str = "-",
323
+ halfmove_clock: int = 0,
324
+ fullmove_number: int = 1) -> str:
325
+ """
326
+ Appends standard game state information to a FEN board placement string.
327
+ """
328
+ side_to_move_lower = str(side_to_move).lower()
329
+ if side_to_move_lower not in ['w', 'b']:
330
+ return f"Error: side_to_move must be 'w' or 'b', received '{side_to_move}'"
331
+
332
+ try:
333
+ halfmove_clock = int(halfmove_clock)
334
+ fullmove_number = int(fullmove_number)
335
+ if halfmove_clock < 0:
336
+ raise ValueError("halfmove_clock cannot be negative.")
337
+ if fullmove_number < 1:
338
+ raise ValueError("fullmove_number must be 1 or greater.")
339
+ except (ValueError, TypeError):
340
+ return (f"Error: halfmove_clock ('{halfmove_clock}') and "
341
+ f"fullmove_number ('{fullmove_number}') must be valid integers "
342
+ f"(non-negative and positive respectively).")
343
+
344
+ full_fen = (f"{board_placement} {side_to_move_lower} {castling} "
345
+ f"{en_passant} {halfmove_clock} {fullmove_number}")
346
+ return full_fen
347
+
348
+ def forward(self, image_path: str, player_turn: str) -> str:
349
+ try:
350
+ board_placement = get_fen_from_image_path(image_path)
351
+
352
+ # Add game state to the board placement
353
+ board_fen_with_state = self._add_fen_game_state(board_placement, player_turn)
354
+
355
+ # Inversion makes board_to_fen output Stockfish compatible
356
+ board_fen_inverted = self._invert_mirror_fen(board_fen_with_state)
357
+
358
+ logger.info(f"Generated FEN from image '{image_path}': {board_fen_inverted}")
359
+ return board_fen_inverted
360
+ except Exception as e:
361
+ logger.error(f"Error generating FEN from image '{image_path}': {e}")
362
+ return f"Error generating FEN from image: {e}"
363
+