civerson916 commited on
Commit
bbccf68
·
verified ·
1 Parent(s): 1406343

Simplified tool selection for ChessAgent

Browse files
Files changed (2) hide show
  1. agent.py +2 -3
  2. tools.py +160 -38
agent.py CHANGED
@@ -9,7 +9,7 @@ from smolagents import LiteLLMModel, CodeAgent
9
  from smolagents import GoogleSearchTool, VisitWebpageTool, FinalAnswerTool
10
  from smolagents.local_python_executor import BASE_PYTHON_TOOLS
11
  from tools import GetTaskFileTool, VideoUnderstandingTool, AudioUnderstandingTool
12
- from tools import ChessPiecePlacementTool, ChessGameFenTool, BestChessMoveTool, ConvertChessMoveTool
13
 
14
 
15
  # Base tools may use these to process files
@@ -63,8 +63,7 @@ class ChessAgent:
63
  name="chess_player",
64
  description="Makes a chess move. Give it a query including board image filepath and player turn (black or white).",
65
  add_base_tools=False,
66
- tools=[ChessPiecePlacementTool(),
67
- ChessGameFenTool(settings, OpenRouterModelID.GPT_O4_MINI),
68
  BestChessMoveTool(settings),
69
  ConvertChessMoveTool(settings, OpenRouterModelID.QWEN_3_14B_FREE),
70
  ],
 
9
  from smolagents import GoogleSearchTool, VisitWebpageTool, FinalAnswerTool
10
  from smolagents.local_python_executor import BASE_PYTHON_TOOLS
11
  from tools import GetTaskFileTool, VideoUnderstandingTool, AudioUnderstandingTool
12
+ from tools import ChessBoardFENTool, BestChessMoveTool, ConvertChessMoveTool
13
 
14
 
15
  # Base tools may use these to process files
 
63
  name="chess_player",
64
  description="Makes a chess move. Give it a query including board image filepath and player turn (black or white).",
65
  add_base_tools=False,
66
+ tools=[ChessBoardFENTool(),
 
67
  BestChessMoveTool(settings),
68
  ConvertChessMoveTool(settings, OpenRouterModelID.QWEN_3_14B_FREE),
69
  ],
tools.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import json
3
  import logging
4
  logger = logging.getLogger(__name__)
 
5
  import requests
6
  import shutil
7
  from typing import Any
@@ -125,10 +126,12 @@ class ConvertChessMoveTool(BaseCustomTool):
125
  self.model = model
126
 
127
  def forward(self, piece_placement: str, move: str) -> str:
