Spaces:
Sleeping
Sleeping
| """ | |
| Gradio interface for plotting attention. | |
| """ | |
| import chess | |
| import chess.pgn | |
| import io | |
| import gradio as gr | |
| import os | |
| import torch | |
| from lczerolens import LczeroBoard, LczeroModel, Lens, InputEncoding | |
| from demo import constants | |
| from demo.utils import get_info | |
| def get_model(model_name: str): | |
| return LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name)) | |
| def get_activations(model: LczeroModel, board: LczeroBoard, input_encoding: InputEncoding): | |
| lens = Lens.from_name("activation", "block\d/conv2/relu") | |
| with torch.no_grad(): | |
| results = lens.analyse(model, board, model_kwargs={"input_encoding": input_encoding}) | |
| return [results[f"block{i}/conv2/relu_output"][0] for i in range(len(results))] | |
| def get_board(game_pgn:str, board_fen:str): | |
| if game_pgn: | |
| try: | |
| board = LczeroBoard() | |
| pgn = io.StringIO(game_pgn) | |
| game = chess.pgn.read_game(pgn) | |
| for move in game.mainline_moves(): | |
| board.push(move) | |
| except Exception as e: | |
| print(e) | |
| gr.Warning("Error parsing PGN, using starting position.") | |
| board = LczeroBoard() | |
| else: | |
| try: | |
| board = LczeroBoard(board_fen) | |
| except Exception as e: | |
| print(e) | |
| gr.Warning("Invalid FEN, using starting position.") | |
| board = LczeroBoard() | |
| return board | |
| def render_activations(board: LczeroBoard, activations, layer_index:int, channel_index:int): | |
| if layer_index >= len(activations): | |
| safe_layer_index = len(activations) - 1 | |
| gr.Warning(f"Layer index {layer_index} out of range, using last layer ({safe_layer_index}).") | |
| else: | |
| safe_layer_index = layer_index | |
| if channel_index >= activations[safe_layer_index].shape[0]: | |
| safe_channel_index = activations[safe_layer_index].shape[0] - 1 | |
| gr.Warning(f"Channel index {channel_index} out of range, using last channel ({safe_channel_index}).") | |
| else: | |
| safe_channel_index = channel_index | |
| heatmap = activations[safe_layer_index][safe_channel_index].view(64) | |
| board.render_heatmap( | |
| heatmap, | |
| save_to=f"{constants.FIGURE_DIRECTORY}/activations.svg", | |
| ) | |
| return f"{constants.FIGURE_DIRECTORY}/activations_board.svg", f"{constants.FIGURE_DIRECTORY}/activations_colorbar.svg" | |
| def initial_load(model_name: str, board_fen: str, game_pgn: str, input_encoding: InputEncoding, layer_index: int, channel_index: int): | |
| model = get_model(model_name) | |
| board = get_board(game_pgn, board_fen) | |
| activations = get_activations(model, board, input_encoding) | |
| info = get_info(model, board) | |
| plots = render_activations(board, activations, layer_index, channel_index) | |
| return model, board, activations, info, *plots | |
| def on_board_change(model: LczeroModel, game_pgn: str, board_fen: str, input_encoding: InputEncoding, layer_index: int, channel_index: int): | |
| board = get_board(game_pgn, board_fen) | |
| activations = get_activations(model, board, input_encoding) | |
| info = get_info(model, board) | |
| plots = render_activations(board, activations, layer_index, channel_index) | |
| return board, activations, info, *plots | |
| def on_model_change(model_name: str, board: LczeroBoard, input_encoding: InputEncoding, layer_index: int, channel_index: int): | |
| model = get_model(model_name) | |
| activations = get_activations(model, board, input_encoding) | |
| info = get_info(model, board) | |
| plots = render_activations(board, activations, layer_index, channel_index) | |
| return model, activations, info, *plots | |
| def on_input_encoding_change(model: LczeroModel, board: LczeroBoard, input_encoding: InputEncoding, layer_index: int, channel_index: int): | |
| activations = get_activations(model, board, input_encoding) | |
| info = get_info(model, board) | |
| plots = render_activations(board, activations, layer_index, channel_index) | |
| return activations, info, *plots | |
| with gr.Blocks() as interface: | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| gr.Markdown( | |
| "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)." | |
| ) | |
| game_pgn = gr.Textbox( | |
| label="Game PGN", | |
| lines=1, | |
| value="", | |
| ) | |
| board_fen = gr.Textbox( | |
| label="Board FEN", | |
| lines=1, | |
| max_lines=1, | |
| value=chess.STARTING_FEN, | |
| ) | |
| input_encoding = gr.Radio( | |
| label="Input encoding", | |
| choices=[ | |
| ("classical", InputEncoding.INPUT_CLASSICAL_112_PLANE), | |
| ("repeated", InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED), | |
| ("no history repeated", InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_REPEATED), | |
| ("no history zeros", InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_ZEROS) | |
| ], | |
| value=InputEncoding.INPUT_CLASSICAL_112_PLANE, | |
| ) | |
| model_name = gr.Dropdown( | |
| label="Model", | |
| choices=constants.ONNX_MODEL_NAMES, | |
| ) | |
| with gr.Group(): | |
| info = gr.Textbox(label="Info", lines=1, value="") | |
| with gr.Group(): | |
| layer_index = gr.Slider( | |
| label="Layer index", | |
| minimum=0, | |
| maximum=19, | |
| step=1, | |
| value=0, | |
| ) | |
| channel_index = gr.Slider( | |
| label="Channel index", | |
| minimum=0, | |
| maximum=200, | |
| step=1, | |
| value=0, | |
| ) | |
| with gr.Column(): | |
| image_board = gr.Image(label="Board", interactive=False) | |
| colorbar = gr.Image(label="Colorbar", interactive=False) | |
| model = gr.State(value=None) | |
| board = gr.State(value=None) | |
| activations = gr.State(value=None) | |
| interface.load( | |
| initial_load, | |
| inputs=[model_name, game_pgn, board_fen, input_encoding, layer_index, channel_index], | |
| outputs=[model, board, activations, info, image_board, colorbar], | |
| concurrency_limit=1, | |
| concurrency_id="trace_queue" | |
| ) | |
| game_pgn.submit( | |
| on_board_change, | |
| inputs=[model, game_pgn, board_fen, input_encoding, layer_index, channel_index], | |
| outputs=[board, activations, info, image_board, colorbar], | |
| concurrency_id="trace_queue" | |
| ) | |
| board_fen.submit( | |
| on_board_change, | |
| inputs=[model, game_pgn, board_fen, input_encoding, layer_index, channel_index], | |
| outputs=[board, activations, info, image_board, colorbar], | |
| concurrency_id="trace_queue" | |
| ) | |
| model_name.change( | |
| on_model_change, | |
| inputs=[model_name, board, input_encoding, layer_index, channel_index], | |
| outputs=[model, activations, info, image_board, colorbar], | |
| concurrency_id="trace_queue" | |
| ) | |
| input_encoding.change( | |
| on_input_encoding_change, | |
| inputs=[model, board, input_encoding, layer_index, channel_index], | |
| outputs=[activations, info, image_board, colorbar], | |
| concurrency_id="trace_queue" | |
| ) | |
| layer_index.change( | |
| render_activations, | |
| inputs=[board, activations, input_encoding, layer_index, channel_index], | |
| outputs=[image_board, colorbar], | |
| ) | |
| channel_index.change( | |
| render_activations, | |
| inputs=[board, activations, input_encoding, layer_index, channel_index], | |
| outputs=[image_board, colorbar], | |
| ) | |