Maxlegrec commited on
Commit
26b86c7
·
verified ·
1 Parent(s): 311cfe1

Upload ChessBot Chess model

Browse files
__pycache__/modeling_chessbot.cpython-311.pyc CHANGED
Binary files a/__pycache__/modeling_chessbot.cpython-311.pyc and b/__pycache__/modeling_chessbot.cpython-311.pyc differ
 
config.json CHANGED
@@ -1,14 +1,10 @@
1
  {
2
- "architectures": [
3
- "ChessBotModel"
4
- ],
5
  "d_ff": 736,
6
  "d_model": 512,
7
  "max_position_embeddings": 64,
8
  "model_type": "chessbot",
9
  "num_heads": 8,
10
  "num_layers": 10,
11
- "torch_dtype": "float32",
12
  "transformers_version": "4.53.1",
13
  "vocab_size": 1929
14
  }
 
1
  {
 
 
 
2
  "d_ff": 736,
3
  "d_model": 512,
4
  "max_position_embeddings": 64,
5
  "model_type": "chessbot",
6
  "num_heads": 8,
7
  "num_layers": 10,
 
8
  "transformers_version": "4.53.1",
9
  "vocab_size": 1929
10
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:18bfb31a333bcc46e2d747315626a030855f913c1e3b129ee08d8d979659fd14
3
- size 122277600
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:274e6c174ae963a3ad25960fb50de368c9a8fe937719d6d78d7ab55c262ae2c1
3
+ size 126985096
modeling_chessbot.py CHANGED
@@ -1,15 +1,23 @@
1
  """
2
- Standalone ChessBot Model for HuggingFace Hub
3
- Contains all necessary code to run the model without external dependencies
 
 
 
 
 
 
 
 
4
  """
5
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
 
 
9
  from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
10
  from transformers.modeling_outputs import BaseModelOutput
11
- import chess
12
- import numpy as np
13
  from typing import Optional, Tuple
14
  import math
15
 
@@ -32,124 +40,66 @@ class ChessBotConfig(PretrainedConfig):
32
  max_position_embeddings: int = 64,
33
  **kwargs,
34
  ):
 
35
  self.num_layers = num_layers
36
  self.d_model = d_model
37
  self.d_ff = d_ff
38
  self.num_heads = num_heads
39
  self.vocab_size = vocab_size
40
  self.max_position_embeddings = max_position_embeddings
41
-
42
- super().__init__(**kwargs)
43
 
44
 
45
- # Attention modules
46
- class RelativeMultiHeadAttention2(nn.Module):
47
  """
48
- Relative Multi-Head Attention mechanism
49
  """
50
- def __init__(self, d_model: int = 512, num_heads: int = 16, dropout_p: float = 0.1):
51
- super(RelativeMultiHeadAttention2, self).__init__()
52
- assert d_model % num_heads == 0
53
-
54
- self.d_model = d_model
55
- self.num_heads = num_heads
56
- self.d_head = int(d_model / num_heads)
57
-
58
- self.query_proj = nn.Linear(d_model, d_model)
59
- self.key_proj = nn.Linear(d_model, d_model)
60
- self.value_proj = nn.Linear(d_model, d_model)
61
- self.pos_proj = nn.Linear(d_model, d_model, bias=False)
62
-
63
- self.dropout = nn.Dropout(p=dropout_p)
64
- self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
65
- self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
66
-
67
- torch.nn.init.xavier_uniform_(self.u_bias)
68
- torch.nn.init.xavier_uniform_(self.v_bias)
69
-
70
- self.out_proj = nn.Linear(d_model, d_model)
71
-
72
- def forward(self, query, key, value, pos_embedding, mask=None):
73
- batch_size = value.size(0)
74
-
75
- query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
76
- key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
77
- value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
78
-
79
- pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
80
-
81
- query = query.permute(0, 2, 1, 3)
82
-
83
- query_with_u_bias = query + self.u_bias.unsqueeze(1)
84
- query_with_v_bias = query + self.v_bias.unsqueeze(1)
85
-
86
- content_score = torch.matmul(query_with_u_bias, key.transpose(-1, -2))
87
- pos_score = torch.matmul(query_with_v_bias, pos_embedding.permute(0, 2, 3, 1))
88
- pos_score = self._compute_relative_positional_encoding(pos_score)
89
-
90
- score = (content_score + pos_score) / math.sqrt(self.d_head)
91
-
92
- if mask is not None:
93
- score.masked_fill_(mask, -float('inf'))
94
-
95
- attn = F.softmax(score, -1)
96
- attn = self.dropout(attn)
97
-
98
- context = torch.matmul(attn, value).transpose(1, 2)
99
- context = context.contiguous().view(batch_size, -1, self.d_model)
100
-
101
- return self.out_proj(context)
102
-
103
- def _compute_relative_positional_encoding(self, pos_score):
104
- batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
105
- zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
106
- padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
107
-
108
- padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
109
- pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
110
-
111
- return pos_score
112
-
113
-
114
- # Utility functions
115
- def fen_to_tensor(fen: str):
116
- """Convert FEN string to tensor representation"""
117
  board = chess.Board(fen)
