Chesstour / main.py
JosephCatrambone's picture
Sync to 9c3c1fc83e7a704152fd11ac255e7df9a0a959ca
02f6666
raw
history blame
2.82 kB
import os
import sys
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import data
from model import ChessModel
def train():
device_string = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_string)
model = ChessModel(256).to(torch.float32).to(device)
opt = torch.optim.Adam(model.parameters())
reconstruction_loss_fn = nn.CrossEntropyLoss().to(torch.float32).to(device)
popularity_loss_fn = nn.L1Loss().to(torch.float32).to(device)
evaluation_loss_fn = nn.L1Loss().to(torch.float32).to(device)
data_loader = DataLoader(data.LichessPuzzleDataset(cap_data=65536), batch_size=64, num_workers=1) # 1 to avoid threading madness.
num_epochs = 100
for epoch in range(num_epochs):
model.train()
total_reconstruction_loss = 0.0
total_popularity_loss = 0.0
total_evaluation_loss = 0.0
total_batch_loss = 0.0
num_batches = 0
for batch_idx, (board_vec, popularity, evaluation) in tqdm(enumerate(data_loader)):
board_vec = board_vec.to(torch.float32).to(device) # [batch_size x 903]
popularity = popularity.to(torch.float32).to(device).unsqueeze(1) # enforce [batch_size, 1]
evaluation = evaluation.to(torch.float32).to(device).unsqueeze(1)
_embedding, predicted_popularity, predicted_evaluation, predicted_board_vec = model(board_vec)
reconstruction_loss = reconstruction_loss_fn(predicted_board_vec, board_vec)
popularity_loss = popularity_loss_fn(predicted_popularity, popularity)
evaluation_loss = evaluation_loss_fn(predicted_evaluation, evaluation)
total_loss = reconstruction_loss + popularity_loss + evaluation_loss
opt.zero_grad()
total_loss.backward()
opt.step()
total_reconstruction_loss += reconstruction_loss.cpu().item()
total_popularity_loss += popularity_loss.cpu().item()
total_evaluation_loss += evaluation_loss.cpu().item()
total_batch_loss += total_loss.cpu().item()
num_batches += 1
print(f"Average reconstruction loss: {total_reconstruction_loss/num_batches}")
print(f"Average popularity loss: {total_popularity_loss/num_batches}")
print(f"Average evaluation loss: {total_evaluation_loss/num_batches}")
print(f"Average batch loss: {total_batch_loss/num_batches}")
torch.save(model, f"checkpoints/epoch_{epoch}.pth")
def infer(fen):
pass
def test():
pass
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python {sys.argv[0]} --train|infer")
elif sys.argv[1] == "--train":
train()
elif sys.argv[2] == "--infer":
infer(sys.argv[3])