128
- move_message = f"""Convert this chess move from coordinate notation to algebraic
129
- notation: {move}. Use the following {piece_placement}. Do not provide any additional
130
- thinking or commentary in the response, the algebraic notation only."""
131
- messages = [{ "content": move_message,"role": "user"}]
 
 
132
  response = completion(
133
  model=self.model,
134
  temperature=0.0,
@@ -158,45 +161,164 @@ class BestChessMoveTool(BaseCustomTool):
158
  except Exception as e:
159
  logger.error(f"Error getting chess evaluation: {e}")
160
 
161
- class ChessGameFenTool(BaseCustomTool):
162
- name = "ChessGameFen"
163
- description = "Get a FEN representation given chess piece placement and a move."
164
  inputs = {
165
- "piece_placement": {"type": "string", "description": "The chess piece placement in plain text"},
166
  "player_turn": {"type": "string",
167
- "description": "The player with the next turn in the match, black or white"},
168
  }
169
  output_type = "string"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- def __init__(self, settings, model):
172
- super().__init__(settings)
173
- self.model = model
 
 
 
 
 
174
 
175
- def forward(self, piece_placement: str, player_turn: str) -> str:
176
- """Use the tool."""
177
- fen_message = f"""First, invert the FEN string making it (rank 1 at top)
178
- and mirrored (file 'a' at right), but formatted according to the strict FEN
179
- standard (reading from the top-left of the standard board layout). Then,
180
- assuming {player_turn} has the next turn, Use the following placement
181
- {piece_placement} and provide the board state as FEN. Do not provide any
182
- additional thinking or commentary in the response, the FEN only."""
183
- messages = [{ "content": fen_message,"role": "user"}]
184
- response = completion(
185
- model=self.model,
186
- temperature=0.0,
187
- messages=messages,
188
- api_key=self.settings.openrouter_api_key.get_secret_value()
189
- )
190
- return response.choices[0].message.content
191
 
192
- class ChessPiecePlacementTool(Tool):
193
- name = "ChessPiecePlacement"
194
- description = "Get chess piece placement information from an image of a board."
195
- inputs = {
196
- "image_path": {"type": "string", "description": "The local file of the chess board image"},
197
- }
198
- output_type = "string"
199
-
200
- def forward(self, image_path: str) -> str:
201
- return get_fen_from_image_path(image_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
 
2
  import json
3
  import logging
4
  logger = logging.getLogger(__name__)
5
+ import re
6
  import requests
7
  import shutil
8
  from typing import Any
 
126
  self.model = model
127
 
128
  def forward(self, piece_placement: str, move: str) -> str:
129
+ move_message = (
130
+ f"Convert this chess move from coordinate notation to algebraic "
131
+ f"notation: {move}. Use the following {piece_placement}. Do not provide any additional "
132
+ "thinking or commentary in the response, the algebraic notation only."
133
+ )
134
+ messages = [{ "content": move_message, "role": "user"}]
135
  response = completion(
136
  model=self.model,
137
  temperature=0.0,
 
161
  except Exception as e:
162
  logger.error(f"Error getting chess evaluation: {e}")
163
 
164
+ class ChessBoardFENTool(Tool):
165
+ name = "ChessBoardFEN"
166
+ description = "Get the FEN representation from an image of a chess board and a player turn."
167
  inputs = {
168
+ "image_path": {"type": "string", "description": "The local file of the chess board image"},
169
  "player_turn": {"type": "string",
170
+ "description": "The player with the next turn in the match, black or white"}
171
  }
172
  output_type = "string"
173
+
174
+ def _expand_fen_rank(self, rank_str):
175
+ """
176
+ Expands a single rank string from FEN notation (e.g., 'p2b4')
177
+ into a list of 8 characters representing the squares.
178
+ Uses ' ' for empty squares.
179
+ """
180
+ expanded_rank = []
181
+ for char in rank_str:
182
+ if char.isdigit():
183
+ # Add number of empty squares specified by the digit
184
+ expanded_rank.extend([' '] * int(char))
185
+ else:
186
+ # Add the piece character
187
+ expanded_rank.append(char)
188
+ # Validate rank length
189
+ if len(expanded_rank) != 8:
190
+ raise ValueError(f"Invalid FEN rank string (length != 8): {rank_str}")
191
+ return expanded_rank
192
 
193
+ def _compress_fen_rank(self, rank_list):
194
+ """
195
+ Compresses a list of 8 characters (representing a rank)
196
+ back into FEN rank notation (e.g., turns [' ', 'K', ...] into '1K6').
197
+ Assumes ' ' represents an empty square.
198
+ """
199
+ if len(rank_list) != 8:
200
+ raise ValueError(f"Invalid rank list (length != 8): {rank_list}")
201
 
202
+ compressed_rank = ""
203
+ empty_count = 0
204
+ for char in rank_list:
205
+ if char == ' ':
206
+ empty_count += 1
207
+ else:
208
+ # If we encountered a piece after empty squares, add the count
209
+ if empty_count > 0:
210
+ compressed_rank += str(empty_count)
211
+ empty_count = 0
212
+ # Add the piece
213
+ compressed_rank += char
214
+ # If the rank ends with empty squares, add the final count
215
+ if empty_count > 0:
216
+ compressed_rank += str(empty_count)
217
+ return compressed_rank
218
 
219
+ def _invert_mirror_fen(self, fen_string):
220
+ """
221
+ Takes a FEN string, inverts the board vertically, mirrors it horizontally,
222
+ and returns the new FEN string representing this transformed view.
223
+ The other FEN fields (turn, castling, etc.) are preserved.
224
+ """
225
+ try:
226
+ # 1. Split FEN into parts
227
+ parts = fen_string.strip().split(' ')
228
+ if len(parts) != 6:
229
+ raise ValueError("FEN string must have 6 space-separated fields.")
230
+ board_part = parts[0]
231
+ other_parts = parts[1:] # Side-to-move, castling, ep, halfmove, fullmove
232
+
233
+ # 2. Parse the board part into an 8x8 representation
234
+ rank_strings = board_part.split('/')
235
+ if len(rank_strings) != 8:
236
+ raise ValueError("FEN board part must have 8 ranks separated by '/'.")
237
+
238
+ # original_board[0] corresponds to rank 8, original_board[7] to rank 1
239
+ original_board = [self._expand_fen_rank(r) for r in rank_strings]
240
+
241
+ # 3. Create a new empty 8x8 board for the transformed state
242
+ # Using ' ' as the placeholder for empty squares
243
+ transformed_board = [[' ' for _ in range(8)] for _ in range(8)]
244
+
245
+ # 4. Apply the inversion (vertical flip) and mirror (horizontal flip)
246
+ for r in range(8): # Iterate through original rows (ranks 8 down to 1)
247
+ for c in range(8): # Iterate through original columns (files a to h)
248
+ # The piece at original [r][c] moves to transformed [7-r][7-c]
249
+ transformed_board[7 - r][7 - c] = original_board[r][c]
250
+
251
+ # 5. Generate the new FEN board string from the transformed board
252
+ # Read ranks from top (index 0 = rank 8) to bottom (index 7 = rank 1)
253
+ new_rank_strings = [self._compress_fen_rank(row) for row in transformed_board]
254
+ new_board_part = "/".join(new_rank_strings)
255
+
256
+ # 6. Reassemble the full FEN string
257
+ return " ".join([new_board_part] + other_parts)
258
+
259
+ except Exception as e:
260
+ # Return error message if parsing or processing fails
261
+ return f"Error processing FEN: {e}. Input: '{fen_string}'"
262
+
263
+ def _add_fen_game_state(self, board_placement,
264
+ side_to_move,
265
+ castling="-",
266
+ en_passant="-",
267
+ halfmove_clock=0,
268
+ fullmove_number=1):
269
+ """
270
+ Appends standard game state information to a FEN board placement string.
271
+
272
+ Args:
273
+ board_placement (str): The board layout part of the FEN string
274
+ (e.g., "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR").
275
+ side_to_move (str): The active color ('w' for White, 'b' for Black).
276
+ Case-insensitive, will be converted to lowercase.
277
+ castling (str, optional): Castling availability string (e.g., "KQkq", "-").
278
+ Defaults to "-".
279
+ en_passant (str, optional): En passant target square string (e.g., "e3", "-").
280
+ Defaults to "-".
281
+ halfmove_clock (int, optional): The number of halfmoves since the last
282
+ capture or pawn advance. Defaults to 0.
283
+ fullmove_number (int, optional): The number of the full move. Starts at 1
284
+ and increments after Black's move. Defaults to 1.
285
+
286
+ Returns:
287
+ str: The complete FEN string including the game state,
288
+ or an error message string if inputs are invalid.
289
+ """
290
+ # Validate side_to_move
291
+ side_to_move_lower = str(side_to_move).lower()
292
+ if side_to_move_lower not in ['w', 'b']:
293
+ return f"Error: side_to_move must be 'w' or 'b', received '{side_to_move}'"
294
+
295
+ # Validate clock values (should be non-negative integers, fullmove >= 1)
296
+ try:
297
+ halfmove_clock = int(halfmove_clock)
298
+ fullmove_number = int(fullmove_number)
299
+ if halfmove_clock < 0:
300
+ raise ValueError("halfmove_clock cannot be negative.")
301
+ if fullmove_number < 1:
302
+ raise ValueError("fullmove_number must be 1 or greater.")
303
+ except (ValueError, TypeError):
304
+ return (f"Error: halfmove_clock ('{halfmove_clock}') and "
305
+ f"fullmove_number ('{fullmove_number}') must be valid integers "
306
+ f"(non-negative and positive respectively).")
307
+
308
+ # Assemble the full FEN string using the validated/defaulted values
309
+ # Note: castling and en_passant strings are used directly as passed or defaulted.
310
+ # More complex validation could be added for them if needed.
311
+ full_fen = (f"{board_placement} {side_to_move_lower} {castling} "
312
+ f"{en_passant} {halfmove_clock} {fullmove_number}")
313
+
314
+ return full_fen
315
+
316
+ def forward(self, image_path: str, player_turn: str) -> str:
317
+ board_placement = get_fen_from_image_path(image_path)
318
+
319
+ # Inversion makes board_to_fen output Stockfish compatible
320
+ board_fen = self._add_fen_game_state(board_placement, player_turn)
321
+ board_fen_inverted = self._invert_mirror_fen(board_fen)
322
+
323
+ return board_fen_inverted
324