import os import gradio as gr import random import chess import chess.svg from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline token = os.environ['auth_token'] tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token) model = AutoModelForSequenceClassification.from_pretrained('jrahn/chessv3', use_auth_token=token) pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer) empty_field = '0' board_split = ' | ' nums = {str(n): empty_field * n for n in range(1, 9)} nums_rev = {v:k for k,v in reversed(nums.items())} def encode_fen(fen): # decompress fen representation # prepare for sub-word tokenization fen_board, fen_rest = fen.split(' ', 1) for n in nums: fen_board = fen_board.replace(n, nums[n]) fen_board = '+' + fen_board fen_board = fen_board.replace('/', ' +') return board_split.join([fen_board, fen_rest]) def decode_fen_repr(fen_repr): fen_board, fen_rest = fen_repr.split(board_split, 1) for n in nums_rev: fen_board = fen_board.replace(n, nums_rev[n]) fen_board = fen_board.replace(' +', '/') fen_board = fen_board.replace('+', '') return ' '.join([fen_board, fen_rest]) def predict_move(fen, top_k=3): fen_prep = encode_fen(fen) preds = pipe(fen_prep, top_k=top_k) weights = [p['score'] for p in preds] p = random.choices(preds, weights=weights)[0] # discard illegal moves (https://python-chess.readthedocs.io/en/latest/core.html#chess.Board.legal_moves), then select top_k return p['label'] def btn_load(inp_fen): board = chess.Board() with open('board.svg', 'w') as f: f.write(str(chess.svg.board(board))) return 'board.svg', board.fen(), '' def btn_play(inp_fen, inp_move, inp_notation, inp_k): board = chess.Board(inp_fen) if inp_move: if inp_notation == 'UCI': mv = chess.Move.from_uci(inp_move) #board.push_uci(inp_move) elif inp_notation == 'SAN': mv = board.parse_san(inp_move) #chess.Move.from_san(inp_move) #board.push_san(inp_move) else: mv = chess.Move.from_uci(predict_move(board.fen(), top_k=inp_k)) if mv in board.legal_moves: board.push(mv) else: raise ValueError(f'Illegal Move: {str(mv)} @ {board.fen()}') with open('board.svg', 'w') as f: f.write(str(chess.svg.board(board, lastmove=mv))) return 'board.svg', board.fen(), '' with gr.Blocks() as block: gr.Markdown( ''' # Play YoloChess - Policy Network v0.3 110M Parameter Transformer (BERT-base architecture) trained for text classification from scratch on expert games in modified FEN notation. ''' ) with gr.Row() as row: with gr.Column(): with gr.Row(): move = gr.Textbox(label='human player move') notation = gr.Radio(["SAN", "UCI"], value="SAN", label='move notation') fen = gr.Textbox(value=chess.Board().fen(), label='FEN') top_k = gr.Number(value=3, label='pick from top_k moves', precision=0) with gr.Row(): load_btn = gr.Button("Load") play_btn = gr.Button("Play") gr.Markdown( ''' - Click "Load" button to start and reset board. - Click "Play" button to get Engine move. - Enter a "human player move" in UCI or SAN notation and click "Play" to move a piece. - Output "ERROR" generally occurs on illegal moves (Human or Engine). - Enter "FEN" to start from a custom position. ''' ) with gr.Column(): position_output = gr.Image(label='board') load_btn.click(fn=btn_load, inputs=fen, outputs=[position_output, fen, move]) play_btn.click(fn=btn_play, inputs=[fen, move, notation, top_k], outputs=[position_output, fen, move]) block.launch()