|
import os |
|
import sqlite3 |
|
import threading |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
import chess |
|
import chess.engine |
|
import numpy as np |
|
from intervaltree import IntervalTree, Interval |
|
from tqdm import tqdm |
|
|
|
lock = threading.Lock() |
|
|
|
|
|
def create_interval_tree_with_distribution(lowest_rating, highest_rating, mean_rating, std_dev, max_puzzles, step=10): |
|
|
|
normal_ratings = np.random.normal(mean_rating, std_dev, max_puzzles).astype(int) |
|
|
|
|
|
tree = IntervalTree() |
|
for start in range(lowest_rating, highest_rating, step): |
|
end = start + step |
|
count = ((normal_ratings >= start) & (normal_ratings < end)).sum() |
|
tree[start:end] = count |
|
|
|
return tree |
|
|
|
|
|
def view_db_content(db_path): |
|
conn = sqlite3.connect(db_path) |
|
cursor = conn.cursor() |
|
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") |
|
tables = cursor.fetchall() |
|
|
|
for table in tables: |
|
print(f"Table: {table[0]}") |
|
cursor.execute(f"SELECT * FROM {table[0]};") |
|
rows = cursor.fetchall() |
|
for row in rows: |
|
print(row) |
|
print("\n") |
|
|
|
conn.close() |
|
|
|
|
|
def validate_move(engine_path, fen, moves): |
|
return True |
|
engine = chess.engine.SimpleEngine.popen_uci(engine_path) |
|
engine.configure({"Skill Level": 20}) |
|
|
|
board = chess.Board(fen) |
|
board.push_uci(moves[0]) |
|
result = engine.play(board, chess.engine.Limit(time=2.0)) |
|
stockfish_move = result.move |
|
info = engine.analyse(board, chess.engine.Limit(time=2.0)) |
|
score = info['score'].relative.score(mate_score=10000) |
|
|
|
if score is None: |
|
winning_chance = None |
|
elif score > 0: |
|
winning_chance = 1 / (1 + 10 ** (-score / 400)) |
|
else: |
|
winning_chance = 1 - 1 / (1 + 10 ** (score / 400)) |
|
|
|
engine.quit() |
|
return stockfish_move.uci() == moves[1] and winning_chance > 0.9 |
|
|
|
|
|
def create_table(cursor, table_name, headers): |
|
column_types = {header: 'TEXT' for header in headers} |
|
for col in ['Rating', 'RatingDeviation', 'Popularity', 'NbPlays']: |
|
if col in column_types: |
|
column_types[col] = 'INTEGER' |
|
columns = ', '.join([f'"{header}" {column_types[header]}' for header in headers]) |
|
cursor.execute(f'CREATE TABLE IF NOT EXISTS "{table_name}" ({columns}, UNIQUE("PuzzleId"))') |
|
|
|
|
|
def insert_puzzle(cursor, table_name, headers, row, written_puzzle_ids, pbar: tqdm): |
|
with lock: |
|
if row['PuzzleId'] in written_puzzle_ids or len(written_puzzle_ids) >= max_puzzles: |
|
return |
|
written_puzzle_ids.add(row['PuzzleId']) |
|
pbar.update(1) |
|
placeholders = ', '.join(['?' for _ in headers]) |
|
try: |
|
cursor.execute(f'INSERT INTO "{table_name}" VALUES ({placeholders})', [row[header] for header in headers]) |
|
except sqlite3.IntegrityError: |
|
pass |
|
|
|
|
|
def process_tasks(tasks, cursor_output, table_name, headers, written_puzzle_ids, pbar, max_puzzles, |
|
incorrect_puzzle_ids, interval_tree): |
|
for task in as_completed(tasks): |
|
row_dict = task.row_dict |
|
if len(written_puzzle_ids) >= max_puzzles: |
|
break |
|
if task.result(): |
|
rating = int(row_dict['Rating']) |
|
intervals = interval_tree[rating] |
|
if intervals: |
|
interval = intervals.pop() |
|
interval_tree.remove(interval) |
|
new_interval = Interval(interval.begin, interval.end, interval.data - 1) |
|
if new_interval.data > 0: |
|
interval_tree.add(new_interval) |
|
insert_puzzle(cursor_output, table_name, headers, row_dict, written_puzzle_ids, pbar) |
|
else: |
|
incorrect_puzzle_ids.add(row_dict['PuzzleId']) |
|
|
|
|
|
def validate_and_store_moves(sqlite_input_db_path, engine_path, sqlite_output_db_path, lowest_rating, highest_rating, |
|
mean_rating, std_dev, max_puzzles, step=10): |
|
if os.path.exists(sqlite_output_db_path): |
|
os.remove(sqlite_output_db_path) |
|
|
|
written_puzzle_ids, incorrect_puzzle_ids = set(), set() |
|
conn_input = sqlite3.connect(sqlite_input_db_path) |
|
cursor_input = conn_input.cursor() |
|
conn_output = sqlite3.connect(sqlite_output_db_path) |
|
cursor_output = conn_output.cursor() |
|
|
|
cursor_input.execute("SELECT * FROM lichess_db_puzzle ORDER BY Popularity DESC, Rating DESC, NbPlays DESC") |
|
headers = [description[0] for description in cursor_input.description] |
|
table_name = "lichess_db_puzzle" |
|
create_table(cursor_output, table_name, headers) |
|
|
|
interval_tree = create_interval_tree_with_distribution(lowest_rating, highest_rating, mean_rating, std_dev, |
|
max_puzzles, step) |
|
|
|
tasks = [] |
|
with ThreadPoolExecutor() as executor, tqdm(total=max_puzzles) as pbar: |
|
for row in cursor_input: |
|
if len(written_puzzle_ids) >= max_puzzles: |
|
break |
|
row_dict = dict(zip(headers, row)) |
|
rating = int(row_dict['Rating']) |
|
popularity = int(row_dict['Popularity']) |
|
nb_plays = int(row_dict['NbPlays']) |
|
rating_deviation = int(row_dict['RatingDeviation']) |
|
|
|
if not (lowest_rating <= rating <= highest_rating and |
|
popularity >= 90 and |
|
nb_plays >= 1000 and |
|
rating_deviation < 100): |
|
continue |
|
|
|
fen = row_dict['FEN'] |
|
moves = row_dict['Moves'].split() |
|
if len(moves) < 2: |
|
print(f"Puzzle ID {row_dict['PuzzleId']}: Not enough moves to validate") |
|
continue |
|
|
|
if row_dict['PuzzleId'] in written_puzzle_ids: |
|
continue |
|
|
|
future = executor.submit(validate_move, engine_path, fen, moves) |
|
future.row_dict = row_dict |
|
tasks.append(future) |
|
|
|
if len(tasks) >= 10: |
|
process_tasks(tasks, cursor_output, table_name, headers, written_puzzle_ids, pbar, max_puzzles, |
|
incorrect_puzzle_ids, interval_tree) |
|
tasks = [] |
|
|
|
process_tasks(tasks, cursor_output, table_name, headers, written_puzzle_ids, pbar, max_puzzles, |
|
incorrect_puzzle_ids, interval_tree) |
|
|
|
conn_output.commit() |
|
conn_output.close() |
|
conn_input.close() |
|
|
|
|
|
engine_path = '/opt/homebrew/Cellar/stockfish/17.1/bin/stockfish' |
|
sqlite_db_path = 'validated_puzzles.db' |
|
sqlite_input_db_path = 'all_puzzles.db' |
|
max_puzzles = 120000 |
|
validate_and_store_moves(sqlite_input_db_path, engine_path, sqlite_db_path, 1000, 3500, 2200, 400, max_puzzles) |
|
|
|
|