File size: 18,500 Bytes
c38cd66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
# tools.py
import os
import json
import logging
import re
import requests
import shutil
import urllib.parse
import pandas as pd # For ExcelParsingTool
from board_to_fen.predict import get_fen_from_image_path # For ChessBoardFENTool
from google import genai
from google.genai import types
# from litellm import completion # Removed - no longer used for ConvertChessMoveTool
from smolagents import Tool
from settings import Settings
from models import GoogleModelID # Import GoogleModelID

logger = logging.getLogger(__name__)

class BaseCustomTool(Tool):
    """Base class for custom tools to easily pass settings."""
    def __init__(self, settings: Settings):
        super().__init__()
        self.settings = settings
        
class GetTaskFileTool(BaseCustomTool):
    name = "get_task_file_tool"
    description = """If a file_name is provided in the task, use this tool to download the file associated with a given task_id. Returns the absolute file path to the downloaded file. This path can then be used by other tools like AudioUnderstandingTool or ExcelParsingTool. Example: get_task_file_tool(task_id="1234", file_name="example.mp3")"""
    inputs = {
        "task_id": {"type": "string", "description": "Task ID (required)"},
        "file_name": {"type": "string", "description": "File name (required)"},
    }
    output_type = "string"

    def __init__(self, settings: Settings):
        super().__init__(settings)
        self.directory_name = "downloads"
        self.create_dir()
        
    def forward(self, task_id: str, file_name: str) -> str:
        try:
            # Use the scoring API base URL for file downloads
            response = requests.get(f"{self.settings.scoring_api_base_url}/files/{task_id}", timeout=15)
            response.raise_for_status()
            
            # Ensure the downloads directory exists
            os.makedirs(self.directory_name, exist_ok=True)
            
            file_path = os.path.join(self.directory_name, file_name)
            with open(file_path, 'wb') as file:
                file.write(response.content)
            
            absolute_file_path = os.path.abspath(file_path)
            logger.info(f"Successfully downloaded file '{file_name}' for task_id {task_id} to {absolute_file_path}")
            return absolute_file_path
        except requests.exceptions.RequestException as e:
            logger.error(f"Error downloading file for task_id {task_id} from API: {e}")
            # Fallback to local 'files' directory if API download fails
            local_file_path = os.path.join("files", file_name)
            if os.path.exists(local_file_path):
                destination_path = os.path.join(self.directory_name, file_name)
                os.makedirs(self.directory_name, exist_ok=True)
                shutil.copy2(local_file_path, destination_path)
                absolute_local_file_path = os.path.abspath(destination_path)
                logger.info(f"Copied local fallback file '{file_name}' to {absolute_local_file_path}")
                return absolute_local_file_path
            else:
                logger.error(f"Local fallback file '{local_file_path}' not found.")
                return f"Error: Could not download or find file '{file_name}' for task_id {task_id}. {e}"
        except Exception as e:
            logger.error(f"An unexpected error occurred in GetTaskFileTool: {e}")
            return f"Error: An unexpected error occurred while getting file '{file_name}'. {e}"
        
    def create_dir(self):
        """Creates the download directory if it doesn't exist."""
        if not os.path.exists(self.directory_name):
            os.makedirs(self.directory_name)
            logger.info(f"Directory '{self.directory_name}' created successfully.")
        else:
            logger.debug(f"Directory '{self.directory_name}' already exists.")

