Upload ChessBot Chess model
Browse files- __pycache__/modeling_chessbot.cpython-311.pyc +0 -0
- config.json +0 -4
- model.safetensors +2 -2
- modeling_chessbot.py +392 -141
- usage_example.py +17 -9
__pycache__/modeling_chessbot.cpython-311.pyc
CHANGED
Binary files a/__pycache__/modeling_chessbot.cpython-311.pyc and b/__pycache__/modeling_chessbot.cpython-311.pyc differ
|
|
config.json
CHANGED
@@ -1,14 +1,10 @@
|
|
1 |
{
|
2 |
-
"architectures": [
|
3 |
-
"ChessBotModel"
|
4 |
-
],
|
5 |
"d_ff": 736,
|
6 |
"d_model": 512,
|
7 |
"max_position_embeddings": 64,
|
8 |
"model_type": "chessbot",
|
9 |
"num_heads": 8,
|
10 |
"num_layers": 10,
|
11 |
-
"torch_dtype": "float32",
|
12 |
"transformers_version": "4.53.1",
|
13 |
"vocab_size": 1929
|
14 |
}
|
|
|
1 |
{
|
|
|
|
|
|
|
2 |
"d_ff": 736,
|
3 |
"d_model": 512,
|
4 |
"max_position_embeddings": 64,
|
5 |
"model_type": "chessbot",
|
6 |
"num_heads": 8,
|
7 |
"num_layers": 10,
|
|
|
8 |
"transformers_version": "4.53.1",
|
9 |
"vocab_size": 1929
|
10 |
}
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:274e6c174ae963a3ad25960fb50de368c9a8fe937719d6d78d7ab55c262ae2c1
|
3 |
+
size 126985096
|
modeling_chessbot.py
CHANGED
@@ -1,15 +1,23 @@
|
|
1 |
"""
|
2 |
-
Standalone ChessBot Model
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
"""
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import torch.nn.functional as F
|
|
|
|
|
9 |
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
|
10 |
from transformers.modeling_outputs import BaseModelOutput
|
11 |
-
import chess
|
12 |
-
import numpy as np
|
13 |
from typing import Optional, Tuple
|
14 |
import math
|
15 |
|
@@ -32,124 +40,66 @@ class ChessBotConfig(PretrainedConfig):
|
|
32 |
max_position_embeddings: int = 64,
|
33 |
**kwargs,
|
34 |
):
|
|
|
35 |
self.num_layers = num_layers
|
36 |
self.d_model = d_model
|
37 |
self.d_ff = d_ff
|
38 |
self.num_heads = num_heads
|
39 |
self.vocab_size = vocab_size
|
40 |
self.max_position_embeddings = max_position_embeddings
|
41 |
-
|
42 |
-
super().__init__(**kwargs)
|
43 |
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
"""
|
48 |
-
|
49 |
"""
|
50 |
-
def __init__(self, d_model: int = 512, num_heads: int = 16, dropout_p: float = 0.1):
|
51 |
-
super(RelativeMultiHeadAttention2, self).__init__()
|
52 |
-
assert d_model % num_heads == 0
|
53 |
-
|
54 |
-
self.d_model = d_model
|
55 |
-
self.num_heads = num_heads
|
56 |
-
self.d_head = int(d_model / num_heads)
|
57 |
-
|
58 |
-
self.query_proj = nn.Linear(d_model, d_model)
|
59 |
-
self.key_proj = nn.Linear(d_model, d_model)
|
60 |
-
self.value_proj = nn.Linear(d_model, d_model)
|
61 |
-
self.pos_proj = nn.Linear(d_model, d_model, bias=False)
|
62 |
-
|
63 |
-
self.dropout = nn.Dropout(p=dropout_p)
|
64 |
-
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
65 |
-
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
66 |
-
|
67 |
-
torch.nn.init.xavier_uniform_(self.u_bias)
|
68 |
-
torch.nn.init.xavier_uniform_(self.v_bias)
|
69 |
-
|
70 |
-
self.out_proj = nn.Linear(d_model, d_model)
|
71 |
-
|
72 |
-
def forward(self, query, key, value, pos_embedding, mask=None):
|
73 |
-
batch_size = value.size(0)
|
74 |
-
|
75 |
-
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
|
76 |
-
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
77 |
-
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
78 |
-
|
79 |
-
pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
|
80 |
-
|
81 |
-
query = query.permute(0, 2, 1, 3)
|
82 |
-
|
83 |
-
query_with_u_bias = query + self.u_bias.unsqueeze(1)
|
84 |
-
query_with_v_bias = query + self.v_bias.unsqueeze(1)
|
85 |
-
|
86 |
-
content_score = torch.matmul(query_with_u_bias, key.transpose(-1, -2))
|
87 |
-
pos_score = torch.matmul(query_with_v_bias, pos_embedding.permute(0, 2, 3, 1))
|
88 |
-
pos_score = self._compute_relative_positional_encoding(pos_score)
|
89 |
-
|
90 |
-
score = (content_score + pos_score) / math.sqrt(self.d_head)
|
91 |
-
|
92 |
-
if mask is not None:
|
93 |
-
score.masked_fill_(mask, -float('inf'))
|
94 |
-
|
95 |
-
attn = F.softmax(score, -1)
|
96 |
-
attn = self.dropout(attn)
|
97 |
-
|
98 |
-
context = torch.matmul(attn, value).transpose(1, 2)
|
99 |
-
context = context.contiguous().view(batch_size, -1, self.d_model)
|
100 |
-
|
101 |
-
return self.out_proj(context)
|
102 |
-
|
103 |
-
def _compute_relative_positional_encoding(self, pos_score):
|
104 |
-
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
|
105 |
-
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
|
106 |
-
padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
|
107 |
-
|
108 |
-
padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
|
109 |
-
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
|
110 |
-
|
111 |
-
return pos_score
|
112 |
-
|
113 |
-
|
114 |
-
# Utility functions
|
115 |
-
def fen_to_tensor(fen: str):
|
116 |
-
"""Convert FEN string to tensor representation"""
|
117 |
board = chess.Board(fen)
|
118 |
-
|
119 |
-
tensor = np.zeros((8, 8, P), dtype=np.float32)
|
120 |
|
|
|
121 |
piece_map = {
|
122 |
'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, # White pieces
|
123 |
'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11 # Black pieces
|
124 |
}
|
125 |
|
126 |
-
#
|
127 |
-
for square
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
131 |
|
132 |
-
#
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
139 |
|
140 |
-
# Castling rights
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
|
|
|
|
|
|
|
|
145 |
|
146 |
-
#
|
147 |
-
|
|
|
|
|
|
|
148 |
|
149 |
return tensor
|
150 |
|
151 |
|
152 |
-
#
|
153 |
policy_index = [
|
154 |
"a1b1", "a1c1", "a1d1", "a1e1", "a1f1", "a1g1", "a1h1", "a1a2", "a1b2",
|
155 |
"a1c2", "a1a3", "a1b3", "a1c3", "a1a4", "a1d4", "a1a5", "a1e5", "a1a6",
|
@@ -370,6 +320,68 @@ policy_index = [
|
|
370 |
"<thinking>","</thinking>","end_variation","end","padding_token"
|
371 |
]
|
372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
# Model components
|
374 |
class MaGating(nn.Module):
|
375 |
def __init__(self, d_model):
|
@@ -425,48 +437,40 @@ class AbsolutePositionalEncoder(nn.Module):
|
|
425 |
class ValueHead(nn.Module):
|
426 |
def __init__(self, d_model):
|
427 |
super().__init__()
|
428 |
-
self.
|
429 |
-
self.
|
430 |
-
self.
|
431 |
-
|
432 |
-
self.layernorm1 = nn.LayerNorm(d_model)
|
433 |
-
self.layernorm2 = nn.LayerNorm(d_model)
|
434 |
-
|
435 |
def forward(self, x):
|
436 |
-
|
437 |
-
x = self.
|
438 |
-
x =
|
439 |
-
x =
|
440 |
-
x = self.
|
441 |
-
x =
|
442 |
-
x = self.
|
443 |
-
x = self.linear3(x)
|
444 |
return x
|
445 |
-
|
446 |
|
447 |
class ValueHeadQ(nn.Module):
|
448 |
def __init__(self, d_model):
|
449 |
super().__init__()
|
450 |
-
self.
|
451 |
-
self.
|
452 |
-
self.
|
453 |
-
|
454 |
-
self.layernorm1 = nn.LayerNorm(d_model)
|
455 |
-
self.layernorm2 = nn.LayerNorm(d_model)
|
456 |
-
|
457 |
def forward(self, x):
|
458 |
-
|
459 |
-
x = self.
|
460 |
-
x =
|
461 |
-
x =
|
462 |
-
x = self.
|
463 |
-
x =
|
464 |
-
x = self.
|
465 |
-
x = self.linear3(x)
|
466 |
return x
|
467 |
|
468 |
|
469 |
-
# Main model class
|
470 |
class ChessBotPreTrainedModel(PreTrainedModel):
|
471 |
"""
|
472 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
|
@@ -491,19 +495,19 @@ class ChessBotPreTrainedModel(PreTrainedModel):
|
|
491 |
|
492 |
class ChessBotModel(ChessBotPreTrainedModel):
|
493 |
"""
|
494 |
-
HuggingFace compatible ChessBot Chess model
|
495 |
"""
|
496 |
|
497 |
def __init__(self, config):
|
498 |
super().__init__(config)
|
499 |
self.config = config
|
500 |
|
501 |
-
# Initialize
|
502 |
self.is_thinking_model = False
|
503 |
self.d_model = config.d_model
|
504 |
self.num_layers = config.num_layers
|
505 |
|
506 |
-
# Model layers
|
507 |
self.layers = nn.ModuleList([
|
508 |
EncoderLayer(config.d_model, config.d_ff, config.num_heads)
|
509 |
for _ in range(config.num_layers)
|
@@ -523,11 +527,90 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
523 |
# Initialize weights
|
524 |
self.post_init()
|
525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
526 |
def forward(self, input_ids, attention_mask=None, compute_loss=False):
|
527 |
"""
|
528 |
-
Forward pass compatible with
|
529 |
"""
|
530 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
531 |
b, seq_len, _, _, emb = x.size()
|
532 |
x = x.view(b * seq_len, 64, emb)
|
533 |
|
@@ -537,9 +620,8 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
537 |
x = self.ma_gating(x)
|
538 |
|
539 |
pos_enc = self.positional(x)
|
540 |
-
|
541 |
-
|
542 |
-
x = layer(x, pos_enc)
|
543 |
|
544 |
value_h = self.value_head(x)
|
545 |
value_h = value_h.view(b, seq_len, 3)
|
@@ -561,12 +643,23 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
561 |
|
562 |
policy = self.policy_head(policy_attn_logits)
|
563 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
564 |
return BaseModelOutput(
|
565 |
last_hidden_state=x,
|
566 |
hidden_states=None,
|
567 |
attentions=None,
|
568 |
), policy, value_h, value_h_q
|
569 |
-
|
570 |
def get_move_from_fen_no_thinking(self, fen, T=1, device="cuda", force_legal=True, return_probs=False):
|
571 |
"""
|
572 |
Get a move from FEN string without thinking
|
@@ -627,11 +720,169 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
627 |
|
628 |
return selected_move
|
629 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
630 |
|
631 |
# Register the configuration and model with transformers
|
632 |
AutoConfig.register("chessbot", ChessBotConfig)
|
633 |
AutoModel.register(ChessBotConfig, ChessBotModel)
|
634 |
|
635 |
-
# For backward compatibility
|
636 |
ChessBot = ChessBotModel
|
637 |
-
BT4Model = ChessBotModel
|
|
|
1 |
"""
|
2 |
+
Standalone ChessBot Chess Model
|
3 |
+
|
4 |
+
This file contains all the necessary code to run the ChessBot model
|
5 |
+
without requiring the HFChessRL package installation.
|
6 |
+
|
7 |
+
Requirements:
|
8 |
+
- torch>=2.0.0
|
9 |
+
- transformers>=4.30.0
|
10 |
+
- python-chess>=1.10.0
|
11 |
+
- numpy>=1.21.0
|
12 |
"""
|
13 |
|
14 |
import torch
|
15 |
import torch.nn as nn
|
16 |
import torch.nn.functional as F
|
17 |
+
import numpy as np
|
18 |
+
import chess
|
19 |
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
|
20 |
from transformers.modeling_outputs import BaseModelOutput
|
|
|
|
|
21 |
from typing import Optional, Tuple
|
22 |
import math
|
23 |
|
|
|
40 |
max_position_embeddings: int = 64,
|
41 |
**kwargs,
|
42 |
):
|
43 |
+
super().__init__(**kwargs)
|
44 |
self.num_layers = num_layers
|
45 |
self.d_model = d_model
|
46 |
self.d_ff = d_ff
|
47 |
self.num_heads = num_heads
|
48 |
self.vocab_size = vocab_size
|
49 |
self.max_position_embeddings = max_position_embeddings
|
|
|
|
|
50 |
|
51 |
|
52 |
+
# FEN encoding function
|
53 |
+
def fen_to_tensor(fen: str):
|
54 |
"""
|
55 |
+
Convert FEN string to tensor representation for the model.
|
56 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
board = chess.Board(fen)
|
58 |
+
tensor = np.zeros((8, 8, 19), dtype=np.float32)
|
|
|
59 |
|
60 |
+
# Piece mapping
|
61 |
piece_map = {
|
62 |
'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, # White pieces
|
63 |
'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11 # Black pieces
|
64 |
}
|
65 |
|
66 |
+
# Fill piece positions
|
67 |
+
for square in chess.SQUARES:
|
68 |
+
piece = board.piece_at(square)
|
69 |
+
if piece:
|
70 |
+
row = 7 - (square // 8) # Flip vertically for proper orientation
|
71 |
+
col = square % 8
|
72 |
+
tensor[row, col, piece_map[piece.symbol()]] = 1.0
|
73 |
|
74 |
+
# Add metadata channels
|
75 |
+
# Channel 12: White to move
|
76 |
+
if board.turn == chess.WHITE:
|
77 |
+
tensor[:, :, 12] = 1.0
|
78 |
+
|
79 |
+
# Channel 13: Black to move
|
80 |
+
if board.turn == chess.BLACK:
|
81 |
+
tensor[:, :, 13] = 1.0
|
82 |
|
83 |
+
# Castling rights
|
84 |
+
if board.has_kingside_castling_rights(chess.WHITE):
|
85 |
+
tensor[:, :, 14] = 1.0
|
86 |
+
if board.has_queenside_castling_rights(chess.WHITE):
|
87 |
+
tensor[:, :, 15] = 1.0
|
88 |
+
if board.has_kingside_castling_rights(chess.BLACK):
|
89 |
+
tensor[:, :, 16] = 1.0
|
90 |
+
if board.has_queenside_castling_rights(chess.BLACK):
|
91 |
+
tensor[:, :, 17] = 1.0
|
92 |
|
93 |
+
# En passant
|
94 |
+
if board.ep_square is not None:
|
95 |
+
ep_row = 7 - (board.ep_square // 8)
|
96 |
+
ep_col = board.ep_square % 8
|
97 |
+
tensor[ep_row, ep_col, 18] = 1.0
|
98 |
|
99 |
return tensor
|
100 |
|
101 |
|
102 |
+
# Complete policy index with all 1929 moves
|
103 |
policy_index = [
|
104 |
"a1b1", "a1c1", "a1d1", "a1e1", "a1f1", "a1g1", "a1h1", "a1a2", "a1b2",
|
105 |
"a1c2", "a1a3", "a1b3", "a1c3", "a1a4", "a1d4", "a1a5", "a1e5", "a1a6",
|
|
|
320 |
"<thinking>","</thinking>","end_variation","end","padding_token"
|
321 |
]
|
322 |
|
323 |
+
|
324 |
+
|
325 |
+
# Attention mechanism
|
326 |
+
class RelativeMultiHeadAttention2(nn.Module):
|
327 |
+
def __init__(self, d_model: int = 512, num_heads: int = 16, dropout_p: float = 0.1):
|
328 |
+
super().__init__()
|
329 |
+
assert d_model % num_heads == 0
|
330 |
+
self.d_model = d_model
|
331 |
+
self.num_heads = num_heads
|
332 |
+
self.d_head = d_model // num_heads
|
333 |
+
self.sqrt_dim = math.sqrt(d_model)
|
334 |
+
|
335 |
+
self.query_proj = nn.Linear(d_model, d_model)
|
336 |
+
self.key_proj = nn.Linear(d_model, d_model)
|
337 |
+
self.value_proj = nn.Linear(d_model, d_model)
|
338 |
+
self.pos_proj = nn.Linear(d_model, d_model)
|
339 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
340 |
+
|
341 |
+
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
342 |
+
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
343 |
+
torch.nn.init.xavier_uniform_(self.u_bias)
|
344 |
+
torch.nn.init.xavier_uniform_(self.v_bias)
|
345 |
+
self.dropout = nn.Dropout(dropout_p)
|
346 |
+
|
347 |
+
def forward(self, query, key, value, pos_embedding, mask=None):
|
348 |
+
batch_size = value.size(0)
|
349 |
+
|
350 |
+
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
|
351 |
+
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
352 |
+
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
353 |
+
|
354 |
+
pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
|
355 |
+
|
356 |
+
content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
|
357 |
+
pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
|
358 |
+
pos_score = self._compute_relative_positional_encoding(pos_score)
|
359 |
+
|
360 |
+
score = (content_score + pos_score) / self.sqrt_dim
|
361 |
+
|
362 |
+
if mask is not None:
|
363 |
+
mask = mask.unsqueeze(1)
|
364 |
+
score.masked_fill_(mask, -1e9)
|
365 |
+
|
366 |
+
attn = F.softmax(score, -1)
|
367 |
+
attn = self.dropout(attn)
|
368 |
+
|
369 |
+
context = torch.matmul(attn, value).transpose(1, 2)
|
370 |
+
context = context.contiguous().view(batch_size, -1, self.d_model)
|
371 |
+
|
372 |
+
return self.out_proj(context)
|
373 |
+
|
374 |
+
def _compute_relative_positional_encoding(self, pos_score):
|
375 |
+
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
|
376 |
+
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
|
377 |
+
padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
|
378 |
+
|
379 |
+
padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
|
380 |
+
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
|
381 |
+
|
382 |
+
return pos_score
|
383 |
+
|
384 |
+
|
385 |
# Model components
|
386 |
class MaGating(nn.Module):
|
387 |
def __init__(self, d_model):
|
|
|
437 |
class ValueHead(nn.Module):
|
438 |
def __init__(self, d_model):
|
439 |
super().__init__()
|
440 |
+
self.dense1 = nn.Linear(d_model, 128)
|
441 |
+
self.dense2 = nn.Linear(128*64, 128)
|
442 |
+
self.dense3 = nn.Linear(128, 3)
|
443 |
+
|
|
|
|
|
|
|
444 |
def forward(self, x):
|
445 |
+
b, _, _ = x.size()
|
446 |
+
x = self.dense1(x)
|
447 |
+
x = F.gelu(x)
|
448 |
+
x = x.view(b, -1)
|
449 |
+
x = self.dense2(x)
|
450 |
+
x = F.gelu(x)
|
451 |
+
x = self.dense3(x)
|
|
|
452 |
return x
|
453 |
+
|
454 |
|
455 |
class ValueHeadQ(nn.Module):
|
456 |
def __init__(self, d_model):
|
457 |
super().__init__()
|
458 |
+
self.dense1 = nn.Linear(d_model, 128)
|
459 |
+
self.dense2 = nn.Linear(128*64, 128)
|
460 |
+
self.dense3 = nn.Linear(128, 3)
|
461 |
+
|
|
|
|
|
|
|
462 |
def forward(self, x):
|
463 |
+
b, _, _ = x.size()
|
464 |
+
x = self.dense1(x)
|
465 |
+
x = F.gelu(x)
|
466 |
+
x = x.view(b, -1)
|
467 |
+
x = self.dense2(x)
|
468 |
+
x = F.gelu(x)
|
469 |
+
x = self.dense3(x)
|
|
|
470 |
return x
|
471 |
|
472 |
|
473 |
+
# Main HuggingFace compatible model class
|
474 |
class ChessBotPreTrainedModel(PreTrainedModel):
|
475 |
"""
|
476 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
|
|
|
495 |
|
496 |
class ChessBotModel(ChessBotPreTrainedModel):
|
497 |
"""
|
498 |
+
HuggingFace compatible ChessBot Chess model with ALL original functionality
|
499 |
"""
|
500 |
|
501 |
def __init__(self, config):
|
502 |
super().__init__(config)
|
503 |
self.config = config
|
504 |
|
505 |
+
# Initialize exactly like the original BT4 model
|
506 |
self.is_thinking_model = False
|
507 |
self.d_model = config.d_model
|
508 |
self.num_layers = config.num_layers
|
509 |
|
510 |
+
# Model layers - same as original
|
511 |
self.layers = nn.ModuleList([
|
512 |
EncoderLayer(config.d_model, config.d_ff, config.num_heads)
|
513 |
for _ in range(config.num_layers)
|
|
|
527 |
# Initialize weights
|
528 |
self.post_init()
|
529 |
|
530 |
+
@classmethod
|
531 |
+
def from_pretrained(cls, model_path, **kwargs):
|
532 |
+
"""
|
533 |
+
Load a pretrained model from a directory (HuggingFace compatible)
|
534 |
+
"""
|
535 |
+
import os
|
536 |
+
|
537 |
+
# Load config
|
538 |
+
config_path = os.path.join(model_path, "config.json")
|
539 |
+
if os.path.exists(config_path):
|
540 |
+
config = ChessBotConfig.from_pretrained(model_path)
|
541 |
+
else:
|
542 |
+
config = ChessBotConfig()
|
543 |
+
|
544 |
+
# Create model instance
|
545 |
+
model = cls(config)
|
546 |
+
|
547 |
+
# Load weights
|
548 |
+
model_file = None
|
549 |
+
for filename in ["pytorch_model.bin", "model.safetensors"]:
|
550 |
+
full_path = os.path.join(model_path, filename)
|
551 |
+
if os.path.exists(full_path):
|
552 |
+
model_file = full_path
|
553 |
+
break
|
554 |
+
|
555 |
+
if model_file is None:
|
556 |
+
raise FileNotFoundError(f"No model file found in {model_path}")
|
557 |
+
|
558 |
+
if model_file.endswith('.safetensors'):
|
559 |
+
# Handle safetensors format
|
560 |
+
try:
|
561 |
+
from safetensors import safe_open
|
562 |
+
state_dict = {}
|
563 |
+
with safe_open(model_file, framework="pt", device="cpu") as f:
|
564 |
+
for key in f.keys():
|
565 |
+
state_dict[key] = f.get_tensor(key)
|
566 |
+
except ImportError:
|
567 |
+
raise ImportError("safetensors library is required to load .safetensors files. Install with: pip install safetensors")
|
568 |
+
else:
|
569 |
+
# Handle pytorch format
|
570 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
571 |
+
|
572 |
+
# Load state dict into model
|
573 |
+
model.load_state_dict(state_dict, strict=False)
|
574 |
+
|
575 |
+
return model
|
576 |
+
|
577 |
+
def save_pretrained(self, save_directory, safe_serialization=False):
|
578 |
+
"""
|
579 |
+
Save the model to a directory (HuggingFace compatible)
|
580 |
+
"""
|
581 |
+
import os
|
582 |
+
os.makedirs(save_directory, exist_ok=True)
|
583 |
+
|
584 |
+
# Save config
|
585 |
+
self.config.save_pretrained(save_directory)
|
586 |
+
|
587 |
+
# Save model weights
|
588 |
+
if safe_serialization:
|
589 |
+
try:
|
590 |
+
from safetensors.torch import save_file
|
591 |
+
model_path = os.path.join(save_directory, "model.safetensors")
|
592 |
+
save_file(self.state_dict(), model_path)
|
593 |
+
except ImportError:
|
594 |
+
print("⚠ Warning: safetensors not available, falling back to pytorch_model.bin")
|
595 |
+
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
596 |
+
torch.save(self.state_dict(), model_path)
|
597 |
+
else:
|
598 |
+
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
599 |
+
torch.save(self.state_dict(), model_path)
|
600 |
+
|
601 |
def forward(self, input_ids, attention_mask=None, compute_loss=False):
|
602 |
"""
|
603 |
+
Forward pass compatible with both HuggingFace interface and original interface
|
604 |
"""
|
605 |
+
# Handle both HF interface (input_ids) and original interface (tuple)
|
606 |
+
if isinstance(input_ids, tuple):
|
607 |
+
inp = input_ids
|
608 |
+
x = inp[0]
|
609 |
+
compute_loss = compute_loss or len(inp) > 1
|
610 |
+
else:
|
611 |
+
x = input_ids
|
612 |
+
inp = (x,)
|
613 |
+
|
614 |
b, seq_len, _, _, emb = x.size()
|
615 |
x = x.view(b * seq_len, 64, emb)
|
616 |
|
|
|
620 |
x = self.ma_gating(x)
|
621 |
|
622 |
pos_enc = self.positional(x)
|
623 |
+
for i in range(self.num_layers):
|
624 |
+
x = self.layers[i](x, pos_enc)
|
|
|
625 |
|
626 |
value_h = self.value_head(x)
|
627 |
value_h = value_h.view(b, seq_len, 3)
|
|
|
643 |
|
644 |
policy = self.policy_head(policy_attn_logits)
|
645 |
|
646 |
+
if compute_loss:
|
647 |
+
targets = inp[1]
|
648 |
+
true_values = inp[3]
|
649 |
+
q_values = inp[4]
|
650 |
+
loss_policy = F.cross_entropy(policy.view(-1, policy.size(-1)), targets.view(-1), ignore_index=1928)
|
651 |
+
z = torch.argmax(true_values, dim=-1)
|
652 |
+
loss_value = F.cross_entropy(value_h.view(-1, value_h.size(-1)), z.view(-1), ignore_index=3)
|
653 |
+
value_h_q = torch.softmax(value_h_q, dim=-1)
|
654 |
+
loss_q = F.mse_loss(value_h_q.view(-1, value_h_q.size(-1)), q_values.view(-1, 3))
|
655 |
+
return policy, value_h, loss_policy, loss_value, loss_q, targets, z
|
656 |
+
|
657 |
return BaseModelOutput(
|
658 |
last_hidden_state=x,
|
659 |
hidden_states=None,
|
660 |
attentions=None,
|
661 |
), policy, value_h, value_h_q
|
662 |
+
|
663 |
def get_move_from_fen_no_thinking(self, fen, T=1, device="cuda", force_legal=True, return_probs=False):
|
664 |
"""
|
665 |
Get a move from FEN string without thinking
|
|
|
720 |
|
721 |
return selected_move
|
722 |
|
723 |
+
def get_position_value(self, fen, device="cuda"):
|
724 |
+
"""
|
725 |
+
Get the value evaluation for a given FEN position.
|
726 |
+
Returns the value vector [black_win_prob, draw_prob, white_win_prob]
|
727 |
+
"""
|
728 |
+
x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
|
729 |
+
x = x.view(1, 1, 8, 8, 19)
|
730 |
+
|
731 |
+
# Forward pass through the model to get value
|
732 |
+
with torch.no_grad():
|
733 |
+
# We need to run through the model layers to get to value_head
|
734 |
+
b, seq_len, _, _, emb = x.size()
|
735 |
+
x_processed = x.view(b * seq_len, 64, emb)
|
736 |
+
x_processed = self.linear1(x_processed)
|
737 |
+
x_processed = F.gelu(x_processed)
|
738 |
+
x_processed = self.layernorm1(x_processed)
|
739 |
+
x_processed = self.ma_gating(x_processed)
|
740 |
+
|
741 |
+
pos_enc = self.positional(x_processed)
|
742 |
+
for i in range(self.num_layers):
|
743 |
+
x_processed = self.layers[i](x_processed, pos_enc)
|
744 |
+
|
745 |
+
value_logits = self.value_head_q(x_processed)
|
746 |
+
value_logits = value_logits.view(b, seq_len, 3)
|
747 |
+
value_logits = torch.softmax(value_logits, dim=-1)
|
748 |
+
|
749 |
+
return value_logits.squeeze() # Remove batch and sequence dimensions
|
750 |
+
|
751 |
+
def get_batch_position_values(self, fens, device="cuda"):
|
752 |
+
"""
|
753 |
+
Get the value evaluation for a batch of FEN positions efficiently.
|
754 |
+
Args:
|
755 |
+
fens: List of FEN strings
|
756 |
+
device: Device to run computations on
|
757 |
+
Returns:
|
758 |
+
value_probs: Tensor of shape [batch_size, 3] with [black_win_prob, draw_prob, white_win_prob] for each position
|
759 |
+
"""
|
760 |
+
if len(fens) == 0:
|
761 |
+
return torch.empty(0, 3, device=device)
|
762 |
+
|
763 |
+
# Convert all FENs to tensors and stack them
|
764 |
+
position_tensors = []
|
765 |
+
for fen in fens:
|
766 |
+
x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
|
767 |
+
position_tensors.append(x)
|
768 |
+
|
769 |
+
# Stack to create batch: [batch_size, 8, 8, 19]
|
770 |
+
batch_x = torch.stack(position_tensors, dim=0)
|
771 |
+
# Reshape to [batch_size, 1, 8, 8, 19] for the model
|
772 |
+
batch_x = batch_x.unsqueeze(1)
|
773 |
+
|
774 |
+
# Forward pass through the model to get values
|
775 |
+
with torch.no_grad():
|
776 |
+
b, seq_len, _, _, emb = batch_x.size()
|
777 |
+
x_processed = batch_x.view(b * seq_len, 64, emb)
|
778 |
+
x_processed = self.linear1(x_processed)
|
779 |
+
x_processed = F.gelu(x_processed)
|
780 |
+
x_processed = self.layernorm1(x_processed)
|
781 |
+
x_processed = self.ma_gating(x_processed)
|
782 |
+
|
783 |
+
pos_enc = self.positional(x_processed)
|
784 |
+
for i in range(self.num_layers):
|
785 |
+
x_processed = self.layers[i](x_processed, pos_enc)
|
786 |
+
|
787 |
+
value_logits = self.value_head_q(x_processed)
|
788 |
+
value_logits = value_logits.view(b, seq_len, 3)
|
789 |
+
value_logits = torch.softmax(value_logits, dim=-1)
|
790 |
+
return value_logits.squeeze(1) # Remove sequence dimension, keep batch dimension
|
791 |
+
|
792 |
+
def calculate_move_values(self, fen, device="cuda"):
|
793 |
+
"""
|
794 |
+
Calculate the value for each legal move from the given position efficiently using batching.
|
795 |
+
For white to move, value = white_win_prob - black_win_prob
|
796 |
+
For black to move, value = black_win_prob - white_win_prob
|
797 |
+
"""
|
798 |
+
board = chess.Board()
|
799 |
+
board.set_fen(fen)
|
800 |
+
|
801 |
+
# Determine whose turn it is
|
802 |
+
is_white_turn = board.turn == chess.WHITE
|
803 |
+
|
804 |
+
legal_moves = list(board.legal_moves)
|
805 |
+
if len(legal_moves) == 0:
|
806 |
+
return [], torch.empty(0, device=device)
|
807 |
+
|
808 |
+
# Get all resulting FENs after each move
|
809 |
+
resulting_fens = []
|
810 |
+
for move in legal_moves:
|
811 |
+
board.push(move)
|
812 |
+
resulting_fens.append(board.fen())
|
813 |
+
board.pop()
|
814 |
+
|
815 |
+
# Batch process all positions in a single inference
|
816 |
+
batch_value_q = self.get_batch_position_values(resulting_fens, device)
|
817 |
+
|
818 |
+
# Calculate values from the current player's perspective
|
819 |
+
# batch_value_probs[:, 0] = black_win_prob, [:, 1] = draw_prob, [:, 2] = white_win_prob
|
820 |
+
batch_value_q = batch_value_q[:,2]-batch_value_q[:,0]
|
821 |
+
if is_white_turn:
|
822 |
+
# White's perspective: white_win_prob - black_win_prob
|
823 |
+
player_values = batch_value_q
|
824 |
+
else:
|
825 |
+
# Black's perspective: black_win_prob - white_win_prob
|
826 |
+
player_values = -batch_value_q
|
827 |
+
|
828 |
+
return legal_moves, player_values
|
829 |
+
|
830 |
+
def get_best_move_value(self, fen, T=1, device="cuda", return_probs=False):
|
831 |
+
"""
|
832 |
+
Determine the best move based on the value of resulting positions using efficient batching.
|
833 |
+
|
834 |
+
Args:
|
835 |
+
fen: FEN string of the position (works for both white and black to move)
|
836 |
+
T: Temperature for sampling (T=0 for greedy, T>0 for stochastic)
|
837 |
+
device: Device to run computations on
|
838 |
+
return_probs: Whether to return the probability distribution
|
839 |
+
|
840 |
+
Returns:
|
841 |
+
move: UCI string of the selected move
|
842 |
+
probs (optional): probability distribution over moves if return_probs=True
|
843 |
+
"""
|
844 |
+
legal_moves, move_values = self.calculate_move_values(fen, device)
|
845 |
+
|
846 |
+
if len(legal_moves) == 0:
|
847 |
+
raise ValueError("No legal moves available")
|
848 |
+
|
849 |
+
if T == 0:
|
850 |
+
# Greedy selection - choose move with highest value
|
851 |
+
best_idx = torch.argmax(move_values)
|
852 |
+
selected_move = legal_moves[best_idx]
|
853 |
+
else:
|
854 |
+
# Stochastic selection based on move values
|
855 |
+
# Convert values to probabilities using softmax with temperature
|
856 |
+
probs = F.softmax(move_values / T, dim=0)
|
857 |
+
|
858 |
+
# Sample according to probabilities
|
859 |
+
sampled_idx = torch.multinomial(probs, num_samples=1)
|
860 |
+
selected_move = legal_moves[sampled_idx.item()]
|
861 |
+
|
862 |
+
# Convert chess.Move to UCI string
|
863 |
+
move_uci = selected_move.uci()
|
864 |
+
|
865 |
+
if return_probs:
|
866 |
+
if T == 0:
|
867 |
+
# Create one-hot distribution for greedy case
|
868 |
+
probs = torch.zeros_like(move_values)
|
869 |
+
probs[best_idx] = 1.0
|
870 |
+
else:
|
871 |
+
probs = F.softmax(move_values / T, dim=0)
|
872 |
+
|
873 |
+
# Create dictionary with move strings as keys
|
874 |
+
move_dict = {}
|
875 |
+
for i, move in enumerate(legal_moves):
|
876 |
+
move_dict[move.uci()] = probs[i].item()
|
877 |
+
return move_uci, move_dict
|
878 |
+
|
879 |
+
return move_uci
|
880 |
+
|
881 |
|
882 |
# Register the configuration and model with transformers
|
883 |
AutoConfig.register("chessbot", ChessBotConfig)
|
884 |
AutoModel.register(ChessBotConfig, ChessBotModel)
|
885 |
|
886 |
+
# For backward compatibility
|
887 |
ChessBot = ChessBotModel
|
888 |
+
BT4Model = ChessBotModel
|
usage_example.py
CHANGED
@@ -10,25 +10,33 @@ This model can be used without installing any external packages except:
|
|
10 |
|
11 |
import torch
|
12 |
import sys
|
13 |
-
|
|
|
|
|
|
|
|
|
14 |
from modeling_chessbot import ChessBotModel, ChessBotConfig
|
15 |
|
16 |
# Load the model
|
17 |
config = ChessBotConfig()
|
18 |
-
model = ChessBotModel.from_pretrained(
|
19 |
-
|
20 |
-
# Alternative: You can also try AutoModel (may require additional setup)
|
21 |
-
# from transformers import AutoModel
|
22 |
-
# model = AutoModel.from_pretrained("./", trust_remote_code=True)
|
23 |
|
24 |
# Example usage
|
25 |
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
26 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
27 |
model = model.to(device)
|
28 |
|
29 |
-
# Get the best move
|
30 |
-
|
31 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
# Get move probabilities
|
34 |
probs = model.get_move_from_fen_no_thinking(fen, T=0.1, device=device, return_probs=True)
|
|
|
10 |
|
11 |
import torch
|
12 |
import sys
|
13 |
+
import os
|
14 |
+
|
15 |
+
# Get the directory of this script (the model directory)
|
16 |
+
model_dir = os.path.dirname(os.path.abspath(__file__))
|
17 |
+
sys.path.append(model_dir) # Add the model directory to path
|
18 |
from modeling_chessbot import ChessBotModel, ChessBotConfig
|
19 |
|
20 |
# Load the model
|
21 |
config = ChessBotConfig()
|
22 |
+
model = ChessBotModel.from_pretrained(model_dir)
|
|
|
|
|
|
|
|
|
23 |
|
24 |
# Example usage
|
25 |
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
26 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
27 |
model = model.to(device)
|
28 |
|
29 |
+
# Get the best move using policy
|
30 |
+
policy_move = model.get_move_from_fen_no_thinking(fen, T=0.1, device=device)
|
31 |
+
print(f"Policy-based move: {policy_move}")
|
32 |
+
|
33 |
+
# Get the best move using value analysis
|
34 |
+
value_move = model.get_best_move_value(fen, T=0.1, device=device)
|
35 |
+
print(f"Value-based move: {value_move}")
|
36 |
+
|
37 |
+
# Get position evaluation
|
38 |
+
position_value = model.get_position_value(fen, device=device)
|
39 |
+
print(f"Position value [black_win, draw, white_win]: {position_value}")
|
40 |
|
41 |
# Get move probabilities
|
42 |
probs = model.get_move_from_fen_no_thinking(fen, T=0.1, device=device, return_probs=True)
|