Jean-baptiste Aujogue commited on
Commit
7aeb2a0
·
unverified ·
2 Parent(s): e90314c ad6d9bc

Base cube and moves with sparse tensor multiplications

Browse files
.gitignore CHANGED
@@ -1,194 +1,12 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
-
6
- # C extensions
7
- *.so
8
-
9
- # Distribution / packaging
10
- .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- share/python-wheels/
24
- *.egg-info/
25
- .installed.cfg
26
- *.egg
27
- MANIFEST
28
-
29
- # PyInstaller
30
- # Usually these files are written by a python script from a template
31
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
- *.manifest
33
- *.spec
34
-
35
- # Installer logs
36
- pip-log.txt
37
- pip-delete-this-directory.txt
38
-
39
- # Unit test / coverage reports
40
- htmlcov/
41
- .tox/
42
- .nox/
43
  .coverage
44
- .coverage.*
45
- .cache
46
- nosetests.xml
47
- coverage.xml
48
- *.cover
49
- *.py,cover
50
- .hypothesis/
51
- .pytest_cache/
52
- cover/
53
-
54
- # Translations
55
- *.mo
56
- *.pot
57
-
58
- # Django stuff:
59
- *.log
60
- local_settings.py
61
- db.sqlite3
62
- db.sqlite3-journal
63
-
64
- # Flask stuff:
65
- instance/
66
- .webassets-cache
67
-
68
- # Scrapy stuff:
69
- .scrapy
70
-
71
- # Sphinx documentation
72
- docs/_build/
73
-
74
- # PyBuilder
75
- .pybuilder/
76
- target/
77
-
78
- # Jupyter Notebook
79
- .ipynb_checkpoints
80
-
81
- # IPython
82
- profile_default/
83
- ipython_config.py
84
-
85
- # pyenv
86
- # For a library or package, you might want to ignore these files since the code is
87
- # intended to run in multiple environments; otherwise, check them in:
88
- # .python-version
89
-
90
- # pipenv
91
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
- # install all needed dependencies.
95
- #Pipfile.lock
96
-
97
- # UV
98
- # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
- # This is especially recommended for binary packages to ensure reproducibility, and is more
100
- # commonly ignored for libraries.
101
- #uv.lock
102
-
103
- # poetry
104
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
- # This is especially recommended for binary packages to ensure reproducibility, and is more
106
- # commonly ignored for libraries.
107
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
- #poetry.lock
109
-
110
- # pdm
111
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
- #pdm.lock
113
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
- # in version control.
115
- # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
- .pdm.toml
117
- .pdm-python
118
- .pdm-build/
119
-
120
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
- __pypackages__/
122
-
123
- # Celery stuff
124
- celerybeat-schedule
125
- celerybeat.pid
126
-
127
- # SageMath parsed files
128
- *.sage.py
129
-
130
- # Environments
131
  .env
132
  .venv
133
- env/
134
- venv/
135
- ENV/
136
- env.bak/
137
- venv.bak/
138
-
139
- # Spyder project settings
140
- .spyderproject
141
- .spyproject
142
-
143
- # Rope project settings
144
- .ropeproject
145
-
146
- # mkdocs documentation
147
- /site
148
-
149
- # mypy
150
- .mypy_cache/
151
- .dmypy.json
152
- dmypy.json
153
-
154
- # Pyre type checker
155
- .pyre/
156
-
157
- # pytype static type analyzer
158
- .pytype/
159
-
160
- # Cython debug symbols
161
- cython_debug/
162
-
163
- # PyCharm
164
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
- # and can be added to the global gitignore or merged into this file. For a more nuclear
167
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
- #.idea/
169
-
170
- # Abstra
171
- # Abstra is an AI-powered process automation framework.
172
- # Ignore directories containing user credentials, local state, and settings.
173
- # Learn more at https://abstra.io/docs
174
- .abstra/
175
-
176
- # Visual Studio Code
177
- # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
178
- # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
179
- # and can be added to the global gitignore or merged into this file. However, if you prefer,
180
- # you could uncomment the following to ignore the enitre vscode folder
181
- # .vscode/
182
-
183
- # Ruff stuff:
184
- .ruff_cache/
185
 
186
- # PyPI configuration file
187
- .pypirc
188
 
189
- # Cursor
190
- # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
191
- # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
192
- # refer to https://docs.cursor.com/context/ignore-files
193
- .cursorignore
194
- .cursorindexingignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  .coverage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  .env
3
  .venv
4
+ .mypy_cache
5
+ .vscode
6
+ .ruff_cache
7
+ .ipynb_checkpoints
8
+ .pytest_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ *egg-info
 
11
 
12
+ __pycache__
 
 
 
 
 