class VideoUnderstandingTool(BaseCustomTool):
    name = "video_understanding_tool"
    description = "Analyzes a YouTube video given its URL and a specific prompt/question about its content. Returns a text description or answer from the video. Use this for tasks involving video content. Example: video_understanding_tool(youtube_url=\"https://www.youtube.com/watch?v=VIDEO_ID\", prompt=\"What is the main topic of this video?\")"
    inputs = {
        "youtube_url": {"type": "string", "description": "The URL of the YouTube video"},
        "prompt": {"type": "string", "description": "A question or request regarding the video content"},
    }
    output_type = "string"

    def __init__(self, settings: Settings, model: GoogleModelID):
        super().__init__(settings)
        self.model = model
        # Initialize Google GenAI client with API key
        genai.configure(api_key=self.settings.gemini_api_key.get_secret_value())
        logger.info(f"VideoUnderstandingTool initialized with model: {self.model}")
        
    def forward(self, youtube_url: str, prompt: str) -> str:
        try:
            # Use the genai.GenerativeModel for multimodal content
            model_instance = genai.GenerativeModel(self.model)
            
            # Create a FileData part from the YouTube URL
            video_file_data = types.Part(
                file_data=types.FileData(
                    file_uri=youtube_url,
                    mime_type="video/mp4" # Assuming common video type, adjust if needed
                )
            )
            
            # Generate content with both video and text prompt
            response = model_instance.generate_content(
                contents=[video_file_data, types.Part(text=prompt)]
            )
            
            return response.text
        except Exception as e:
            logger.error(f"Error understanding video from URL '{youtube_url}': {e}")
            return f"Error understanding video: {e}"

class AudioUnderstandingTool(BaseCustomTool):
    name = "audio_understanding_tool"
    description = "Analyzes a local audio file given its file path and a specific prompt/question about its content. Returns a text description or answer from the audio. Use this for tasks involving audio files. You must first download the audio file using 'get_task_file_tool'. Example: audio_understanding_tool(file_path=\"/tmp/audio.mp3\", prompt=\"What are the key ingredients mentioned?\")"
    inputs = {
        "file_path": {"type": "string", "description": "The local file path of the audio file (e.g., from get_task_file_tool)."},
        "prompt": {"type": "string", "description": "A question or request regarding the audio content."},
    }
    output_type = "string"

    def __init__(self, settings: Settings, model: GoogleModelID):
        super().__init__(settings)
        self.model = model
        # Initialize Google GenAI client with API key
        genai.configure(api_key=self.settings.gemini_api_key.get_secret_value())
        logger.info(f"AudioUnderstandingTool initialized with model: {self.model}")

    def forward(self, file_path: str, prompt: str) -> str:
        try:
            # Upload the local audio file to Gemini Files API
            mp3_file = genai.upload_file(path=file_path)
            logger.info(f"Uploaded audio file: {mp3_file.uri}")

            # Use the genai.GenerativeModel for multimodal content
            model_instance = genai.GenerativeModel(self.model)
            
            # Generate content with both audio file and text prompt
            response = model_instance.generate_content(
                contents=[mp3_file, types.Part(text=prompt)]
            )
            
            # Delete the uploaded file from Gemini Files API (optional, but good practice)
            # genai.delete_file(mp3_file.name) # This might require a separate API call or context manager

            return response.text
        except Exception as e:
            logger.error(f"Error understanding audio from file '{file_path}': {e}")
            return f"Error understanding audio: {e}"

class ExcelParsingTool(BaseCustomTool):
    name = "excel_parsing_tool"
    description = "Parses an Excel (.xlsx) file given its local file path. It reads the first sheet and returns its content as a CSV formatted string. Use this for tasks involving Excel data. You must first download the Excel file using 'get_task_file_tool'. Example: excel_parsing_tool(file_path=\"/tmp/sales_data.xlsx\")"
    inputs = {"file_path": {"type": "string", "description": "The local path to the Excel file (e.g., from get_task_file_tool)."}}
    output_type = "string"

    def __init__(self, settings: Settings):
        super().__init__(settings)
        logger.info("ExcelParsingTool initialized.")

    def forward(self, file_path: str) -> str:
        """
        Reads an Excel file and returns its content (first sheet) as a CSV string.
        """
        try:
            # Ensure the file exists before trying to read
            if not os.path.exists(file_path):
                raise FileNotFoundError(f"Excel file not found at: {file_path}")

            df = pd.read_excel(file_path)
            csv_content = df.to_csv(index=False)
            logger.info(f"Successfully parsed Excel file: {file_path}")
            return csv_content
        except Exception as e:
            logger.error(f"Error parsing Excel file {file_path}: {e}")
            return f"Error parsing Excel file: {e}"