118
- P = 19 # 12 planes for pieces + 1 for side to play + 1 for en passant + 4 for castling + 1 for 50-move rule
119
- tensor = np.zeros((8, 8, P), dtype=np.float32)
120
 
 
121
  piece_map = {
122
  'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, # White pieces
123
  'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11 # Black pieces
124
  }
125
 
126
- # Populate piece planes
127
- for square, piece in board.piece_map().items():
128
- rank, file = divmod(square, 8)
129
- plane = piece_map[piece.symbol()]
130
- tensor[7 - rank, file, plane] = 1.0 # Flip rank to align with standard board representation
 
 
131
 
132
- # Side to play plane
133
- tensor[:, :, 12] = 1.0 if board.turn == chess.WHITE else 0.0
134
-
135
- # En passant plane
136
- if board.ep_square is not None:
137
- rank, file = divmod(board.ep_square, 8)
138
- tensor[7 - rank, file, 13] = 1.0
 
139
 
140
- # Castling rights planes (4 total: white kingside, white queenside, black kingside, black queenside)
141
- tensor[:, :, 14] = 1.0 if board.has_kingside_castling_rights(chess.WHITE) else 0.0
142
- tensor[:, :, 15] = 1.0 if board.has_queenside_castling_rights(chess.WHITE) else 0.0
143
- tensor[:, :, 16] = 1.0 if board.has_kingside_castling_rights(chess.BLACK) else 0.0
144
- tensor[:, :, 17] = 1.0 if board.has_queenside_castling_rights(chess.BLACK) else 0.0
 
 
 
 
145
 
146
- # 50-move rule plane (normalized to [0,1])
147
- tensor[:, :, 18] = min(board.halfmove_clock / 100.0, 1.0)
 
 
 
148
 
149
  return tensor
150
 
151
 
