File size: 3,958 Bytes
313445a
 
 
 
 
b02b814
313445a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b02b814
1e7240f
b02b814
313445a
 
3efeb97
313445a
0671ff1
799a947
b02b814
 
 
 
 
0671ff1
b02b814
0c203c2
 
 
 
b02b814
 
 
 
 
 
 
799a947
 
 
0671ff1
799a947
 
b02b814
 
0671ff1
 
 
799a947
0671ff1
b02b814
 
 
1e7240f
 
 
 
 
 
 
 
 
b02b814
 
 
3efeb97
0671ff1
b02b814
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import gradio as gr
import random
import chess
import chess.svg
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline

token = os.environ['auth_token']

tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token)
model = AutoModelForSequenceClassification.from_pretrained('jrahn/chessv3', use_auth_token=token)
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)

empty_field = '0'
board_split = ' | '
nums = {str(n): empty_field * n for n in range(1, 9)}
nums_rev = {v:k for k,v in reversed(nums.items())}


def encode_fen(fen):
    # decompress fen representation
    # prepare for sub-word tokenization
    fen_board, fen_rest = fen.split(' ', 1)
    for n in nums:
        fen_board = fen_board.replace(n, nums[n])
    fen_board = '+' + fen_board
    fen_board = fen_board.replace('/', ' +')
    return board_split.join([fen_board, fen_rest])

def decode_fen_repr(fen_repr):
    fen_board, fen_rest = fen_repr.split(board_split, 1)
    for n in nums_rev:
        fen_board = fen_board.replace(n, nums_rev[n])
    fen_board = fen_board.replace(' +', '/')
    fen_board = fen_board.replace('+', '')
    return ' '.join([fen_board, fen_rest])

def predict_move(fen, top_k=3):
    fen_prep = encode_fen(fen)
    preds = pipe(fen_prep, top_k=top_k)
    weights = [p['score'] for p in preds]
    p = random.choices(preds, weights=weights)[0]
    # discard illegal moves (https://python-chess.readthedocs.io/en/latest/core.html#chess.Board.legal_moves), then select top_k
    return p['label']

def btn_load(inp_fen):
    board = chess.Board()
    
    with open('board.svg', 'w') as f:
        f.write(str(chess.svg.board(board)))
    return 'board.svg', board.fen(), ''

def btn_play(inp_fen, inp_move, inp_notation, inp_k):
    board = chess.Board(inp_fen)
    
    if inp_move:
        if inp_notation == 'UCI': mv = chess.Move.from_uci(inp_move) #board.push_uci(inp_move)
        elif inp_notation == 'SAN': mv = board.parse_san(inp_move) #chess.Move.from_san(inp_move) #board.push_san(inp_move)
    else:
        mv = chess.Move.from_uci(predict_move(board.fen(), top_k=inp_k))
    
    if mv in board.legal_moves:
        board.push(mv)
    else:
        raise ValueError(f'Illegal Move: {str(mv)} @ {board.fen()}')
        
    with open('board.svg', 'w') as f:
        f.write(str(chess.svg.board(board, lastmove=mv)))
    
    return 'board.svg', board.fen(), ''

with gr.Blocks() as block:
    gr.Markdown(
    '''
    # Play YoloChess - Policy Network v0.3
    110M Parameter Transformer (BERT-base architecture) trained for text classification from scratch on expert games in modified FEN notation. 
    '''
    )
    with gr.Row() as row:
        with gr.Column():
            with gr.Row():
                move = gr.Textbox(label='human player move')
                notation = gr.Radio(["SAN", "UCI"], value="SAN", label='move notation')
            fen = gr.Textbox(value=chess.Board().fen(), label='FEN')
            top_k = gr.Number(value=3, label='pick from top_k moves', precision=0)
            with gr.Row():
                load_btn = gr.Button("Load")
                play_btn = gr.Button("Play")
            gr.Markdown(
                '''
                - Click "Load" button to start and reset board.  
                - Click "Play" button to get Engine move.  
                - Enter a "human player move" in UCI or SAN notation and click "Play" to move a piece.  
                - Output "ERROR" generally occurs on illegal moves (Human or Engine).
                - Enter "FEN" to start from a custom position.
                '''
            )
        with gr.Column():
            position_output = gr.Image(label='board')

    load_btn.click(fn=btn_load, inputs=fen, outputs=[position_output, fen, move])
    play_btn.click(fn=btn_play, inputs=[fen, move, notation, top_k], outputs=[position_output, fen, move])
    

block.launch()