class ConvertChessMoveTool(BaseCustomTool):
    name = "convert_chess_move_tool"
    description = "Converts a chess move from coordinate notation (e.g., 'e2e4') to standard algebraic notation. Requires the current piece placement as plain text. Example: convert_chess_move_tool(piece_placement=\"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR\", move=\"e2e4\")"
    inputs = {
        "piece_placement": {"type": "string", "description": "The chess piece placement in plain text (e.g., a FEN board part)."},
        "move": {"type": "string", "description": "The move in coordinate notation (e.g., 'e2e4')"},
    }
    output_type = "string"

    def __init__(self, settings: Settings, model: GoogleModelID): # Changed model type to GoogleModelID
        super().__init__(settings)
        self.model = model
        genai.configure(api_key=self.settings.gemini_api_key.get_secret_value()) # Configure genai for this tool
        logger.info(f"ConvertChessMoveTool initialized with model: {self.model}")

    def forward(self, piece_placement: str, move: str) -> str:
        move_message = (
            f"Convert this chess move from coordinate notation to algebraic "
            f"notation: {move}. Use the following board state for context: {piece_placement}. "
            "Do not provide any additional thinking or commentary in the response, "
            "return only the algebraic notation for the move."
            )
        messages = [{ "content": move_message, "role": "user"}]
        try:
            model_instance = genai.GenerativeModel(self.model) # Use genai.GenerativeModel
            response = model_instance.generate_content(
                contents=messages[0]['content'] # Pass content directly
            )
            return response.text
        except Exception as e:
            logger.error(f"Error converting chess move: {e}")
            return f"Error converting chess move: {e}"

class BestChessMoveTool(BaseCustomTool):
    name = "best_chess_move_tool"
    description = "Gets the best chess move in coordinate notation (e.g., 'e2e4') based on a FEN (Forsyth-Edwards Notation) representation of the chess position. Example: best_chess_move_tool(fen=\"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1\")"
    inputs = {
        "fen": {"type": "string", "description": "The FEN (Forsyth-Edwards Notation) representation of the chess position. Example: 'rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1'"},
    }
    output_type = "string"

    def forward(self, fen: str) -> str:
        try:
            url = f"{self.settings.chess_eval_url}?fen={urllib.parse.quote(fen)}&depth=15" # Depth 15 for reasonable accuracy
            response = requests.get(url, timeout=15)
            response.raise_for_status() # Raise HTTPError for bad responses
            
            response_json = response.json()
            if response_json.get('success') == True and 'bestmove' in response_json:
                # Stockfish API often returns "bestmove e2e4 ponder e7e5"
                # We need only the move itself, which is the second part
                best_move = response_json['bestmove'].split()[1]
                logger.info(f"Successfully retrieved best chess move: {best_move} for FEN: {fen}")
                return best_move
            else:
                raise ValueError(f"Stockfish API returned unsuccessful response or missing 'bestmove': {response_json}")
        except Exception as e:
            logger.error(f"Error getting best chess move for FEN '{fen}': {e}")
            return f"Error getting best chess move: {e}"