.pre-commit-config.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v5.0.0
4
+ hooks:
5
+ - id: check-added-large-files
6
+ args: [--maxkb=5000]
7
+ - id: detect-private-key
8
+
9
+ - repo: https://github.com/astral-sh/uv-pre-commit
10
+ rev: 0.7.13
11
+ hooks:
12
+ - id: uv-lock
13
+
14
+ - repo: https://github.com/astral-sh/ruff-pre-commit
15
+ rev: v0.12.0
16
+ hooks:
17
+ - id: ruff
18
+ args: [--fix]
19
+ - id: ruff-format
20
+
21
+ - repo: https://github.com/pre-commit/mirrors-mypy
22
+ rev: v1.16.1
23
+ hooks:
24
+ - id: mypy
25
+ exclude: '^(?!src).*'
26
+
27
+ - repo: https://github.com/kynan/nbstripout
28
+ rev: 0.8.1
29
+ hooks:
30
+ - id: nbstripout
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
README.md CHANGED
@@ -1 +1,79 @@
1
- # Rubik-Tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Rubik-Tensor
2
+
3
+
4
+ ## Setup
5
+
6
+ This project uses `uv 0.7` as environment & dependency manager, and `python 3.11` as core interpreter. Install the project with
7
+
8
+ ```shell
9
+ uv venv
10
+ -- Activate env --
11
+ uv sync
12
+ pre-commit install
13
+ ```
14
+
15
+ ## Usage
16
+
17
+ ### Create a cube
18
+
19
+ ```python
20
+ from rubik.cube import Cube
21
+
22
+ cube = Cube(colors=['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
23
+
24
+ # display the cube state and history of moves
25
+ print(cube)
26
+ # UUU
27
+ # UUU
28
+ # UUU
29
+ # LLL CCC RRR BBB
30
+ # LLL CCC RRR BBB
31
+ # LLL CCC RRR BBB
32
+ # DDD
33
+ # DDD
34
+ # DDD
35
+
36
+ print(cube.history)
37
+ # []
38
+ ```
39
+
40
+ ### Perform basic moves
41
+
42
+ ```python
43
+ # shuffle the cube using 1000 random moves (random shuffling resets the history)
44
+ cube.shuffle(num_moves=1000, seed=0)
45
+
46
+ # rotate it in some way
47
+ cube.rotate('X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i')
48
+ ```
49
+
50
+ ## Roadmap
51
+
52
+ #### Fully tensorized Rubik Cube model
53
+
54
+ - ☑️ Tensorized states.
55
+ - ☑️ Tensorized actions.
56
+ - ☑️ Interface for performing actions.
57
+
58
+ #### Movement explorer
59
+
60
+ - ⬜ Explore changes resulting from a sequences of moves.
61
+ - ⬜ Find least sequences of moves satisfying some input constrains.
62
+
63
+ #### Visualization interface
64
+
65
+ #### Base solvers following rule-based policies
66
+
67
+
68
+
69
+
70
+ ## Related projects
71
+
72
+ Open-source projects related to Rubik's Cube, sorted by number of stars:
73
+ - [adrianliaw/PyCuber](https://github.com/adrianliaw/PyCuber)
74
+ - [pglass/cube](https://github.com/pglass/cube)
75
+ - [dwalton76/rubiks-cube-NxNxN-solver](https://github.com/dwalton76/rubiks-cube-NxNxN-solver)
76
+ - [bellerb/RubiksCube_Solver](https://github.com/bellerb/RubiksCube_Solver)
77
+ - [trincaog/magiccube](https://github.com/trincaog/magiccube)
78
+ - [charlstown/rubiks-cube-solver](https://github.com/charlstown/rubiks-cube-solver)
79
+ - [staetyk/NxNxN-Cubes](https://github.com/staetyk/NxNxN-Cubes)
notebooks/dev.ipynb ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "0",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "%load_ext autoreload\n",
11
+ "%autoreload 2"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "id": "1",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "import copy\n",
22
+ "\n",
23
+ "import torch"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "id": "2",
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "from rubik.cube import Cube\n",
34
+ "from rubik.action import build_actions_tensor"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "id": "3",
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "size = 3\n",
45
+ "\n",
46
+ "actions = build_actions_tensor(size)\n",
47
+ "\n",
48
+ "cube = Cube([\"U\", \"L\", \"C\", \"R\", \"B\", \"D\"], size=size)\n",
49
+ "print(cube)"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "id": "4",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "cubis = copy.deepcopy(cube)\n",
60
+ "cubis.shuffle(2000, seed=0)\n",
61
+ "print(cubis)\n",
62
+ "print(cubis.history)"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "id": "5",
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "cubis.compute_changes(\"X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i \")"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "id": "6",
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "cubis = copy.deepcopy(cube)\n",
83
+ "cubis.rotate(\"X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i \" * 1000)\n",
84
+ "print(cubis)\n",
85
+ "print(cubis.history)"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "id": "7",
92
+ "metadata": {},
93
+ "outputs": [],
94
+ "source": [
95
+ "(actions[0, 2, 0].type(torch.float32) @ actions[0, 1, 1].type(torch.float32)).type(torch.int8)"
96
+ ]
97
+ }
98
+ ],
99
+ "metadata": {
100
+ "kernelspec": {
101
+ "display_name": "Rubik-Tensor",
102
+ "language": "python",
103
+ "name": "python3"
104
+ },
105
+ "language_info": {
106
+ "codemirror_mode": {
107
+ "name": "ipython",
108
+ "version": 3
109
+ },
110
+ "file_extension": ".py",
111
+ "mimetype": "text/x-python",
112
+ "name": "python",
113
+ "nbconvert_exporter": "python",
114
+ "pygments_lexer": "ipython3",
115
+ "version": "3.11.13"
116
+ }
117
+ },
118
+ "nbformat": 4,
119
+ "nbformat_minor": 5
120
+ }
pyproject.toml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "rubik-tensor"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11,<3.12"
7
+ dependencies = [
8
+ "cupy-cuda12x>=13.4.1",
9
+ "fire>=0.7.0",
10
+ "loguru>=0.7.3",
11
+ "torch>=2.7.1",
12
+ ]
13
+
14
+ [dependency-groups]
15
+ dev = [
16
+ "jupyter>=1.1.1",
17
+ "mypy>=1.16.1",
18
+ "pre-commit>=4.2.0",
19
+ "pytest>=8.4.1",
20
+ "pytest-cov>=6.2.1",
21
+ "ruff>=0.12.0",
22
+ ]
23
+
24
+ [tool.uv]
25
+ package = true
26
+
27
+ [tool.uv.sources]
28
+ torch = { index = "torch-cu126" }
29
+
30
+ [[tool.uv.index]]
31
+ name = "torch-cu126"
32
+ url = "https://download.pytorch.org/whl/cu126"
33
+
34
+ [tool.ruff]
35
+ line-length = 120
36
+
37
+ [tool.pytest.ini_options]
38
+ addopts = "--cov src"
src/rubik/__init__.py ADDED
File without changes
src/rubik/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # from fire import Fire
2
+
3
+
4
+ # Fire({"hello": hello_world})
src/rubik/action.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from rubik.tensor_utils import build_permutation_matrix, build_cube_tensor
8
+
9
+
10
+ POS_ROTATIONS = torch.stack(
11
+ [
12
+ # rot about X: Z -> Y
13
+ torch.tensor(
14
+ [
15
+ [1, 0, 0, 0],
16
+ [0, 1, 0, 0],
17
+ [0, 0, 0, 1],
18
+ [0, 0, -1, 0],
19
+ ],
20
+ dtype=torch.int16,
21
+ ),
22
+ # rot about Y: X -> Z
23
+ torch.tensor(
24
+ [
25
+ [1, 0, 0, 0],
26
+ [0, 0, 0, -1],
27
+ [0, 0, 1, 0],
28
+ [0, 1, 0, 0],
29
+ ],
30
+ dtype=torch.int16,
31
+ ),
32
+ # rot about Z: Y -> X
33
+ torch.tensor(
34
+ [
35
+ [1, 0, 0, 0],
36
+ [0, 0, 1, 0],
37
+ [0, -1, 0, 0],
38
+ [0, 0, 0, 1],
39
+ ],
40
+ dtype=torch.int16,
41
+ ),
42
+ ]
43
+ )
44
+
45
+ POS_SHIFTS = torch.tensor(
46
+ [
47
+ [0, 0, 0, 1],
48
+ [0, 1, 0, 0],
49
+ [0, 0, 1, 0],
50
+ ],
51
+ dtype=torch.int16,
52
+ )
53
+
54
+
55
+ # rotation about X axis: 0 (Up) -> 2 (Front) -> 5 (Down) -> 4 (Back) -> 0 (Up)
56
+ # rotation about Y axis: 0 (Up) -> 1 (Left) -> 5 (Down) -> 3 (Right) -> 0 (Up)
57
+ # rotation about Z axis: 1 (Left) -> 2 (Front) -> 3 (Right) -> 4 (Back) -> 1 (Left)
58
+ FACE_ROTATIONS = torch.stack(
59
+ [
60
+ build_permutation_matrix(size=6, perm="0254"),
61
+ build_permutation_matrix(size=6, perm="0153"),
62
+ build_permutation_matrix(size=6, perm="1234"),
63
+ ]
64
+ )
65
+
66
+
67
+ def build_actions_tensor(size: int) -> torch.Tensor:
68
+ """
69
+ Built the 5D tensor carrying all rotations of a cube as matrix multiplication.
70
+ """
71
+ return torch.stack(
72
+ [
73
+ build_action_tensor(size=size, axis=axis, slice=slice, inverse=inverse)
74
+ for axis in range(3)
75
+ for slice in range(size)
76
+ for inverse in range(2)
77
+ ],
78
+ dim=0,
79
+ ).sum(dim=0, dtype=torch.int16)
80
+
81
+
82
+ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor:
83
+ """
84
+ Compute the sparse permutation tensor whose effect on a position-frozen color vector
85
+ is the rotation along the specified axis, within the specified slice and the specified
86
+ orientation.
87
+ """
88
+ tensor = build_cube_tensor(colors=list("ULCRBD"), size=size)
89
+ length = 6 * (size**2)
90
+
91
+ # extract faces impacted by the move
92
+ indices = tensor.indices().to(dtype=torch.int16) # size = (4, length)
93
+ changes = (indices[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
94
+ extract = indices[:, changes] # size = (4, n)
95
+
96
+ # apply coordinate rotation
97
+ rotated = POS_ROTATIONS[axis] @ extract # size = (4, n)
98
+ offsets = (POS_SHIFTS[axis] * (size - 1)).repeat(extract.shape[-1], 1).transpose(0, 1) # size = (4, n)
99
+ rotated = rotated + offsets # size = (4, n)
100
+
101
+ # apply face rotation
102
+ rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(torch.int16) @ FACE_ROTATIONS[axis]).argmax(dim=-1)
103
+
104
+ # from this point on, convert rotation into a position-based permutation of colors
105
+ (inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
106
+ inputs = inputs.transpose(0, 1).tolist() # size = (n, 4)
107
+ outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
108
+
109
+ # compute position-based permutation of colors equivalent to rotation converting inputs into outputs
110
+ local_to_total = dict(enumerate(changes.tolist()))
111
+ total_to_local = {ind: i for i, ind in local_to_total.items()}
112
+
113
+ local_perm = {i: inputs.index(outputs[i]) for i in range(len(inputs))}
114
+ total_perm = {
115
+ i: (i if i not in total_to_local else local_to_total[local_perm[total_to_local[i]]]) for i in range(length)
116
+ }
117
+
118
+ # convert permutation dict into sparse tensor
119
+ perm_indices = torch.tensor(
120
+ [[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
121
+ dtype=torch.int16,
122
+ )
123
+ perm_values = torch.tensor([1] * length, dtype=torch.int16)
124
+ perm_size = (3, size, 2, length, length)
125
+ return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.int16)
126
+
127
+
128
+ def parse_action_str(move: str) -> tuple[int, int, int]:
129
+ """
130
+ Convert the name of an action into a triple (axis, slice, inverse).
131
+ Examples:
132
+ 'X1' -> (0, 1, 0)
133
+ 'X2i' -> (0, 2, 1)
134
+ """
135
+ axis = "XYZ".index(move[0])
136
+ slice = int(re.findall(r"^\d+", move[1:])[0])
137
+ inverse = int(len(move) > (1 + len(str(slice))))
138
+ return (axis, slice, inverse)
139
+
140
+
141
+ def parse_actions_str(moves: str) -> list[tuple[int, int, int]]:
142
+ """
143
+ Convert a sequence of actions in a string into a list of triples (axis, slice, inverse).
144
+ Examples:
145
+ 'X1 X2i' -> [(0, 1, 0), (0, 2, 1)]
146
+ """
147
+ return [parse_action_str(move) for move in moves.strip().split()]
148
+
149
+
150
+ def sample_actions_str(num_moves: int, size: int, seed: int = 0) -> str:
151
+ """
152
+ Generate a string containing moves that are randomly sampled.
153
+ """
154
+ rng = np.random.default_rng(seed=seed)
155
+ axes = rng.choice(["X", "Y", "Z"], size=num_moves)
156
+ slices = rng.choice([str(i) for i in range(size)], size=num_moves)
157
+ orients = rng.choice(["", "i"], size=num_moves)
158
+ return " ".join("".join(move) for move in zip(axes, slices, orients))
src/rubik/cube.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from loguru import logger
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from rubik.action import build_actions_tensor, parse_actions_str, sample_actions_str
8
+ from rubik.display import stringify
9
+ from rubik.tensor_utils import build_cube_tensor
10
+
11
+
12
+ class Cube:
13
+ """
14
+ A 4D tensor filled with colors. Dimensions have the following interpretation:
15
+ - Face (from 0 to 5, with 0 = "Up", 1 = "Left", 2 = "Front", 3 = "Right", 4 = "Back", 5 = "Down").
16
+ - X coordinate (from 0 to self.size - 1, from Left to Right).
17
+ - Y coordinate (from 0 to self.size - 1, from Back to Front).
18
+ - Z coordinate (from 0 to self.size - 1, from Down to Up).
19
+
20
+ Colors filling each tensor cell are from 0 to 6, 0 being the "dark" color,
21
+ the rest according to order given in "colors" attribute.
22
+ """
23
+
24
+ def __init__(self, colors: list[str], size: int):
25
+ """
26
+ Create Cube from a given list of 6 colors and size.
27
+ Example:
28
+ cube = Cube(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
29
+ """
30
+ tensor = build_cube_tensor(colors, size)
31
+ self.coordinates = tensor.indices().transpose(0, 1).to(torch.int16)
32
+ self.state = F.one_hot(tensor.values().long(), num_classes=7).to(torch.int16)
33
+ self.actions = build_actions_tensor(size)
34
+ self.history: list[list[int]] = []
35
+ self.colors = colors
36
+ self.size = size
37
+
38
+ def to(self, device: str | torch.device) -> "Cube":
39
+ device = torch.device(device)
40
+ dtype = (
41
+ self.state.dtype
42
+ if self.state.device == device
43
+ else torch.int16
44
+ if device == torch.device("cpu")
45
+ else torch.float32
46
+ )
47
+ self.coordinates = self.coordinates.to(device=device, dtype=dtype)
48
+ self.state = self.state.to(device=device, dtype=dtype)
49
+ self.actions = self.actions.to(device=device, dtype=dtype)
50
+ logger.info(f"Using device '{self.state.device}' and dtype '{dtype}'")
51
+ return self
52
+
53
+ def reset_history(self) -> None:
54
+ """
55
+ Reset internal history of moves.
56
+ """
57
+ self.history = []
58
+ return
59
+
60
+ def shuffle(self, num_moves: int, seed: int = 0) -> None:
61
+ """
62
+ Randomly shuffle the cube by the supplied number of steps, and reset history of moves.
63
+ """
64
+ moves = sample_actions_str(num_moves, self.size, seed=seed)
65
+ self.rotate(moves)
66
+ self.reset_history()
67
+ return
68
+
69
+ def rotate(self, moves: str) -> None:
70
+ """
71
+ Apply a sequence of moves (defined as plain string) to the cube.
72
+ """
73
+ actions = parse_actions_str(moves)
74
+ for action in actions:
75
+ self.rotate_once(*action)
76
+ return
77
+
78
+ def rotate_once(self, axis: int, slice: int, inverse: int) -> None:
79
+ """
80
+ Apply a move (defined as 3 coordinates) to the cube.
81
+ """
82
+ action = self.actions[axis, slice, inverse]
83
+ self.state = action @ self.state
84
+ self.history.append([axis, slice, inverse])
85
+ return
86
+
87
+ def compute_changes(self, moves: str) -> dict[int, int]:
88
+ """
89
+ combine a sequence of moves and return the resulting changes.
90
+ """
91
+ actions = parse_actions_str(moves)
92
+ tensors = [self.actions[*action].to(torch.float32) for action in actions]
93
+ result = reduce(lambda A, B: B @ A, tensors).to(torch.int16).coalesce()
94
+ return dict(result.indices().transpose(0, 1).tolist())
95
+
96
+ def __str__(self):
97
+ """
98
+ Compute a string representation of a cube.
99
+ """
100
+ state = self.state.argmax(dim=-1).to(device="cpu", dtype=torch.int16)
101
+ return stringify(state, self.colors, self.size)
src/rubik/display.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def stringify(state: torch.Tensor, colors: list[str], size: int) -> str:
5
+ """
6
+ Compute a string representation of a cube.
7
+ """
8
+ colors = pad_colors(colors)
9
+ faces = state.reshape(6, size, size).transpose(1, 2)
10
+ faces = [[[colors[i - 1] for i in row] for row in face.tolist()] for face in faces]
11
+ space = " " * max(len(c) for c in colors) * size
12
+ l1 = "\n".join(" ".join([space, "".join(row), space, space]) for row in faces[0])
13
+ l2 = "\n".join(" ".join("".join(face[i]) for face in faces[1:5]) for i in range(size))
14
+ l3 = "\n".join(" ".join((space, "".join(row), space, space)) for row in faces[-1])
15
+ return "\n".join([l1, l2, l3])
16
+
17
+
18
+ def pad_colors(colors: list[str]) -> list[str]:
19
+ """
20
+ Pad color names to strings of equal length.
21
+ """
22
+ max_len = max(len(c) for c in colors)
23
+ return [c + " " * (max_len - len(c)) for c in colors]
src/rubik/tensor_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def build_cube_tensor(colors: list[str], size: int) -> torch.Tensor:
5
+ """
6
+ Convert a list of 6 colors and size into a sparse 4D tensor representing a cube.
7
+ """
8
+ assert (num := len(set(colors))) == 6, f"Expected 6 distinct colors, got {num}"
9
+ assert isinstance(size, int) and size > 1, f"Expected non-zero integrer size, got {size}"
10
+
11
+ # build dense tensor filled with colors
12
+ n = size - 1
13
+ tensor = torch.zeros([6, size, size, size], dtype=torch.int16)
14
+ tensor[0, :, :, n] = 1 # up
15
+ tensor[1, 0, :, :] = 2 # left
16
+ tensor[2, :, n, :] = 3 # front
17
+ tensor[3, n, :, :] = 4 # right
18
+ tensor[4, :, 0, :] = 5 # back
19
+ tensor[5, :, :, 0] = 6 # down
20
+ return tensor.to_sparse()
21
+
22
+
23
+ def build_permutation_matrix(size: int, perm: str) -> torch.Tensor:
24
+ """
25
+ Convert a permutation sting into a sparse 2D matrix.
26
+ """
27
+ perm_list = [int(p) for p in (perm + perm[0])]
28
+ perm_dict = {perm_list[i]: perm_list[i + 1] for i in range(len(perm))}
29
+ indices = torch.tensor([list(range(size)), [(perm_dict.get(i, i)) for i in range(size)]], dtype=torch.int16)
30
+ values = torch.tensor([1] * size, dtype=torch.int16)
31
+ return torch.sparse_coo_tensor(indices=indices, values=values, size=(size, size), dtype=torch.int16).coalesce()
tests/unit/test_action.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from typing import Iterable
3
+
4
+ import torch
5
+
6
+ from rubik.action import (
7
+ POS_ROTATIONS,
8
+ POS_SHIFTS,
9
+ FACE_ROTATIONS,
10
+ build_actions_tensor,
11
+ build_action_tensor,
12
+ parse_action_str,
13
+ parse_actions_str,
14
+ sample_actions_str,
15
+ )
16
+
17
+
18
+ def test_position_rotation_shape():
19
+ """
20
+ Test that POS_ROTATIONS has expected shape.
21
+ """
22
+ expected = (3, 4, 4)
23
+ observed = POS_ROTATIONS.shape
24
+ assert expected == observed, f"Position rotation tensor expected shape '{expected}', got '{observed}' instead"
25
+
26
+
27
+ @pytest.mark.parametrize(
28
+ "axis, input, expected",
29
+ [
30
+ (0, (1, 1, 0, 0), (1, 1, 0, 0)), # X -> X
31
+ (0, (1, 0, 1, 0), (1, 0, 0, -1)), # Y -> -Z
32
+ (0, (1, 0, 0, 1), (1, 0, 1, 0)), # Z -> Y
33
+ (1, (1, 1, 0, 0), (1, 0, 0, 1)), # X -> Z
34
+ (1, (1, 0, 1, 0), (1, 0, 1, 0)), # Y -> Y
35
+ (1, (1, 0, 0, 1), (1, -1, 0, 0)), # Z -> -X
36
+ (2, (1, 1, 0, 0), (1, 0, -1, 0)), # X -> -Y
37
+ (2, (1, 0, 1, 0), (1, 1, 0, 0)), # Y -> X
38
+ (2, (1, 0, 0, 1), (1, 0, 0, 1)), # Z -> Z
39
+ ],
40
+ )
41
+ def test_position_rotation(axis: int, input: Iterable[int], expected: Iterable[int]):
42
+ """
43
+ Test that POS_ROTATIONS behaves as expected.
44
+ """
45
+ out = POS_ROTATIONS[axis] @ torch.tensor(input, dtype=POS_ROTATIONS.dtype)
46
+ exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype)
47
+ assert torch.equal(out, exp), f"Position rotation tensor is incorrect along axis {axis}: {out} != {exp}"
48
+
49
+
50
+ @pytest.mark.parametrize(
51
+ "axis, size, input, expected",
52
+ [
53
+ (0, 3, (1, 1, 1, 1), (1, 1, 1, 0)),
54
+ (1, 3, (1, 1, 1, 1), (1, 0, 1, 1)),
55
+ (2, 3, (1, 1, 1, 1), (1, 1, 0, 1)),
56
+ ],
57
+ )
58
+ def test_position_shift(axis: int, size: int, input: Iterable[int], expected: Iterable[int]):
59
+ """
60
+ Test that POS_SHIFTS behaves as expected.
61
+ """
62
+ rot = POS_ROTATIONS[axis] @ (torch.tensor(input, dtype=POS_ROTATIONS.dtype) * (size - 1))
63
+ out = rot + (POS_SHIFTS[axis] * (size - 1))
64
+ exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype) * (size - 1)
65
+ assert torch.equal(out, exp), f"Position shift tensor is incorrect along axis {axis}: {out} != {exp}"
66
+
67
+
68
+ def test_face_rotation_shape():
69
+ """
70
+ Test that FACE_ROTATIONS has expected shape.
71
+ """
72
+ expected = (3, 6, 6)
73
+ observed = FACE_ROTATIONS.shape
74
+ assert expected == observed, f"Face rotation tensor expected shape '{expected}', got '{observed}' instead"
75
+
76
+
77
+ @pytest.mark.parametrize(
78
+ "axis, input, expected",
79
+ [
80
+ (0, (1, 0, 0, 0, 0, 0), (0, 0, 1, 0, 0, 0)), # rotation about X axis: 0 (Up) -> 2 (Front)
81
+ (1, (1, 0, 0, 0, 0, 0), (0, 1, 0, 0, 0, 0)), # rotation about Y axis: 0 (Up) -> 1 (Left)
82
+ (2, (0, 1, 0, 0, 0, 0), (0, 0, 1, 0, 0, 0)), # rotation about Z axis: 1 (Left) -> 2 (Front)
83
+ ],
84
+ )
85
+ def test_face_rotation(axis: int, input: Iterable[int], expected: Iterable[int]):
86
+ """
87
+ Test that POS_ROTATIONS behaves as expected.
88
+ """
89
+ out = torch.tensor(input, dtype=FACE_ROTATIONS.dtype) @ FACE_ROTATIONS[axis]
90
+ exp = torch.tensor(expected, dtype=FACE_ROTATIONS.dtype)
91
+ assert torch.equal(out, exp), f"Face rotation tensor is incorrect along axis {axis}: {out} != {exp}"
92
+
93
+
94
+ @pytest.mark.parametrize("size", [2, 3, 5, 20])
95
+ def test_build_actions_tensor_shape(size: int):
96
+ """
97
+ Test that "build_actions_tensor" output has expected shape.
98
+ """
99
+ expected = (3, size, 2, 6 * (size**2), 6 * (size**2))
100
+ observed = build_actions_tensor(size).shape
101
+ assert expected == observed, (
102
+ f"'build_actions_tensor' output has incorrect shape: expected shape '{expected}', got '{observed}' instead"
103
+ )
104
+
105
+
106
+ @pytest.mark.parametrize(
107
+ "size, axis, slice, inverse",
108
+ [
109
+ (2, 2, 1, 0),
110
+ (3, 0, 1, 1),
111
+ (5, 1, 4, 0),
112
+ ],
113
+ )
114
+ def test_build_action_tensor_shape(size: int, axis: int, slice: int, inverse: int):
115
+ """
116
+ Test that "build_actions_tensor" output has expected shape.
117
+ """
118
+ expected = (3, size, 2, 6 * (size**2), 6 * (size**2))
119
+ observed = build_action_tensor(size, axis, slice, inverse).shape
120
+ assert expected == observed, (
121
+ f"'build_action_tensor' output has incorrect shape: expected shape '{expected}', got '{observed}' instead"
122
+ )
123
+
124
+
125
+ @pytest.mark.parametrize(
126
+ "move, expected",
127
+ [
128
+ ["X1", (0, 1, 0)],
129
+ ["X25i", (0, 25, 1)],
130
+ ["Y0", (1, 0, 0)],
131
+ ["Y5i", (1, 5, 1)],
132
+ ["Z30", (2, 30, 0)],
133
+ ["Z512ijk", (2, 512, 1)],
134
+ ],
135
+ )
136
+ def test_parse_action_str(move: str, expected: tuple[int, int, int]):
137
+ """
138
+ Test that "parse_action_str" behaves as expected.
139
+ """
140
+ observed = parse_action_str(move)
141
+ assert expected == observed, (
142
+ f"'parse_action_str' output is incorrect: expected '{expected}', got '{observed}' instead"
143
+ )
144
+
145
+
146
+ @pytest.mark.parametrize(
147
+ "moves, expected",
148
+ [
149
+ [" X1 Y0 X25i Z512ijk Z30 Y5i ", [(0, 1, 0), (1, 0, 0), (0, 25, 1), (2, 512, 1), (2, 30, 0), (1, 5, 1)]],
150
+ ],
151
+ )
152
+ def test_parse_actions_str(moves: str, expected: tuple[int, int, int]):
153
+ """
154
+ Test that "parse_action_str" behaves as expected.
155
+ """
156
+ observed = parse_actions_str(moves)
157
+ assert expected == observed, (
158
+ f"'parse_actions_str' output is incorrect: expected '{expected}', got '{observed}' instead"
159
+ )
160
+
161
+
162
+ @pytest.mark.parametrize(
163
+ "num_moves, size, seed",
164
+ [
165
+ [1, 3, 0],
166
+ [1, 20, 42],
167
+ [256, 5, 21],
168
+ ],
169
+ )
170
+ def test_sample_actions_str(num_moves: int, size: int, seed: int):
171
+ """
172
+ Test that "sample_actions_str" is deterministic and outputs parsable content.
173
+ """
174
+ moves_1 = sample_actions_str(num_moves, size, seed)
175
+ moves_2 = sample_actions_str(num_moves, size, seed)
176
+ assert moves_1 == moves_2, f"'sample_actions_str' is non-deterministic: {moves_1} != {moves_2}"
177
+
178
+ parsed = parse_actions_str(moves_1)
179
+ assert len(parsed) == len(moves_1.split()), "'sample_actions_str' output cannot be parsed correctly"
tests/unit/test_cube.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ import torch
4
+
5
+ from rubik.cube import Cube
6
+
7
+
8
+ class TestCube:
9
+ """
10
+ A testing class for the Cube class.
11
+ """
12
+
13
+ @pytest.mark.parametrize(
14
+ "colors, size",
15
+ [
16
+ [["U", "L", "C", "R", "B", "D"], 3],
17
+ [["Up", "Left", "Center", "Right", "Back", "Down"], 5],
18
+ [["A", "BB", "CCC", "DDDD", "EEEEE", "FFFFFF"], 10],
19
+ ],
20
+ )
21
+ def test__init__(self, colors: list[str], size: int):
22
+ """
23
+ Test that the __init__ method produce expected attributes.
24
+ """
25
+ cube = Cube(colors, size)
26
+ assert cube.coordinates.shape == (6 * (size**2), 4), (
27
+ f"'coordinates' has incorrect shape {cube.coordinates.shape}"
28
+ )
29
+ assert cube.state.shape == (6 * (size**2), 7), f"'state' has incorrect shape {cube.state.shape}"
30
+ assert len(cube.history) == 0, "'history' field should be empty"
31
+
32
+ @pytest.mark.parametrize("device", ["cpu"])
33
+ def test_to(self, device: str | torch.device):
34
+ """
35
+ Test that the .to method behaves as expected.
36
+ """
37
+ cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
38
+ cube_2 = cube.to(device)
39
+ assert torch.equal(cube.state, cube_2.state), "cube has different state after calling 'to' method"
40
+
41
+ def test_reset_history(self):
42
+ """
43
+ Test that the .reset_history method behaves as expected.
44
+ """
45
+ cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
46
+ cube.rotate("X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i")
47
+ cube.reset_history()
48
+ assert cube.history == [], "method 'reset_history' does not flush content"
49
+
50
+ @pytest.mark.parametrize("num_moves, seed", [[50, 42]])
51
+ def test_shuffle(self, num_moves: int, seed: int):
52
+ """
53
+ Test that the .shuffle method behaves as expected.
54
+ """
55
+ cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
56
+ cube_state = cube.state.clone()
57
+ cube.shuffle(num_moves, seed)
58
+ assert cube.history == [], "method 'shuffle' does not flush content"
59
+ assert not torch.equal(cube_state, cube.state), "method 'shuffle' does not change state"
60
+
61
+ @pytest.mark.parametrize(
62
+ "moves",
63
+ [
64
+ "X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i",
65
+ "X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i" * 2,
66
+ ],
67
+ )
68
+ def test_rotate(self, moves: str):
69
+ """
70
+ Test that the .rotate method behaves as expected.
71
+ """
72
+ cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
73
+ cube_state = cube.state.clone()
74
+ cube.rotate(moves)
75
+ assert cube.history != [], "method 'rotate' does not update history"
76
+ assert not torch.equal(cube_state, cube.state), "method 'rotate' does not change state"
77
+
78
+ @pytest.mark.parametrize(
79
+ "axis, slice, inverse",
80
+ [
81
+ [0, 2, 0],
82
+ [1, 1, 1],
83
+ [2, 0, 0],
84
+ ],
85
+ )
86
+ def test_rotate_once(self, axis: int, slice: int, inverse: int):
87
+ """
88
+ Test that the .rotate_once method behaves as expected.
89
+ """
90
+ cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
91
+ cube_state = cube.state.clone()
92
+ cube.rotate_once(axis, slice, inverse)
93
+ assert cube.history == [[axis, slice, inverse]], "method 'rotate_once' does not update history"
94
+ assert not torch.equal(cube_state, cube.state), "method 'rotate_once' does not change state"
95
+
96
+ @pytest.mark.parametrize(
97
+ "moves",
98
+ [
99
+ "X2 X1i Y1i",
100
+ "X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i " * 2,
101
+ ],
102
+ )
103
+ def test_compute_changes(self, moves: str):
104
+ """
105
+ Test that the .compute_changes method behaves as expected.
106
+ """
107
+ cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
108
+ facets = cube.state.argmax(dim=-1).to(torch.int16).tolist()
109
+ changes = cube.compute_changes(moves)
110
+
111
+ # apply changes induced by moves using the permutation dict returned by 'compute_changes'
112
+ expected = [facets[changes.get(i, i)] for i in range(len(facets))]
113
+
114
+ # apply changes induced by moves using the optimized 'rotate' method
115
+ cube.rotate(moves)
116
+ observed = cube.state.argmax(dim=-1).to(torch.int16).tolist()
117
+
118
+ # assert the tow are identical
119
+ assert expected == observed, "method 'compute_changes' does not behave correctly: "
120
+
121
+ def test__str__(self):
122
+ """
123
+ Test that the __str__ method behaves as expected.
124
+ """
125
+ cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
126
+ repr = str(cube)
127
+ assert len(repr), "__str__ method returns an empty representation"
tests/unit/test_display.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ import torch
4
+
5
+ from rubik.cube import Cube
6
+ from rubik.display import stringify, pad_colors
7
+
8
+
9
+ @pytest.mark.parametrize(
10
+ "colors, size",
11
+ [
12
+ [["U", "L", "C", "R", "B", "D"], 3],
13
+ [["Up", "Left", "Center", "Right", "Back", "Down"], 5],
14
+ [["A", "BB", "CCC", "DDDD", "EEEEE", "FFFFFF"], 10],
15
+ ],
16
+ )
17
+ def test_stringify(colors: list[str], size: int):
18
+ """
19
+ Test that stringify behaves as expected.
20
+ """
21
+ cube = Cube(colors=colors, size=size)
22
+ state = cube.state.argmax(dim=-1).to(device="cpu", dtype=torch.int16)
23
+ repr = stringify(state, colors, size)
24
+ lens = {len(line) for line in repr.split("\n")}
25
+ assert len(lens) == 1, f"'stringify' lines have variable length: {lens}"
26
+
27
+
28
+ @pytest.mark.parametrize(
29
+ "colors",
30
+ [
31
+ ["U", "L", "C", "R", "B", "D"],
32
+ ["Up", "Left", "Center", "Right", "Back", "Down"],
33
+ ["A", "BB", "CCC", "DDDD", "EEEEE", "FFFFFF"],
34
+ ],
35
+ )
36
+ def test_pad_colors(colors: list[str]):
37
+ """
38
+ Test that pad_colors behaves as expected.
39
+ """
40
+ padded = pad_colors(colors)
41
+ lengths = {len(color) for color in padded}
42
+ assert len(lengths) == 1, f"'pad_colors' generates non-unique lengths: {lengths}"
tests/unit/test_tensor_utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ import torch
4
+
5
+ from rubik.tensor_utils import build_cube_tensor, build_permutation_matrix
6
+
7
+
8
+ @pytest.mark.parametrize("size", [2, 3, 5, 20])
9
+ def test_build_cube_tensor(size: int):
10
+ """
11
+ Test that build_cube_tensor behaves as expected.
12
+ """
13
+ tensor = build_cube_tensor(colors=["U", "L", "C", "R", "B", "D"], size=size)
14
+ facets = tensor.to_dense().to(dtype=torch.int8) != 0
15
+ x_sums = facets.sum(dim=(0, 2, 3)).tolist()
16
+ y_sums = facets.sum(dim=(0, 1, 3)).tolist()
17
+ z_sums = facets.sum(dim=(0, 1, 2)).tolist()
18
+ expected = [(size**2) + (4 * size)] + [4 * size] * (size - 2) + [(size**2) + (4 * size)]
19
+ assert x_sums == expected, (
20
+ f"'build_cube_tensor' has incorrect sum along X axis: expected '{expected}', got '{x_sums}'"
21
+ )
22
+ assert y_sums == expected, (
23
+ f"'build_cube_tensor' has incorrect sum along Y axis: expected '{expected}', got '{y_sums}'"
24
+ )
25
+ assert z_sums == expected, (
26
+ f"'build_cube_tensor' has incorrect sum along Z axis: expected '{expected}', got '{z_sums}'"
27
+ )
28
+
29
+
30
+ @pytest.mark.parametrize("size, perm", [[2, "01"], [3, "210"], [6, "2345"]])
31
+ def test_build_permutation_matrix(size: int, perm: str):
32
+ """
33
+ Test that build_permutation_matrix behaves as expected.
34
+ """
35
+ matrix = build_permutation_matrix(size, perm)
36
+ mapping = dict(matrix.indices().transpose(0, 1).tolist())
37
+ for i, j in zip(perm, perm[1:] + perm[0]):
38
+ assert mapping[int(i)] == int(j), f"'build_permutation_matrix' outputs has wrong behavior: {perm}, {mapping}"
uv.lock ADDED
The diff for this file is too large to render. See raw diff