152
- # Policy index (chess moves vocabulary)
153
  policy_index = [
154
  "a1b1", "a1c1", "a1d1", "a1e1", "a1f1", "a1g1", "a1h1", "a1a2", "a1b2",
155
  "a1c2", "a1a3", "a1b3", "a1c3", "a1a4", "a1d4", "a1a5", "a1e5", "a1a6",
@@ -370,6 +320,68 @@ policy_index = [
370
  "<thinking>","</thinking>","end_variation","end","padding_token"
371
  ]
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  # Model components
374
  class MaGating(nn.Module):
375
  def __init__(self, d_model):
@@ -425,48 +437,40 @@ class AbsolutePositionalEncoder(nn.Module):
425
  class ValueHead(nn.Module):
426
  def __init__(self, d_model):
427
  super().__init__()
428
- self.linear1 = nn.Linear(d_model, d_model)
429
- self.linear2 = nn.Linear(d_model, d_model)
430
- self.linear3 = nn.Linear(d_model, 3)
431
- self.gelu = nn.GELU()
432
- self.layernorm1 = nn.LayerNorm(d_model)
433
- self.layernorm2 = nn.LayerNorm(d_model)
434
-
435
  def forward(self, x):
436
- x = x.mean(dim=-2)
437
- x = self.linear1(x)
438
- x = self.gelu(x)
439
- x = self.layernorm1(x)
440
- x = self.linear2(x)
441
- x = self.gelu(x)
442
- x = self.layernorm2(x)
443
- x = self.linear3(x)
444
  return x
445
-
446
 
447
  class ValueHeadQ(nn.Module):
448
  def __init__(self, d_model):
449
  super().__init__()
450
- self.linear1 = nn.Linear(d_model, d_model)
451
- self.linear2 = nn.Linear(d_model, d_model)
452
- self.linear3 = nn.Linear(d_model, 3)
453
- self.gelu = nn.GELU()
454
- self.layernorm1 = nn.LayerNorm(d_model)
455
- self.layernorm2 = nn.LayerNorm(d_model)
456
-
457
  def forward(self, x):
458
- x = x.mean(dim=-2)
459
- x = self.linear1(x)
460
- x = self.gelu(x)
461
- x = self.layernorm1(x)
462
- x = self.linear2(x)
463
- x = self.gelu(x)
464
- x = self.layernorm2(x)
465
- x = self.linear3(x)
466
  return x
467
 
468
 
469
- # Main model class
470
  class ChessBotPreTrainedModel(PreTrainedModel):
471
  """
472
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
@@ -491,19 +495,19 @@ class ChessBotPreTrainedModel(PreTrainedModel):
491
 
492
  class ChessBotModel(ChessBotPreTrainedModel):
493
  """
494
- HuggingFace compatible ChessBot Chess model
495
  """
496
 
497
  def __init__(self, config):
498
  super().__init__(config)
499
  self.config = config
500
 
501
- # Initialize the same components as the original BT4 model
502
  self.is_thinking_model = False
503
  self.d_model = config.d_model
504
  self.num_layers = config.num_layers
505
 
506
- # Model layers
507
  self.layers = nn.ModuleList([
508
  EncoderLayer(config.d_model, config.d_ff, config.num_heads)
509
  for _ in range(config.num_layers)
@@ -523,11 +527,90 @@ class ChessBotModel(ChessBotPreTrainedModel):
523
  # Initialize weights
524
  self.post_init()
525
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  def forward(self, input_ids, attention_mask=None, compute_loss=False):
527
  """
528
- Forward pass compatible with Hugging Face interface
529
  """
530
- x = input_ids
 
 
 
 
 
 
 
 
531
  b, seq_len, _, _, emb = x.size()
532
  x = x.view(b * seq_len, 64, emb)
533
 
@@ -537,9 +620,8 @@ class ChessBotModel(ChessBotPreTrainedModel):
537
  x = self.ma_gating(x)
538
 
539
  pos_enc = self.positional(x)
540
-
541
- for layer in self.layers:
542
- x = layer(x, pos_enc)
543
 
544
  value_h = self.value_head(x)
545
  value_h = value_h.view(b, seq_len, 3)
@@ -561,12 +643,23 @@ class ChessBotModel(ChessBotPreTrainedModel):
561
 
562
  policy = self.policy_head(policy_attn_logits)
563
 
 
 
 
 
 
 
 
 
 
 
 
564
  return BaseModelOutput(
565
  last_hidden_state=x,
566
  hidden_states=None,
567
  attentions=None,
568
  ), policy, value_h, value_h_q
569
-
570
  def get_move_from_fen_no_thinking(self, fen, T=1, device="cuda", force_legal=True, return_probs=False):
571
  """
572
  Get a move from FEN string without thinking
@@ -627,11 +720,169 @@ class ChessBotModel(ChessBotPreTrainedModel):
627
 
628
  return selected_move
629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
 
631
  # Register the configuration and model with transformers
632
  AutoConfig.register("chessbot", ChessBotConfig)
633
  AutoModel.register(ChessBotConfig, ChessBotModel)
634
 
635
- # For backward compatibility, create aliases
636
  ChessBot = ChessBotModel
637
- BT4Model = ChessBotModel # Keep for backward compatibility
 
1
  """
2
+ Standalone ChessBot Chess Model
3
+
4
+ This file contains all the necessary code to run the ChessBot model
5
+ without requiring the HFChessRL package installation.
6
+
7
+ Requirements:
8
+ - torch>=2.0.0
9
+ - transformers>=4.30.0
10
+ - python-chess>=1.10.0
11
+ - numpy>=1.21.0
12
  """
13
 
14
  import torch
15
  import torch.nn as nn
16
  import torch.nn.functional as F
17
+ import numpy as np
18
+ import chess
19
  from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
20
  from transformers.modeling_outputs import BaseModelOutput
 
 
21
  from typing import Optional, Tuple
22
  import math
23
 
 
40
  max_position_embeddings: int = 64,
41
  **kwargs,
42
  ):