class ChessBoardFENTool(Tool):
    name = "chess_board_fen_tool"
    description = "Generates the FEN (Forsyth-Edwards Notation) representation from a local image file of a chess board and the player whose turn it is. Returns the FEN string. You must first download the image file using 'get_task_file_tool'. Example: chess_board_fen_tool(image_path=\"/tmp/board.png\", player_turn=\"b\")"
    inputs = {
        "image_path": {"type": "string", "description": "The local file path of the chess board image (e.g., from get_task_file_tool)."},
        "player_turn": {"type": "string", "description": "The player with the next turn in the match, must be 'w' (white) or 'b' (black)."}
    }
    output_type = "string"
    
    def _expand_fen_rank(self, rank_str):
        """Expands a single rank string from FEN notation into a list of 8 characters."""
        expanded_rank = []
        for char in rank_str:
            if char.isdigit():
                expanded_rank.extend([' '] * int(char))
            else:
                expanded_rank.append(char)
        if len(expanded_rank) != 8:
            raise ValueError(f"Invalid FEN rank string (length != 8): {rank_str}")
        return expanded_rank

    def _compress_fen_rank(self, rank_list):
        """Compresses a list of 8 characters (representing a rank) back into FEN rank notation."""
        if len(rank_list) != 8:
            raise ValueError(f"Invalid rank list (length != 8): {rank_list}")

        compressed_rank = ""
        empty_count = 0
        for char in rank_list:
            if char == ' ':
                empty_count += 1
            else:
                if empty_count > 0:
                    compressed_rank += str(empty_count)
                    empty_count = 0
                compressed_rank += char
        if empty_count > 0:
            compressed_rank += str(empty_count)
        return compressed_rank

    def _invert_mirror_fen(self, fen_string: str) -> str:
        """
        Takes a FEN string, inverts the board vertically, mirrors it horizontally,
        and returns the new FEN string representing this transformed view.
        This is often needed to convert board_to_fen output to Stockfish compatible FEN.
        """
        try:
            parts = fen_string.strip().split(' ')
            if len(parts) != 6:
                raise ValueError("FEN string must have 6 space-separated fields (board, turn, castling, ep, halfmove, fullmove).")
            board_part = parts[0]
            other_parts = parts[1:]

            rank_strings = board_part.split('/')
            if len(rank_strings) != 8:
                raise ValueError("FEN board part must have 8 ranks separated by '/'.")

            original_board = [self._expand_fen_rank(r) for r in rank_strings]
            transformed_board = [[' ' for _ in range(8)] for _ in range(8)]

            for r in range(8):
                for c in range(8):
                    transformed_board[7 - r][7 - c] = original_board[r][c]

            new_rank_strings = [self._compress_fen_rank(row) for row in transformed_board]
            new_board_part = "/".join(new_rank_strings)

            return " ".join([new_board_part] + other_parts)

        except Exception as e:
            logger.error(f"Error processing FEN for inversion/mirroring: {e}. Input: '{fen_string}'")
            return f"Error processing FEN: {e}"

    def _add_fen_game_state(self, board_placement: str,
                        side_to_move: str,
                        castling: str = "-",
                        en_passant: str = "-",
                        halfmove_clock: int = 0,
                        fullmove_number: int = 1) -> str:
        """
        Appends standard game state information to a FEN board placement string.
        """
        side_to_move_lower = str(side_to_move).lower()
        if side_to_move_lower not in ['w', 'b']:
            return f"Error: side_to_move must be 'w' or 'b', received '{side_to_move}'"

        try:
            halfmove_clock = int(halfmove_clock)
            fullmove_number = int(fullmove_number)
            if halfmove_clock < 0:
                raise ValueError("halfmove_clock cannot be negative.")
            if fullmove_number < 1:
                raise ValueError("fullmove_number must be 1 or greater.")
        except (ValueError, TypeError):
            return (f"Error: halfmove_clock ('{halfmove_clock}') and "
                    f"fullmove_number ('{fullmove_number}') must be valid integers "
                    f"(non-negative and positive respectively).")

        full_fen = (f"{board_placement} {side_to_move_lower} {castling} "
                    f"{en_passant} {halfmove_clock} {fullmove_number}")
        return full_fen

    def forward(self, image_path: str, player_turn: str) -> str:
        try:
            board_placement = get_fen_from_image_path(image_path)
            
            # Add game state to the board placement
            board_fen_with_state = self._add_fen_game_state(board_placement, player_turn)
            
            # Inversion makes board_to_fen output Stockfish compatible
            board_fen_inverted = self._invert_mirror_fen(board_fen_with_state) 
            
            logger.info(f"Generated FEN from image '{image_path}': {board_fen_inverted}")
            return board_fen_inverted
        except Exception as e:
            logger.error(f"Error generating FEN from image '{image_path}': {e}")
            return f"Error generating FEN from image: {e}"