File size: 6,728 Bytes
2311079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os
import sqlite3
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed

import chess
import chess.engine
import numpy as np
from intervaltree import IntervalTree, Interval
from tqdm import tqdm

lock = threading.Lock()


def create_interval_tree_with_distribution(lowest_rating, highest_rating, mean_rating, std_dev, max_puzzles, step=10):
    # Generate a normal distribution of ratings
    normal_ratings = np.random.normal(mean_rating, std_dev, max_puzzles).astype(int)

    # Create an IntervalTree
    tree = IntervalTree()
    for start in range(lowest_rating, highest_rating, step):
        end = start + step
        count = ((normal_ratings >= start) & (normal_ratings < end)).sum()
        tree[start:end] = count

    return tree


def view_db_content(db_path):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()

    for table in tables:
        print(f"Table: {table[0]}")
        cursor.execute(f"SELECT * FROM {table[0]};")
        rows = cursor.fetchall()
        for row in rows:
            print(row)
        print("\n")

    conn.close()


def validate_move(engine_path, fen, moves):
    return True
    engine = chess.engine.SimpleEngine.popen_uci(engine_path)
    engine.configure({"Skill Level": 20})

    board = chess.Board(fen)
    board.push_uci(moves[0])
    result = engine.play(board, chess.engine.Limit(time=2.0))
    stockfish_move = result.move
    info = engine.analyse(board, chess.engine.Limit(time=2.0))
    score = info['score'].relative.score(mate_score=10000)

    if score is None:
        winning_chance = None
    elif score > 0:
        winning_chance = 1 / (1 + 10 ** (-score / 400))
    else:
        winning_chance = 1 - 1 / (1 + 10 ** (score / 400))

    engine.quit()
    return stockfish_move.uci() == moves[1] and winning_chance > 0.9


def create_table(cursor, table_name, headers):
    column_types = {header: 'TEXT' for header in headers}
    for col in ['Rating', 'RatingDeviation', 'Popularity', 'NbPlays']:
        if col in column_types:
            column_types[col] = 'INTEGER'
    columns = ', '.join([f'"{header}" {column_types[header]}' for header in headers])
    cursor.execute(f'CREATE TABLE IF NOT EXISTS "{table_name}" ({columns}, UNIQUE("PuzzleId"))')


def insert_puzzle(cursor, table_name, headers, row, written_puzzle_ids, pbar: tqdm):
    with lock:
        if row['PuzzleId'] in written_puzzle_ids or len(written_puzzle_ids) >= max_puzzles:
            return
        written_puzzle_ids.add(row['PuzzleId'])
        pbar.update(1)
        placeholders = ', '.join(['?' for _ in headers])
        try:
            cursor.execute(f'INSERT INTO "{table_name}" VALUES ({placeholders})', [row[header] for header in headers])
        except sqlite3.IntegrityError:
            pass


def process_tasks(tasks, cursor_output, table_name, headers, written_puzzle_ids, pbar, max_puzzles,
                  incorrect_puzzle_ids, interval_tree):
    for task in as_completed(tasks):
        row_dict = task.row_dict
        if len(written_puzzle_ids) >= max_puzzles:
            break
        if task.result():
            rating = int(row_dict['Rating'])
            intervals = interval_tree[rating]
            if intervals:
                interval = intervals.pop()
                interval_tree.remove(interval)
                new_interval = Interval(interval.begin, interval.end, interval.data - 1)
                if new_interval.data > 0:
                    interval_tree.add(new_interval)
                    insert_puzzle(cursor_output, table_name, headers, row_dict, written_puzzle_ids, pbar)
        else:
            incorrect_puzzle_ids.add(row_dict['PuzzleId'])


def validate_and_store_moves(sqlite_input_db_path, engine_path, sqlite_output_db_path, lowest_rating, highest_rating,
                             mean_rating, std_dev, max_puzzles, step=10):
    if os.path.exists(sqlite_output_db_path):
        os.remove(sqlite_output_db_path)

    written_puzzle_ids, incorrect_puzzle_ids = set(), set()
    conn_input = sqlite3.connect(sqlite_input_db_path)
    cursor_input = conn_input.cursor()
    conn_output = sqlite3.connect(sqlite_output_db_path)
    cursor_output = conn_output.cursor()

    cursor_input.execute("SELECT * FROM lichess_db_puzzle ORDER BY Popularity DESC, Rating DESC, NbPlays DESC")
    headers = [description[0] for description in cursor_input.description]
    table_name = "lichess_db_puzzle"
    create_table(cursor_output, table_name, headers)

    interval_tree = create_interval_tree_with_distribution(lowest_rating, highest_rating, mean_rating, std_dev,
                                                           max_puzzles, step)

    tasks = []
    with ThreadPoolExecutor() as executor, tqdm(total=max_puzzles) as pbar:
        for row in cursor_input:
            if len(written_puzzle_ids) >= max_puzzles:
                break
            row_dict = dict(zip(headers, row))
            rating = int(row_dict['Rating'])
            popularity = int(row_dict['Popularity'])
            nb_plays = int(row_dict['NbPlays'])
            rating_deviation = int(row_dict['RatingDeviation'])

            if not (lowest_rating <= rating <= highest_rating and
                    popularity >= 90 and
                    nb_plays >= 1000 and
                    rating_deviation < 100):
                continue

            fen = row_dict['FEN']
            moves = row_dict['Moves'].split()
            if len(moves) < 2:
                print(f"Puzzle ID {row_dict['PuzzleId']}: Not enough moves to validate")
                continue

            if row_dict['PuzzleId'] in written_puzzle_ids:
                continue

            future = executor.submit(validate_move, engine_path, fen, moves)
            future.row_dict = row_dict
            tasks.append(future)

            if len(tasks) >= 10:
                process_tasks(tasks, cursor_output, table_name, headers, written_puzzle_ids, pbar, max_puzzles,
                              incorrect_puzzle_ids, interval_tree)
                tasks = []

        process_tasks(tasks, cursor_output, table_name, headers, written_puzzle_ids, pbar, max_puzzles,
                      incorrect_puzzle_ids, interval_tree)

    conn_output.commit()
    conn_output.close()
    conn_input.close()


engine_path = '/opt/homebrew/Cellar/stockfish/17.1/bin/stockfish'
sqlite_db_path = 'validated_puzzles.db'
sqlite_input_db_path = 'all_puzzles.db'
max_puzzles = 120000
validate_and_store_moves(sqlite_input_db_path, engine_path, sqlite_db_path, 1000, 3500, 2200, 400, max_puzzles)
# view_db_content(sqlite_db_path)