43
+ super().__init__(**kwargs)
44
  self.num_layers = num_layers
45
  self.d_model = d_model
46
  self.d_ff = d_ff
47
  self.num_heads = num_heads
48
  self.vocab_size = vocab_size
49
  self.max_position_embeddings = max_position_embeddings
 
 
50
 
51
 
52
+ # FEN encoding function
53
+ def fen_to_tensor(fen: str):
54
  """
55
+ Convert FEN string to tensor representation for the model.
56
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  board = chess.Board(fen)
58
+ tensor = np.zeros((8, 8, 19), dtype=np.float32)
 
59
 
60
+ # Piece mapping
61
  piece_map = {
62
  'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, # White pieces
63
  'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11 # Black pieces
64
  }
65
 
66
+ # Fill piece positions
67
+ for square in chess.SQUARES:
68
+ piece = board.piece_at(square)
69
+ if piece:
70
+ row = 7 - (square // 8) # Flip vertically for proper orientation
71
+ col = square % 8
72
+ tensor[row, col, piece_map[piece.symbol()]] = 1.0
73
 
74
+ # Add metadata channels
75
+ # Channel 12: White to move
76
+ if board.turn == chess.WHITE:
77
+ tensor[:, :, 12] = 1.0
78
+
79
+ # Channel 13: Black to move
80
+ if board.turn == chess.BLACK:
81
+ tensor[:, :, 13] = 1.0
82
 
83
+ # Castling rights
84
+ if board.has_kingside_castling_rights(chess.WHITE):
85
+ tensor[:, :, 14] = 1.0
86
+ if board.has_queenside_castling_rights(chess.WHITE):
87
+ tensor[:, :, 15] = 1.0
88
+ if board.has_kingside_castling_rights(chess.BLACK):
89
+ tensor[:, :, 16] = 1.0
90
+ if board.has_queenside_castling_rights(chess.BLACK):
91
+ tensor[:, :, 17] = 1.0
92
 
93
+ # En passant
94
+ if board.ep_square is not None:
95
+ ep_row = 7 - (board.ep_square // 8)
96
+ ep_col = board.ep_square % 8
97
+ tensor[ep_row, ep_col, 18] = 1.0
98
 
99
  return tensor
100
 
101
 
102
+ # Complete policy index with all 1929 moves
103
  policy_index = [
104
  "a1b1", "a1c1", "a1d1", "a1e1", "a1f1", "a1g1", "a1h1", "a1a2", "a1b2",
105
  "a1c2", "a1a3", "a1b3", "a1c3", "a1a4", "a1d4", "a1a5", "a1e5", "a1a6",
 
320
  "<thinking>","</thinking>","end_variation","end","padding_token"
321
  ]
322
 
323
+
324
+
325
+ # Attention mechanism
326
+ class RelativeMultiHeadAttention2(nn.Module):
327
+ def __init__(self, d_model: int = 512, num_heads: int = 16, dropout_p: float = 0.1):
328
+ super().__init__()
329
+ assert d_model % num_heads == 0
330
+ self.d_model = d_model
331
+ self.num_heads = num_heads
332
+ self.d_head = d_model // num_heads
333
+ self.sqrt_dim = math.sqrt(d_model)
334
+
335
+ self.query_proj = nn.Linear(d_model, d_model)
336
+ self.key_proj = nn.Linear(d_model, d_model)
337
+ self.value_proj = nn.Linear(d_model, d_model)
338
+ self.pos_proj = nn.Linear(d_model, d_model)
339
+ self.out_proj = nn.Linear(d_model, d_model)
340
+
341
+ self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
342
+ self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
343
+ torch.nn.init.xavier_uniform_(self.u_bias)
344
+ torch.nn.init.xavier_uniform_(self.v_bias)
345
+ self.dropout = nn.Dropout(dropout_p)
346
+
347
+ def forward(self, query, key, value, pos_embedding, mask=None):
348
+ batch_size = value.size(0)
349
+
350
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
351
+ key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
352
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
353
+
354
+ pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
355
+
356
+ content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
357
+ pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
358
+ pos_score = self._compute_relative_positional_encoding(pos_score)
359
+
360
+ score = (content_score + pos_score) / self.sqrt_dim
361
+
362
+ if mask is not None:
363
+ mask = mask.unsqueeze(1)
364
+ score.masked_fill_(mask, -1e9)
365
+
366
+ attn = F.softmax(score, -1)
367
+ attn = self.dropout(attn)
368
+
369
+ context = torch.matmul(attn, value).transpose(1, 2)
370
+ context = context.contiguous().view(batch_size, -1, self.d_model)
371
+
372
+ return self.out_proj(context)
373
+
374
+ def _compute_relative_positional_encoding(self, pos_score):
375
+ batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
376
+ zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
377
+ padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
378
+
379
+ padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
380
+ pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
381
+
382
+ return pos_score
383
+
384
+
385
  # Model components
386
  class MaGating(nn.Module):
387
  def __init__(self, d_model):
 
437
  class ValueHead(nn.Module):
438
  def __init__(self, d_model):
439
  super().__init__()
440
+ self.dense1 = nn.Linear(d_model, 128)
441
+ self.dense2 = nn.Linear(128*64, 128)
442
+ self.dense3 = nn.Linear(128, 3)
443
+
 
 
 
444
  def forward(self, x):
445
+ b, _, _ = x.size()
446
+ x = self.dense1(x)
447
+ x = F.gelu(x)
448
+ x = x.view(b, -1)
449
+ x = self.dense2(x)
450
+ x = F.gelu(x)
451
+ x = self.dense3(x)
 
452
  return x
453
+
454
 
455
  class ValueHeadQ(nn.Module):
456
  def __init__(self, d_model):
457
  super().__init__()
458
+ self.dense1 = nn.Linear(d_model, 128)
459
+ self.dense2 = nn.Linear(128*64, 128)
460
+ self.dense3 = nn.Linear(128, 3)
461
+
 
 
 
462
  def forward(self, x):
463
+ b, _, _ = x.size()
464
+ x = self.dense1(x)
465
+ x = F.gelu(x)
466
+ x = x.view(b, -1)
467
+ x = self.dense2(x)
468
+ x = F.gelu(x)
469
+ x = self.dense3(x)
 
470
  return x
471
 
472
 
473
+ # Main HuggingFace compatible model class
474
  class ChessBotPreTrainedModel(PreTrainedModel):
475
  """
