File size: 6,728 Bytes
2311079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import os
import sqlite3
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
import chess
import chess.engine
import numpy as np
from intervaltree import IntervalTree, Interval
from tqdm import tqdm
lock = threading.Lock()
def create_interval_tree_with_distribution(lowest_rating, highest_rating, mean_rating, std_dev, max_puzzles, step=10):
# Generate a normal distribution of ratings
normal_ratings = np.random.normal(mean_rating, std_dev, max_puzzles).astype(int)
# Create an IntervalTree
tree = IntervalTree()
for start in range(lowest_rating, highest_rating, step):
end = start + step
count = ((normal_ratings >= start) & (normal_ratings < end)).sum()
tree[start:end] = count
return tree
def view_db_content(db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
for table in tables:
print(f"Table: {table[0]}")
cursor.execute(f"SELECT * FROM {table[0]};")
rows = cursor.fetchall()
for row in rows:
print(row)
print("\n")
conn.close()
def validate_move(engine_path, fen, moves):
return True
engine = chess.engine.SimpleEngine.popen_uci(engine_path)
engine.configure({"Skill Level": 20})
board = chess.Board(fen)
board.push_uci(moves[0])
result = engine.play(board, chess.engine.Limit(time=2.0))
stockfish_move = result.move
info = engine.analyse(board, chess.engine.Limit(time=2.0))
score = info['score'].relative.score(mate_score=10000)
if score is None:
winning_chance = None
elif score > 0:
winning_chance = 1 / (1 + 10 ** (-score / 400))
else:
winning_chance = 1 - 1 / (1 + 10 ** (score / 400))
engine.quit()
return stockfish_move.uci() == moves[1] and winning_chance > 0.9
def create_table(cursor, table_name, headers):
column_types = {header: 'TEXT' for header in headers}
for col in ['Rating', 'RatingDeviation', 'Popularity', 'NbPlays']:
if col in column_types:
column_types[col] = 'INTEGER'
columns = ', '.join([f'"{header}" {column_types[header]}' for header in headers])
cursor.execute(f'CREATE TABLE IF NOT EXISTS "{table_name}" ({columns}, UNIQUE("PuzzleId"))')
def insert_puzzle(cursor, table_name, headers, row, written_puzzle_ids, pbar: tqdm):
with lock:
if row['PuzzleId'] in written_puzzle_ids or len(written_puzzle_ids) >= max_puzzles:
return
written_puzzle_ids.add(row['PuzzleId'])
pbar.update(1)
placeholders = ', '.join(['?' for _ in headers])
try:
cursor.execute(f'INSERT INTO "{table_name}" VALUES ({placeholders})', [row[header] for header in headers])
except sqlite3.IntegrityError:
pass
def process_tasks(tasks, cursor_output, table_name, headers, written_puzzle_ids, pbar, max_puzzles,
incorrect_puzzle_ids, interval_tree):
for task in as_completed(tasks):
row_dict = task.row_dict
if len(written_puzzle_ids) >= max_puzzles:
break
if task.result():
rating = int(row_dict['Rating'])
intervals = interval_tree[rating]
if intervals:
interval = intervals.pop()
interval_tree.remove(interval)
new_interval = Interval(interval.begin, interval.end, interval.data - 1)
if new_interval.data > 0:
interval_tree.add(new_interval)
insert_puzzle(cursor_output, table_name, headers, row_dict, written_puzzle_ids, pbar)
else:
incorrect_puzzle_ids.add(row_dict['PuzzleId'])
def validate_and_store_moves(sqlite_input_db_path, engine_path, sqlite_output_db_path, lowest_rating, highest_rating,
mean_rating, std_dev, max_puzzles, step=10):
if os.path.exists(sqlite_output_db_path):
os.remove(sqlite_output_db_path)
written_puzzle_ids, incorrect_puzzle_ids = set(), set()
conn_input = sqlite3.connect(sqlite_input_db_path)
cursor_input = conn_input.cursor()
conn_output = sqlite3.connect(sqlite_output_db_path)
cursor_output = conn_output.cursor()
cursor_input.execute("SELECT * FROM lichess_db_puzzle ORDER BY Popularity DESC, Rating DESC, NbPlays DESC")
headers = [description[0] for description in cursor_input.description]
table_name = "lichess_db_puzzle"
create_table(cursor_output, table_name, headers)
interval_tree = create_interval_tree_with_distribution(lowest_rating, highest_rating, mean_rating, std_dev,
max_puzzles, step)
tasks = []
with ThreadPoolExecutor() as executor, tqdm(total=max_puzzles) as pbar:
for row in cursor_input:
if len(written_puzzle_ids) >= max_puzzles:
break
row_dict = dict(zip(headers, row))
rating = int(row_dict['Rating'])
popularity = int(row_dict['Popularity'])
nb_plays = int(row_dict['NbPlays'])
rating_deviation = int(row_dict['RatingDeviation'])
if not (lowest_rating <= rating <= highest_rating and
popularity >= 90 and
nb_plays >= 1000 and
rating_deviation < 100):
continue
fen = row_dict['FEN']
moves = row_dict['Moves'].split()
if len(moves) < 2:
print(f"Puzzle ID {row_dict['PuzzleId']}: Not enough moves to validate")
continue
if row_dict['PuzzleId'] in written_puzzle_ids:
continue
future = executor.submit(validate_move, engine_path, fen, moves)
future.row_dict = row_dict
tasks.append(future)
if len(tasks) >= 10:
process_tasks(tasks, cursor_output, table_name, headers, written_puzzle_ids, pbar, max_puzzles,
incorrect_puzzle_ids, interval_tree)
tasks = []
process_tasks(tasks, cursor_output, table_name, headers, written_puzzle_ids, pbar, max_puzzles,
incorrect_puzzle_ids, interval_tree)
conn_output.commit()
conn_output.close()
conn_input.close()
engine_path = '/opt/homebrew/Cellar/stockfish/17.1/bin/stockfish'
sqlite_db_path = 'validated_puzzles.db'
sqlite_input_db_path = 'all_puzzles.db'
max_puzzles = 120000
validate_and_store_moves(sqlite_input_db_path, engine_path, sqlite_db_path, 1000, 3500, 2200, 400, max_puzzles)
# view_db_content(sqlite_db_path)
|