Spaces:
Sleeping
Sleeping
| """ | |
| Gradio interface for plotting attention. | |
| """ | |
| import chess | |
| import chess.pgn | |
| import io | |
| import gradio as gr | |
| import os | |
| from lczerolens import LczeroBoard, LczeroModel, Lens, InputEncoding | |
| from demo import constants | |
| from demo.utils import get_info | |
| def get_model(model_name: str): | |
| return LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name)) | |
| def get_gradients(model: LczeroModel, board: LczeroBoard, input_encoding: InputEncoding, target: str): | |
| lens = Lens.from_name("gradient") | |
| def init_target(model): | |
| if target == "best_move": | |
| return getattr(model, "output/policy").output.max(dim=1).values | |
| else: | |
| wdl_index = {"win": 0, "draw": 1, "loss": 2}[target] | |
| return getattr(model, "output/wdl").output[:, wdl_index] | |
| results = lens.analyse(model, board, init_target=init_target, model_kwargs={"input_encoding": input_encoding}) | |
| return results["input_grad"] | |
| def get_board(game_pgn:str, board_fen:str): | |
| if game_pgn: | |
| try: | |
| board = LczeroBoard() | |
| pgn = io.StringIO(game_pgn) | |
| game = chess.pgn.read_game(pgn) | |
| for move in game.mainline_moves(): | |
| board.push(move) | |
| except Exception as e: | |
| print(e) | |
| gr.Warning("Error parsing PGN, using starting position.") | |
| board = LczeroBoard() | |
| else: | |
| try: | |
| board = LczeroBoard(board_fen) | |
| except Exception as e: | |
| print(e) | |
| gr.Warning("Invalid FEN, using starting position.") | |
| board = LczeroBoard() | |
| return board | |
| def render_gradients(board: LczeroBoard, gradients, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index:int): | |
| if average_over_planes: | |
| heatmap = gradients[0, begin_average_index:end_average_index].mean(dim=0).view(64) | |
| else: | |
| heatmap = gradients[0, plane_index].view(64) | |
| board.render_heatmap( | |
| heatmap, | |
| save_to=f"{constants.FIGURE_DIRECTORY}/gradients.svg", | |
| ) | |
| return f"{constants.FIGURE_DIRECTORY}/gradients_board.svg", f"{constants.FIGURE_DIRECTORY}/gradients_colorbar.svg" | |
| def initial_load(model_name: str, board_fen: str, game_pgn: str, input_encoding: InputEncoding, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): | |
| model = get_model(model_name) | |
| board = get_board(game_pgn, board_fen) | |
| gradients = get_gradients(model, board, input_encoding, target) | |
| info = get_info(model, board) | |
| plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) | |
| return model, board, gradients, info, *plots | |
| def on_board_change(model: LczeroModel, game_pgn: str, board_fen: str, input_encoding: InputEncoding, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): | |
| board = get_board(game_pgn, board_fen) | |
| gradients = get_gradients(model, board, input_encoding, target) | |
| info = get_info(model, board) | |
| plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) | |
| return board, gradients, info, *plots | |
| def on_model_change(model_name: str, board: LczeroBoard, input_encoding: InputEncoding, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): | |
| model = get_model(model_name) | |
| gradients = get_gradients(model, board, input_encoding, target) | |
| info = get_info(model, board) | |
| plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) | |
| return model, gradients, info, *plots | |
| def on_input_encoding_change(model: LczeroModel, board: LczeroBoard, input_encoding: InputEncoding, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): | |
| gradients = get_gradients(model, board, input_encoding, target) | |
| plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) | |
| return gradients, *plots | |
| def on_target_change(model: LczeroModel, board: LczeroBoard, input_encoding: InputEncoding, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): | |
| gradients = get_gradients(model, board, input_encoding, target) | |
| plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) | |
| return gradients, *plots | |
| with gr.Blocks() as interface: | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| gr.Markdown( | |
| "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)." | |
| ) | |
| game_pgn = gr.Textbox( | |
| label="Game PGN", | |
| lines=1, | |
| value="", | |
| ) | |
| board_fen = gr.Textbox( | |
| label="Board FEN", | |
| lines=1, | |
| max_lines=1, | |
| value=chess.STARTING_FEN, | |
| ) | |
| input_encoding = gr.Radio( | |
| label="Input encoding", | |
| choices=[ | |
| ("classical", InputEncoding.INPUT_CLASSICAL_112_PLANE), | |
| ("repeated", InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED), | |
| ("no history repeated", InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_REPEATED), | |
| ("no history zeros", InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_ZEROS) | |
| ], | |
| value=InputEncoding.INPUT_CLASSICAL_112_PLANE, | |
| ) | |
| model_name = gr.Dropdown( | |
| label="Model", | |
| choices=constants.ONNX_MODEL_NAMES, | |
| ) | |
| with gr.Group(): | |
| info = gr.Textbox(label="Info", lines=1, value="") | |
| with gr.Group(): | |
| target = gr.Radio( | |
| ["win", "draw", "loss", "best_move"], label="Target", | |
| value="win", | |
| ) | |
| average_over_planes = gr.Checkbox(label="Average over Planes", value=False) | |
| with gr.Accordion("Average over planes", open=False): | |
| begin_average_index = gr.Slider( | |
| label="Begin average index", | |
| minimum=0, | |
| maximum=111, | |
| step=1, | |
| value=0, | |
| ) | |
| end_average_index = gr.Slider( | |
| label="End average index", | |
| minimum=0, | |
| maximum=111, | |
| step=1, | |
| value=111, | |
| ) | |
| plane_index = gr.Slider( | |
| label="Plane index", | |
| minimum=0, | |
| maximum=111, | |
| step=1, | |
| value=0, | |
| ) | |
| with gr.Column(): | |
| image_board = gr.Image(label="Board", interactive=False) | |
| colorbar = gr.Image(label="Colorbar", interactive=False) | |
| model = gr.State(value=None) | |
| board = gr.State(value=None) | |
| gradients = gr.State(value=None) | |
| interface.load( | |
| initial_load, | |
| inputs=[model_name, game_pgn, board_fen, input_encoding, target, average_over_planes, begin_average_index, end_average_index, plane_index], | |
| outputs=[model, board, gradients, info, image_board, colorbar], | |
| concurrency_id="trace_queue" | |
| ) | |
| game_pgn.submit( | |
| on_board_change, | |
| inputs=[model, game_pgn, board_fen, input_encoding, target, average_over_planes, begin_average_index, end_average_index, plane_index], | |
| outputs=[board, gradients, info, image_board, colorbar], | |
| concurrency_id="trace_queue" | |
| ) | |
| board_fen.submit( | |
| on_board_change, | |
| inputs=[model, game_pgn, board_fen, input_encoding, target, average_over_planes, begin_average_index, end_average_index, plane_index], | |
| outputs=[board, gradients, info, image_board, colorbar], | |
| concurrency_id="trace_queue" | |
| ) | |
| model_name.change( | |
| on_model_change, | |
| inputs=[model_name, board, input_encoding, target, average_over_planes, begin_average_index, end_average_index, plane_index], | |
| outputs=[model, gradients, info, image_board, colorbar], | |
| concurrency_id="trace_queue" | |
| ) | |
| input_encoding.change( | |
| on_input_encoding_change, | |
| inputs=[model, board, input_encoding, target, average_over_planes, begin_average_index, end_average_index, plane_index], | |
| outputs=[gradients, image_board, colorbar], | |
| concurrency_id="trace_queue" | |
| ) | |
| target.change( | |
| on_target_change, | |
| inputs=[model, board, input_encoding, target, average_over_planes, begin_average_index, end_average_index, plane_index], | |
| outputs=[gradients, image_board, colorbar], | |
| concurrency_id="trace_queue" | |
| ) | |
| for render_arg in [average_over_planes, begin_average_index, end_average_index, plane_index]: | |
| render_arg.change( | |
| render_gradients, | |
| inputs=[board, gradients, input_encoding, average_over_planes, begin_average_index, end_average_index, plane_index], | |
| outputs=[image_board, colorbar], | |
| ) | |