476
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
 
495
 
496
  class ChessBotModel(ChessBotPreTrainedModel):
497
  """
498
+ HuggingFace compatible ChessBot Chess model with ALL original functionality
499
  """
500
 
501
  def __init__(self, config):
502
  super().__init__(config)
503
  self.config = config
504
 
505
+ # Initialize exactly like the original BT4 model
506
  self.is_thinking_model = False
507
  self.d_model = config.d_model
508
  self.num_layers = config.num_layers
509
 
510
+ # Model layers - same as original
511
  self.layers = nn.ModuleList([
512
  EncoderLayer(config.d_model, config.d_ff, config.num_heads)
513
  for _ in range(config.num_layers)
 
527
  # Initialize weights
528
  self.post_init()
529
 
530
+ @classmethod
531
+ def from_pretrained(cls, model_path, **kwargs):
532
+ """
533
+ Load a pretrained model from a directory (HuggingFace compatible)
534
+ """
535
+ import os
536
+
537
+ # Load config
538
+ config_path = os.path.join(model_path, "config.json")
539
+ if os.path.exists(config_path):
540
+ config = ChessBotConfig.from_pretrained(model_path)
541
+ else:
542
+ config = ChessBotConfig()
543
+
544
+ # Create model instance
545
+ model = cls(config)
546
+
547
+ # Load weights
548
+ model_file = None
549
+ for filename in ["pytorch_model.bin", "model.safetensors"]:
550
+ full_path = os.path.join(model_path, filename)
551
+ if os.path.exists(full_path):
552
+ model_file = full_path
553
+ break
554
+
555
+ if model_file is None:
556
+ raise FileNotFoundError(f"No model file found in {model_path}")
557
+
558
+ if model_file.endswith('.safetensors'):
559
+ # Handle safetensors format
560
+ try:
561
+ from safetensors import safe_open
562
+ state_dict = {}
563
+ with safe_open(model_file, framework="pt", device="cpu") as f:
564
+ for key in f.keys():
565
+ state_dict[key] = f.get_tensor(key)
566
+ except ImportError:
567
+ raise ImportError("safetensors library is required to load .safetensors files. Install with: pip install safetensors")
568
+ else:
569
+ # Handle pytorch format
570
+ state_dict = torch.load(model_file, map_location="cpu")
571
+
572
+ # Load state dict into model
573
+ model.load_state_dict(state_dict, strict=False)
574
+
575
+ return model
576
+
577
+ def save_pretrained(self, save_directory, safe_serialization=False):
578
+ """
579
+ Save the model to a directory (HuggingFace compatible)
580
+ """
581
+ import os
582
+ os.makedirs(save_directory, exist_ok=True)
583
+
584
+ # Save config
585
+ self.config.save_pretrained(save_directory)
586
+
587
+ # Save model weights
588
+ if safe_serialization:
589
+ try:
590
+ from safetensors.torch import save_file
591
+ model_path = os.path.join(save_directory, "model.safetensors")
592
+ save_file(self.state_dict(), model_path)
593
+ except ImportError:
594
+ print("⚠ Warning: safetensors not available, falling back to pytorch_model.bin")
595
+ model_path = os.path.join(save_directory, "pytorch_model.bin")
596
+ torch.save(self.state_dict(), model_path)
597
+ else:
598
+ model_path = os.path.join(save_directory, "pytorch_model.bin")
599
+ torch.save(self.state_dict(), model_path)
600
+
601
  def forward(self, input_ids, attention_mask=None, compute_loss=False):
602
  """
603
+ Forward pass compatible with both HuggingFace interface and original interface
604
  """
605
+ # Handle both HF interface (input_ids) and original interface (tuple)
606
+ if isinstance(input_ids, tuple):
607
+ inp = input_ids
608
+ x = inp[0]
609
+ compute_loss = compute_loss or len(inp) > 1
610
+ else:
611
+ x = input_ids
612
+ inp = (x,)
613
+
614
  b, seq_len, _, _, emb = x.size()
615
  x = x.view(b * seq_len, 64, emb)
616
 
 
620
  x = self.ma_gating(x)
621
 
622
  pos_enc = self.positional(x)
623
+ for i in range(self.num_layers):
624
+ x = self.layers[i](x, pos_enc)
 
625
 
626
  value_h = self.value_head(x)
627
  value_h = value_h.view(b, seq_len, 3)
 
643
 
644
  policy = self.policy_head(policy_attn_logits)
645
 
646
+ if compute_loss:
647
+ targets = inp[1]
648
+ true_values = inp[3]
649
+ q_values = inp[4]
650
+ loss_policy = F.cross_entropy(policy.view(-1, policy.size(-1)), targets.view(-1), ignore_index=1928)
651
+ z = torch.argmax(true_values, dim=-1)
652
+ loss_value = F.cross_entropy(value_h.view(-1, value_h.size(-1)), z.view(-1), ignore_index=3)
653
+ value_h_q = torch.softmax(value_h_q, dim=-1)
654
+ loss_q = F.mse_loss(value_h_q.view(-1, value_h_q.size(-1)), q_values.view(-1, 3))
655
+ return policy, value_h, loss_policy, loss_value, loss_q, targets, z
656
+
657
  return BaseModelOutput(
658
  last_hidden_state=x,
659
  hidden_states=None,
660
  attentions=None,
661
  ), policy, value_h, value_h_q
662
+
663
  def get_move_from_fen_no_thinking(self, fen, T=1, device="cuda", force_legal=True, return_probs=False):
664
  """
665
  Get a move from FEN string without thinking
 
720
 
721
  return selected_move
722
 
723
+ def get_position_value(self, fen, device="cuda"):
724
+ """
725
+ Get the value evaluation for a given FEN position.
726
+ Returns the value vector [black_win_prob, draw_prob, white_win_prob]
727
+ """
728
+ x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
729
+ x = x.view(1, 1, 8, 8, 19)
730
+
731
+ # Forward pass through the model to get value
732
+ with torch.no_grad():
733
+ # We need to run through the model layers to get to value_head
734
+ b, seq_len, _, _, emb = x.size()
735
+ x_processed = x.view(b * seq_len, 64, emb)
736
+ x_processed = self.linear1(x_processed)
737
+ x_processed = F.gelu(x_processed)
738
+ x_processed = self.layernorm1(x_processed)
739
+ x_processed = self.ma_gating(x_processed)
740
+
741
+ pos_enc = self.positional(x_processed)
742
+ for i in range(self.num_layers):
743
+ x_processed = self.layers[i](x_processed, pos_enc)
744
+
745
+ value_logits = self.value_head_q(x_processed)
746
+ value_logits = value_logits.view(b, seq_len, 3)
747
+ value_logits = torch.softmax(value_logits, dim=-1)
748
+
749
+ return value_logits.squeeze() # Remove batch and sequence dimensions
750
+
751
+ def get_batch_position_values(self, fens, device="cuda"):
752
+ """
753
+ Get the value evaluation for a batch of FEN positions efficiently.
754
+ Args:
755
+ fens: List of FEN strings
756
+ device: Device to run computations on
757
+ Returns:
758
+ value_probs: Tensor of shape [batch_size, 3] with [black_win_prob, draw_prob, white_win_prob] for each position
759
+ """
760
+ if len(fens) == 0:
761
+ return torch.empty(0, 3, device=device)
762
+
763
+ # Convert all FENs to tensors and stack them
764
+ position_tensors = []
765
+ for fen in fens:
766
+ x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
767
+ position_tensors.append(x)
768
+
769
+ # Stack to create batch: [batch_size, 8, 8, 19]
770
+ batch_x = torch.stack(position_tensors, dim=0)
771
+ # Reshape to [batch_size, 1, 8, 8, 19] for the model
772
+ batch_x = batch_x.unsqueeze(1)
773
+
774
+ # Forward pass through the model to get values
775
+ with torch.no_grad():
776
+ b, seq_len, _, _, emb = batch_x.size()
777
+ x_processed = batch_x.view(b * seq_len, 64, emb)
778
+ x_processed = self.linear1(x_processed)
779
+ x_processed = F.gelu(x_processed)
780
+ x_processed = self.layernorm1(x_processed)
781
+ x_processed = self.ma_gating(x_processed)
782
+
783
+ pos_enc = self.positional(x_processed)
784
+ for i in range(self.num_layers):
785
+ x_processed = self.layers[i](x_processed, pos_enc)
786
+
787
+ value_logits = self.value_head_q(x_processed)
788
+ value_logits = value_logits.view(b, seq_len, 3)
789
+ value_logits = torch.softmax(value_logits, dim=-1)
790
+ return value_logits.squeeze(1) # Remove sequence dimension, keep batch dimension
791
+
792
+ def calculate_move_values(self, fen, device="cuda"):
793
+ """
794
+ Calculate the value for each legal move from the given position efficiently using batching.
795
+ For white to move, value = white_win_prob - black_win_prob
796
+ For black to move, value = black_win_prob - white_win_prob
797
+ """
798
+ board = chess.Board()
799
+ board.set_fen(fen)
800
+
801
+ # Determine whose turn it is
802
+ is_white_turn = board.turn == chess.WHITE
803
+
804
+ legal_moves = list(board.legal_moves)
805
+ if len(legal_moves) == 0:
806
+ return [], torch.empty(0, device=device)
807
+
808
+ # Get all resulting FENs after each move
809
+ resulting_fens = []
810
+ for move in legal_moves:
811
+ board.push(move)
812
+ resulting_fens.append(board.fen())
813
+ board.pop()
814
+
815
+ # Batch process all positions in a single inference
816
+ batch_value_q = self.get_batch_position_values(resulting_fens, device)
817
+
818
+ # Calculate values from the current player's perspective
819
+ # batch_value_probs[:, 0] = black_win_prob, [:, 1] = draw_prob, [:, 2] = white_win_prob
820
+ batch_value_q = batch_value_q[:,2]-batch_value_q[:,0]
821
+ if is_white_turn:
822
+ # White's perspective: white_win_prob - black_win_prob
823
+ player_values = batch_value_q
824
+ else:
825
+ # Black's perspective: black_win_prob - white_win_prob
826
+ player_values = -batch_value_q
827
+
828
+ return legal_moves, player_values
829
+
830
+ def get_best_move_value(self, fen, T=1, device="cuda", return_probs=False):
831
+ """
832
+ Determine the best move based on the value of resulting positions using efficient batching.
833
+
834
+ Args:
835
+ fen: FEN string of the position (works for both white and black to move)
836
+ T: Temperature for sampling (T=0 for greedy, T>0 for stochastic)
837
+ device: Device to run computations on
838
+ return_probs: Whether to return the probability distribution
839
+
840
+ Returns:
841
+ move: UCI string of the selected move
842
+ probs (optional): probability distribution over moves if return_probs=True
843
+ """
844
+ legal_moves, move_values = self.calculate_move_values(fen, device)
845
+
846
+ if len(legal_moves) == 0:
847
+ raise ValueError("No legal moves available")
848
+
849
+ if T == 0:
850
+ # Greedy selection - choose move with highest value
851
+ best_idx = torch.argmax(move_values)
852
+ selected_move = legal_moves[best_idx]
853
+ else:
854
+ # Stochastic selection based on move values
855
+ # Convert values to probabilities using softmax with temperature
856
+ probs = F.softmax(move_values / T, dim=0)
857
+
858
+ # Sample according to probabilities
859
+ sampled_idx = torch.multinomial(probs, num_samples=1)
860
+ selected_move = legal_moves[sampled_idx.item()]
861
+
862
+ # Convert chess.Move to UCI string
863
+ move_uci = selected_move.uci()
864
+
865
+ if return_probs:
866
+ if T == 0:
867
+ # Create one-hot distribution for greedy case
868
+ probs = torch.zeros_like(move_values)
869
+ probs[best_idx] = 1.0
870
+ else:
871
+ probs = F.softmax(move_values / T, dim=0)
872
+
873
+ # Create dictionary with move strings as keys
874
+ move_dict = {}
875
+ for i, move in enumerate(legal_moves):
876
+ move_dict[move.uci()] = probs[i].item()
877
+ return move_uci, move_dict
878
+
879
+ return move_uci
880
+
881
 
882
  # Register the configuration and model with transformers
883
  AutoConfig.register("chessbot", ChessBotConfig)
884
  AutoModel.register(ChessBotConfig, ChessBotModel)
885
 
886
+ # For backward compatibility
887
  ChessBot = ChessBotModel
888
+ BT4Model = ChessBotModel
usage_example.py CHANGED
@@ -10,25 +10,33 @@ This model can be used without installing any external packages except:
10
 
11
  import torch
12
  import sys
13
- sys.path.append("./") # Add the model directory to path
 
 
 
 
14
  from modeling_chessbot import ChessBotModel, ChessBotConfig
15
 
16
  # Load the model
17
  config = ChessBotConfig()
18
- model = ChessBotModel.from_pretrained("./")
19
-
20
- # Alternative: You can also try AutoModel (may require additional setup)
21
- # from transformers import AutoModel
22
- # model = AutoModel.from_pretrained("./", trust_remote_code=True)
23
 
24
  # Example usage
25
  fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  model = model.to(device)
28
 
29
- # Get the best move
30
- move = model.get_move_from_fen_no_thinking(fen, T=0.1, device=device)
31
- print(f"Best move: {move}")
 
 
 
 
 
 
 
 
32
 
33
  # Get move probabilities
34
  probs = model.get_move_from_fen_no_thinking(fen, T=0.1, device=device, return_probs=True)
 
10
 
11
  import torch
12
  import sys
13
+ import os
14
+
15
+ # Get the directory of this script (the model directory)
16
+ model_dir = os.path.dirname(os.path.abspath(__file__))
17
+ sys.path.append(model_dir) # Add the model directory to path
18
  from modeling_chessbot import ChessBotModel, ChessBotConfig
19
 
20
  # Load the model
21
  config = ChessBotConfig()
22
+ model = ChessBotModel.from_pretrained(model_dir)
 
 
 
 
23
 
24
  # Example usage
25
  fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  model = model.to(device)
28
 
29
+ # Get the best move using policy
30
+ policy_move = model.get_move_from_fen_no_thinking(fen, T=0.1, device=device)
31
+ print(f"Policy-based move: {policy_move}")
32
+
33
+ # Get the best move using value analysis
34
+ value_move = model.get_best_move_value(fen, T=0.1, device=device)
35
+ print(f"Value-based move: {value_move}")
36
+
37
+ # Get position evaluation
38
+ position_value = model.get_position_value(fen, device=device)
39
+ print(f"Position value [black_win, draw, white_win]: {position_value}")
40
 
41
  # Get move probabilities
42
  probs = model.get_move_from_fen_no_thinking(fen, T=0.1, device=device, return_probs=True)