Spaces:
Sleeping
Sleeping
Jean-baptiste Aujogue
commited on
Gradio Interface (#2)
Browse files* add plotly and gradio
* edit readme
* simplify cube object
* simplify cube object
* edit facelect computation
* base demo with 3d plot
* remove legacy plot function
* rename demo as interface
* edit plot function
* remove cupy dependency
* all dependencies
* gradio interface with 3D plot
- README.md +28 -11
- notebooks/dev.ipynb +23 -5
- pyproject.toml +3 -1
- src/rubik/__main__.py +4 -2
- src/rubik/action.py +20 -14
- src/rubik/cube.py +58 -20
- src/rubik/display.py +0 -23
- src/rubik/interface/__init__.py +0 -0
- src/rubik/interface/app.py +66 -0
- src/rubik/interface/plot.py +155 -0
- src/rubik/{tensor_utils.py → state.py} +5 -6
- tests/unit/test_cube.py +28 -25
- tests/unit/test_display.py +0 -42
- tests/unit/{test_tensor_utils.py → test_state.py} +2 -2
- uv.lock +0 -0
README.md
CHANGED
@@ -14,14 +14,20 @@ pre-commit install
|
|
14 |
|
15 |
## Usage
|
16 |
|
17 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
```python
|
20 |
from rubik.cube import Cube
|
21 |
|
22 |
-
cube = Cube(
|
23 |
|
24 |
-
# display
|
25 |
print(cube)
|
26 |
# UUU
|
27 |
# UUU
|
@@ -33,17 +39,14 @@ print(cube)
|
|
33 |
# DDD
|
34 |
# DDD
|
35 |
|
|
|
36 |
print(cube.history)
|
37 |
# []
|
38 |
-
```
|
39 |
-
|
40 |
-
### Perform basic moves
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
cube.shuffle(num_moves=1000, seed=0)
|
45 |
|
46 |
-
# rotate it in some way
|
47 |
cube.rotate('X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i')
|
48 |
```
|
49 |
|
@@ -66,8 +69,9 @@ cube.rotate('X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i')
|
|
66 |
|
67 |
|
68 |
|
|
|
69 |
|
70 |
-
|
71 |
|
72 |
Open-source projects related to Rubik's Cube, sorted by number of stars:
|
73 |
- [adrianliaw/PyCuber](https://github.com/adrianliaw/PyCuber)
|
@@ -77,3 +81,16 @@ Open-source projects related to Rubik's Cube, sorted by number of stars:
|
|
77 |
- [trincaog/magiccube](https://github.com/trincaog/magiccube)
|
78 |
- [charlstown/rubiks-cube-solver](https://github.com/charlstown/rubiks-cube-solver)
|
79 |
- [staetyk/NxNxN-Cubes](https://github.com/staetyk/NxNxN-Cubes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
## Usage
|
16 |
|
17 |
+
### Launch the web interface
|
18 |
+
|
19 |
+
```shell
|
20 |
+
python -m rubik interface
|
21 |
+
```
|
22 |
+
|
23 |
+
### Use the python API
|
24 |
|
25 |
```python
|
26 |
from rubik.cube import Cube
|
27 |
|
28 |
+
cube = Cube(size=3)
|
29 |
|
30 |
+
# display cube state
|
31 |
print(cube)
|
32 |
# UUU
|
33 |
# UUU
|
|
|
39 |
# DDD
|
40 |
# DDD
|
41 |
|
42 |
+
# display history of moves
|
43 |
print(cube.history)
|
44 |
# []
|
|
|
|
|
|
|
45 |
|
46 |
+
# scramble the cube using 1000 random moves (this resets the history)
|
47 |
+
cube.scramble(num_moves=1000, seed=0)
|
|
|
48 |
|
49 |
+
# rotate it in some way (this gets appended to history)
|
50 |
cube.rotate('X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i')
|
51 |
```
|
52 |
|
|
|
69 |
|
70 |
|
71 |
|
72 |
+
## References
|
73 |
|
74 |
+
### Python implementations & rule-based solvers
|
75 |
|
76 |
Open-source projects related to Rubik's Cube, sorted by number of stars:
|
77 |
- [adrianliaw/PyCuber](https://github.com/adrianliaw/PyCuber)
|
|
|
81 |
- [trincaog/magiccube](https://github.com/trincaog/magiccube)
|
82 |
- [charlstown/rubiks-cube-solver](https://github.com/charlstown/rubiks-cube-solver)
|
83 |
- [staetyk/NxNxN-Cubes](https://github.com/staetyk/NxNxN-Cubes)
|
84 |
+
|
85 |
+
### Machine Learning based solver models
|
86 |
+
|
87 |
+
- 2025, _CayleyPy Cube_, [Github](https://github.com/k1242/cayleypy-cube), [Paper](https://arxiv.org/html/2502.13266v1)
|
88 |
+
|
89 |
+
- 2025, _Solving A Rubik’s Cube with Supervised Learning – Intuitively and Exhaustively Explained_, [Blog post](https://towardsdatascience.com/solving-a-rubiks-cube-with-supervised-learning-intuitively-and-exhaustively-explained-4f87b72ba1e2/)
|
90 |
+
|
91 |
+
- 2024, _Solving Rubik's Cube Without Tricky Sampling_, [Paper](https://arxiv.org/abs/2411.19583).<br>
|
92 |
+
This involves training a scorer, that estimates the number of moves transforming a given source state into a given target state, where the latter is not necessarily a solved cube. Data are synthetically generated performing random moves and factorizing repeated identical moves.
|
93 |
+
|
94 |
+
- 2023, _Curious Transformer_, [Github](https://github.com/tedtedtedtedtedted/Solve-Rubiks-Cube-Via-Transformer)
|
95 |
+
|
96 |
+
- 2021, _Efficient Cube_, [Github](https://github.com/kyo-takano/efficientcube), [Paper](https://arxiv.org/abs/2106.03157)
|
notebooks/dev.ipynb
CHANGED
@@ -31,7 +31,8 @@
|
|
31 |
"outputs": [],
|
32 |
"source": [
|
33 |
"from rubik.cube import Cube\n",
|
34 |
-
"from rubik.action import build_actions_tensor"
|
|
|
35 |
]
|
36 |
},
|
37 |
{
|
@@ -45,7 +46,7 @@
|
|
45 |
"\n",
|
46 |
"actions = build_actions_tensor(size)\n",
|
47 |
"\n",
|
48 |
-
"cube = Cube(
|
49 |
"print(cube)"
|
50 |
]
|
51 |
},
|
@@ -57,7 +58,8 @@
|
|
57 |
"outputs": [],
|
58 |
"source": [
|
59 |
"cubis = copy.deepcopy(cube)\n",
|
60 |
-
"cubis.
|
|
|
61 |
"print(cubis)\n",
|
62 |
"print(cubis.history)"
|
63 |
]
|
@@ -69,7 +71,10 @@
|
|
69 |
"metadata": {},
|
70 |
"outputs": [],
|
71 |
"source": [
|
72 |
-
"cubis.
|
|
|
|
|
|
|
73 |
]
|
74 |
},
|
75 |
{
|
@@ -80,7 +85,7 @@
|
|
80 |
"outputs": [],
|
81 |
"source": [
|
82 |
"cubis = copy.deepcopy(cube)\n",
|
83 |
-
"cubis.
|
84 |
"print(cubis)\n",
|
85 |
"print(cubis.history)"
|
86 |
]
|
@@ -91,6 +96,19 @@
|
|
91 |
"id": "7",
|
92 |
"metadata": {},
|
93 |
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
"source": [
|
95 |
"(actions[0, 2, 0].type(torch.float32) @ actions[0, 1, 1].type(torch.float32)).type(torch.int8)"
|
96 |
]
|
|
|
31 |
"outputs": [],
|
32 |
"source": [
|
33 |
"from rubik.cube import Cube\n",
|
34 |
+
"from rubik.action import build_actions_tensor\n",
|
35 |
+
"from rubik.interface.plot import CubeVisualizer"
|
36 |
]
|
37 |
},
|
38 |
{
|
|
|
46 |
"\n",
|
47 |
"actions = build_actions_tensor(size)\n",
|
48 |
"\n",
|
49 |
+
"cube = Cube(size)\n",
|
50 |
"print(cube)"
|
51 |
]
|
52 |
},
|
|
|
58 |
"outputs": [],
|
59 |
"source": [
|
60 |
"cubis = copy.deepcopy(cube)\n",
|
61 |
+
"cubis.rotate(\"X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i \")\n",
|
62 |
+
"\n",
|
63 |
"print(cubis)\n",
|
64 |
"print(cubis.history)"
|
65 |
]
|
|
|
71 |
"metadata": {},
|
72 |
"outputs": [],
|
73 |
"source": [
|
74 |
+
"visualizer = CubeVisualizer(size=cubis.size)\n",
|
75 |
+
"layout_args = {\"autosize\": False, \"width\": 600, \"height\": 600}\n",
|
76 |
+
"\n",
|
77 |
+
"visualizer(cubis.coordinates, cubis.state, cubis.size).update_layout(**layout_args).show()"
|
78 |
]
|
79 |
},
|
80 |
{
|
|
|
85 |
"outputs": [],
|
86 |
"source": [
|
87 |
"cubis = copy.deepcopy(cube)\n",
|
88 |
+
"cubis.scramble(2000, seed=0)\n",
|
89 |
"print(cubis)\n",
|
90 |
"print(cubis.history)"
|
91 |
]
|
|
|
96 |
"id": "7",
|
97 |
"metadata": {},
|
98 |
"outputs": [],
|
99 |
+
"source": [
|
100 |
+
"cubis = copy.deepcopy(cube)\n",
|
101 |
+
"cubis.rotate(\"X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i \" * 1000)\n",
|
102 |
+
"print(cubis)\n",
|
103 |
+
"print(cubis.history)"
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"cell_type": "code",
|
108 |
+
"execution_count": null,
|
109 |
+
"id": "8",
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [],
|
112 |
"source": [
|
113 |
"(actions[0, 2, 0].type(torch.float32) @ actions[0, 1, 1].type(torch.float32)).type(torch.int8)"
|
114 |
]
|
pyproject.toml
CHANGED
@@ -5,9 +5,11 @@ description = "Add your description here"
|
|
5 |
readme = "README.md"
|
6 |
requires-python = ">=3.11,<3.12"
|
7 |
dependencies = [
|
8 |
-
"cupy-cuda12x>=13.4.1",
|
9 |
"fire>=0.7.0",
|
|
|
10 |
"loguru>=0.7.3",
|
|
|
|
|
11 |
"torch>=2.7.1",
|
12 |
]
|
13 |
|
|
|
5 |
readme = "README.md"
|
6 |
requires-python = ">=3.11,<3.12"
|
7 |
dependencies = [
|
|
|
8 |
"fire>=0.7.0",
|
9 |
+
"gradio>=5.38",
|
10 |
"loguru>=0.7.3",
|
11 |
+
"plotly>=6.2.0",
|
12 |
+
"pydantic>=2.11.7",
|
13 |
"torch>=2.7.1",
|
14 |
]
|
15 |
|
src/rubik/__main__.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
-
|
2 |
|
|
|
3 |
|
4 |
-
|
|
|
|
1 |
+
from fire import Fire
|
2 |
|
3 |
+
from rubik.interface.app import app
|
4 |
|
5 |
+
|
6 |
+
Fire({"interface": app})
|
src/rubik/action.py
CHANGED
@@ -4,7 +4,7 @@ import numpy as np
|
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
-
from rubik.
|
8 |
|
9 |
|
10 |
POS_ROTATIONS = torch.stack(
|
@@ -17,7 +17,7 @@ POS_ROTATIONS = torch.stack(
|
|
17 |
[0, 0, 0, 1],
|
18 |
[0, 0, -1, 0],
|
19 |
],
|
20 |
-
dtype=torch.
|
21 |
),
|
22 |
# rot about Y: X -> Z
|
23 |
torch.tensor(
|
@@ -27,7 +27,7 @@ POS_ROTATIONS = torch.stack(
|
|
27 |
[0, 0, 1, 0],
|
28 |
[0, 1, 0, 0],
|
29 |
],
|
30 |
-
dtype=torch.
|
31 |
),
|
32 |
# rot about Z: Y -> X
|
33 |
torch.tensor(
|
@@ -37,7 +37,7 @@ POS_ROTATIONS = torch.stack(
|
|
37 |
[0, -1, 0, 0],
|
38 |
[0, 0, 0, 1],
|
39 |
],
|
40 |
-
dtype=torch.
|
41 |
),
|
42 |
]
|
43 |
)
|
@@ -48,7 +48,7 @@ POS_SHIFTS = torch.tensor(
|
|
48 |
[0, 1, 0, 0],
|
49 |
[0, 0, 1, 0],
|
50 |
],
|
51 |
-
dtype=torch.
|
52 |
)
|
53 |
|
54 |
|
@@ -76,7 +76,7 @@ def build_actions_tensor(size: int) -> torch.Tensor:
|
|
76 |
for inverse in range(2)
|
77 |
],
|
78 |
dim=0,
|
79 |
-
).sum(dim=0, dtype=torch.
|
80 |
|
81 |
|
82 |
def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor:
|
@@ -85,11 +85,11 @@ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch
|
|
85 |
is the rotation along the specified axis, within the specified slice and the specified
|
86 |
orientation.
|
87 |
"""
|
88 |
-
tensor = build_cube_tensor(
|
89 |
length = 6 * (size**2)
|
90 |
|
91 |
# extract faces impacted by the move
|
92 |
-
indices = tensor.indices().to(dtype=torch.
|
93 |
changes = (indices[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
|
94 |
extract = indices[:, changes] # size = (4, n)
|
95 |
|
@@ -99,7 +99,7 @@ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch
|
|
99 |
rotated = rotated + offsets # size = (4, n)
|
100 |
|
101 |
# apply face rotation
|
102 |
-
rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(torch.
|
103 |
|
104 |
# from this point on, convert rotation into a position-based permutation of colors
|
105 |
(inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
|
@@ -110,19 +110,25 @@ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch
|
|
110 |
local_to_total = dict(enumerate(changes.tolist()))
|
111 |
total_to_local = {ind: i for i, ind in local_to_total.items()}
|
112 |
|
113 |
-
local_perm = {i:
|
114 |
total_perm = {
|
115 |
i: (i if i not in total_to_local else local_to_total[local_perm[total_to_local[i]]]) for i in range(length)
|
116 |
}
|
117 |
|
118 |
# convert permutation dict into sparse tensor
|
119 |
perm_indices = torch.tensor(
|
120 |
-
[
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
)
|
123 |
-
perm_values = torch.tensor([1] * length, dtype=torch.
|
124 |
perm_size = (3, size, 2, length, length)
|
125 |
-
return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.
|
126 |
|
127 |
|
128 |
def parse_action_str(move: str) -> tuple[int, int, int]:
|
|
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
+
from rubik.state import build_permutation_matrix, build_cube_tensor
|
8 |
|
9 |
|
10 |
POS_ROTATIONS = torch.stack(
|
|
|
17 |
[0, 0, 0, 1],
|
18 |
[0, 0, -1, 0],
|
19 |
],
|
20 |
+
dtype=torch.int32,
|
21 |
),
|
22 |
# rot about Y: X -> Z
|
23 |
torch.tensor(
|
|
|
27 |
[0, 0, 1, 0],
|
28 |
[0, 1, 0, 0],
|
29 |
],
|
30 |
+
dtype=torch.int32,
|
31 |
),
|
32 |
# rot about Z: Y -> X
|
33 |
torch.tensor(
|
|
|
37 |
[0, -1, 0, 0],
|
38 |
[0, 0, 0, 1],
|
39 |
],
|
40 |
+
dtype=torch.int32,
|
41 |
),
|
42 |
]
|
43 |
)
|
|
|
48 |
[0, 1, 0, 0],
|
49 |
[0, 0, 1, 0],
|
50 |
],
|
51 |
+
dtype=torch.int32,
|
52 |
)
|
53 |
|
54 |
|
|
|
76 |
for inverse in range(2)
|
77 |
],
|
78 |
dim=0,
|
79 |
+
).sum(dim=0, dtype=torch.int32)
|
80 |
|
81 |
|
82 |
def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor:
|
|
|
85 |
is the rotation along the specified axis, within the specified slice and the specified
|
86 |
orientation.
|
87 |
"""
|
88 |
+
tensor = build_cube_tensor(size).to(dtype=torch.int32)
|
89 |
length = 6 * (size**2)
|
90 |
|
91 |
# extract faces impacted by the move
|
92 |
+
indices = tensor.indices().to(dtype=torch.int32) # size = (4, length)
|
93 |
changes = (indices[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
|
94 |
extract = indices[:, changes] # size = (4, n)
|
95 |
|
|
|
99 |
rotated = rotated + offsets # size = (4, n)
|
100 |
|
101 |
# apply face rotation
|
102 |
+
rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(torch.int32) @ FACE_ROTATIONS[axis]).argmax(dim=-1)
|
103 |
|
104 |
# from this point on, convert rotation into a position-based permutation of colors
|
105 |
(inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
|
|
|
110 |
local_to_total = dict(enumerate(changes.tolist()))
|
111 |
total_to_local = {ind: i for i, ind in local_to_total.items()}
|
112 |
|
113 |
+
local_perm = {i: outputs.index(inputs[i]) for i in range(len(inputs))}
|
114 |
total_perm = {
|
115 |
i: (i if i not in total_to_local else local_to_total[local_perm[total_to_local[i]]]) for i in range(length)
|
116 |
}
|
117 |
|
118 |
# convert permutation dict into sparse tensor
|
119 |
perm_indices = torch.tensor(
|
120 |
+
[
|
121 |
+
[axis] * length,
|
122 |
+
[slice] * length,
|
123 |
+
[inverse] * length,
|
124 |
+
list(total_perm.keys()),
|
125 |
+
list(total_perm.values()),
|
126 |
+
],
|
127 |
+
dtype=torch.int32,
|
128 |
)
|
129 |
+
perm_values = torch.tensor([1] * length, dtype=torch.int32)
|
130 |
perm_size = (3, size, 2, length, length)
|
131 |
+
return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.int32)
|
132 |
|
133 |
|
134 |
def parse_action_str(move: str) -> tuple[int, int, int]:
|
src/rubik/cube.py
CHANGED
@@ -5,8 +5,7 @@ import torch
|
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
from rubik.action import build_actions_tensor, parse_actions_str, sample_actions_str
|
8 |
-
from rubik.
|
9 |
-
from rubik.tensor_utils import build_cube_tensor
|
10 |
|
11 |
|
12 |
class Cube:
|
@@ -21,30 +20,65 @@ class Cube:
|
|
21 |
the rest according to order given in "colors" attribute.
|
22 |
"""
|
23 |
|
24 |
-
def __init__(self,
|
25 |
"""
|
26 |
-
Create Cube
|
27 |
-
Example:
|
28 |
-
cube = Cube(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
|
29 |
"""
|
30 |
-
tensor = build_cube_tensor(
|
31 |
-
|
32 |
-
self.
|
33 |
-
self.
|
34 |
-
self.
|
35 |
-
self.
|
36 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def to(self, device: str | torch.device) -> "Cube":
|
39 |
device = torch.device(device)
|
40 |
dtype = (
|
41 |
self.state.dtype
|
42 |
if self.state.device == device
|
43 |
-
else
|
44 |
if device == torch.device("cpu")
|
45 |
else torch.float32
|
46 |
)
|
47 |
-
self.coordinates = self.coordinates.to(device=device, dtype=dtype)
|
48 |
self.state = self.state.to(device=device, dtype=dtype)
|
49 |
self.actions = self.actions.to(device=device, dtype=dtype)
|
50 |
logger.info(f"Using device '{self.state.device}' and dtype '{dtype}'")
|
@@ -54,10 +88,10 @@ class Cube:
|
|
54 |
"""
|
55 |
Reset internal history of moves.
|
56 |
"""
|
57 |
-
self.
|
58 |
return
|
59 |
|
60 |
-
def
|
61 |
"""
|
62 |
Randomly shuffle the cube by the supplied number of steps, and reset history of moves.
|
63 |
"""
|
@@ -81,7 +115,7 @@ class Cube:
|
|
81 |
"""
|
82 |
action = self.actions[axis, slice, inverse]
|
83 |
self.state = action @ self.state
|
84 |
-
self.
|
85 |
return
|
86 |
|
87 |
def compute_changes(self, moves: str) -> dict[int, int]:
|
@@ -97,5 +131,9 @@ class Cube:
|
|
97 |
"""
|
98 |
Compute a string representation of a cube.
|
99 |
"""
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
from rubik.action import build_actions_tensor, parse_actions_str, sample_actions_str
|
8 |
+
from rubik.state import build_cube_tensor
|
|
|
9 |
|
10 |
|
11 |
class Cube:
|
|
|
20 |
the rest according to order given in "colors" attribute.
|
21 |
"""
|
22 |
|
23 |
+
def __init__(self, size: int):
|
24 |
"""
|
25 |
+
Create Cube of a given size.
|
|
|
|
|
26 |
"""
|
27 |
+
tensor = build_cube_tensor(size)
|
28 |
+
|
29 |
+
self.dtype = torch.int8 if size <= 6 else torch.int16 if size <= 73 else torch.int32
|
30 |
+
self.coordinates = tensor.indices().to(self.dtype)
|
31 |
+
self.state = F.one_hot(tensor.values().long(), num_classes=7).to(self.dtype)
|
32 |
+
self.actions = build_actions_tensor(size).to(self.dtype)
|
33 |
+
self._history: list[tuple[int, int, int]] = []
|
34 |
+
self._colors: list[str] = list("ULCRBD")
|
35 |
+
self._size: int = size
|
36 |
+
|
37 |
+
@property
|
38 |
+
def history(self) -> list[tuple[int, int, int]]:
|
39 |
+
return self._history
|
40 |
+
|
41 |
+
@property
|
42 |
+
def colors(self) -> list[str]:
|
43 |
+
return self._colors
|
44 |
+
|
45 |
+
@property
|
46 |
+
def size(self) -> int:
|
47 |
+
return self._size
|
48 |
+
|
49 |
+
@property
|
50 |
+
def facelets(self) -> list[list[list[str]]]:
|
51 |
+
"""
|
52 |
+
Return the list of faces of the cube, each given by a list of rows,
|
53 |
+
each given by a list of facelets.
|
54 |
+
"""
|
55 |
+
tensor = torch.sparse_coo_tensor(
|
56 |
+
indices=self.coordinates,
|
57 |
+
values=self.state.argmax(dim=-1),
|
58 |
+
size=(6, self.size, self.size, self.size),
|
59 |
+
dtype=self.dtype,
|
60 |
+
).to_dense()
|
61 |
+
|
62 |
+
n = self.size - 1
|
63 |
+
faces = [
|
64 |
+
tensor[0, :, :, n].transpose(0, 1), # up
|
65 |
+
tensor[1, 0, :, :].flip(1).transpose(0, 1), # left
|
66 |
+
tensor[2, :, n, :].flip(1).transpose(0, 1), # front
|
67 |
+
tensor[3, n, :, :].flip(0).flip(1).transpose(0, 1), # right
|
68 |
+
tensor[4, :, 0, :].flip(0).flip(1).transpose(0, 1), # back
|
69 |
+
tensor[5, :, :, 0].flip(1).transpose(0, 1), # down
|
70 |
+
]
|
71 |
+
return [[[self.colors[i - 1] for i in row] for row in face.tolist()] for face in faces]
|
72 |
|
73 |
def to(self, device: str | torch.device) -> "Cube":
|
74 |
device = torch.device(device)
|
75 |
dtype = (
|
76 |
self.state.dtype
|
77 |
if self.state.device == device
|
78 |
+
else self.dtype
|
79 |
if device == torch.device("cpu")
|
80 |
else torch.float32
|
81 |
)
|
|
|
82 |
self.state = self.state.to(device=device, dtype=dtype)
|
83 |
self.actions = self.actions.to(device=device, dtype=dtype)
|
84 |
logger.info(f"Using device '{self.state.device}' and dtype '{dtype}'")
|
|
|
88 |
"""
|
89 |
Reset internal history of moves.
|
90 |
"""
|
91 |
+
self._history = []
|
92 |
return
|
93 |
|
94 |
+
def scramble(self, num_moves: int, seed: int = 0) -> None:
|
95 |
"""
|
96 |
Randomly shuffle the cube by the supplied number of steps, and reset history of moves.
|
97 |
"""
|
|
|
115 |
"""
|
116 |
action = self.actions[axis, slice, inverse]
|
117 |
self.state = action @ self.state
|
118 |
+
self._history.append((axis, slice, inverse))
|
119 |
return
|
120 |
|
121 |
def compute_changes(self, moves: str) -> dict[int, int]:
|
|
|
131 |
"""
|
132 |
Compute a string representation of a cube.
|
133 |
"""
|
134 |
+
space = " " * self.size
|
135 |
+
facelets = self.facelets
|
136 |
+
l1 = "\n".join(" ".join([space, "".join(row), space, space]) for row in facelets[0])
|
137 |
+
l2 = "\n".join(" ".join("".join(face[i]) for face in facelets[1:5]) for i in range(self.size))
|
138 |
+
l3 = "\n".join(" ".join((space, "".join(row), space, space)) for row in facelets[-1])
|
139 |
+
return "\n".join([l1, l2, l3])
|
src/rubik/display.py
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
|
4 |
-
def stringify(state: torch.Tensor, colors: list[str], size: int) -> str:
|
5 |
-
"""
|
6 |
-
Compute a string representation of a cube.
|
7 |
-
"""
|
8 |
-
colors = pad_colors(colors)
|
9 |
-
faces = state.reshape(6, size, size).transpose(1, 2)
|
10 |
-
faces = [[[colors[i - 1] for i in row] for row in face.tolist()] for face in faces]
|
11 |
-
space = " " * max(len(c) for c in colors) * size
|
12 |
-
l1 = "\n".join(" ".join([space, "".join(row), space, space]) for row in faces[0])
|
13 |
-
l2 = "\n".join(" ".join("".join(face[i]) for face in faces[1:5]) for i in range(size))
|
14 |
-
l3 = "\n".join(" ".join((space, "".join(row), space, space)) for row in faces[-1])
|
15 |
-
return "\n".join([l1, l2, l3])
|
16 |
-
|
17 |
-
|
18 |
-
def pad_colors(colors: list[str]) -> list[str]:
|
19 |
-
"""
|
20 |
-
Pad color names to strings of equal length.
|
21 |
-
"""
|
22 |
-
max_len = max(len(c) for c in colors)
|
23 |
-
return [c + " " * (max_len - len(c)) for c in colors]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/rubik/interface/__init__.py
ADDED
File without changes
|
src/rubik/interface/app.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from plotly import graph_objects as go
|
4 |
+
|
5 |
+
from rubik.cube import Cube
|
6 |
+
from rubik.interface.plot import CubeVisualizer
|
7 |
+
|
8 |
+
|
9 |
+
def app(default_size: int = 3):
|
10 |
+
"""
|
11 |
+
Interface with the following features:
|
12 |
+
- create a cube of the specified size.
|
13 |
+
- ability to scramble it with a specified number of moves.
|
14 |
+
- ability to rotate it through a text field.
|
15 |
+
- display a cube upon creation or update.
|
16 |
+
"""
|
17 |
+
cube = Cube(default_size)
|
18 |
+
cube_visualizer = CubeVisualizer(default_size)
|
19 |
+
|
20 |
+
def create_cube(size) -> None:
|
21 |
+
nonlocal cube
|
22 |
+
nonlocal cube_visualizer
|
23 |
+
cube = Cube(size)
|
24 |
+
cube_visualizer = CubeVisualizer(size)
|
25 |
+
return
|
26 |
+
|
27 |
+
def scramble_cube(num_moves: int) -> None:
|
28 |
+
nonlocal cube
|
29 |
+
cube.scramble(num_moves, seed=0)
|
30 |
+
return
|
31 |
+
|
32 |
+
def rotate_cube(moves: str) -> None:
|
33 |
+
nonlocal cube
|
34 |
+
cube.rotate(moves)
|
35 |
+
return
|
36 |
+
|
37 |
+
def display_cube() -> go.Figure:
|
38 |
+
nonlocal cube
|
39 |
+
nonlocal cube_visualizer
|
40 |
+
layout_args = {"autosize": False, "width": 600, "height": 600}
|
41 |
+
return cube_visualizer(cube.coordinates, cube.state, cube.size).update_layout(**layout_args)
|
42 |
+
|
43 |
+
with gr.Blocks(fill_height=True) as demo:
|
44 |
+
# structure
|
45 |
+
gr.Markdown("Rubik's Cube Interface")
|
46 |
+
with gr.Row():
|
47 |
+
with gr.Column(scale=15):
|
48 |
+
size = gr.Slider(1, 100, value=default_size, step=1, label="Select a size")
|
49 |
+
create_btn = gr.Button("Generate a Cube")
|
50 |
+
|
51 |
+
num_moves = gr.Slider(0, 10000, value=500, step=100, label="Select a number of steps for scrambling")
|
52 |
+
scramble_btn = gr.Button("Scramble the Cube")
|
53 |
+
|
54 |
+
moves = gr.Textbox(value="X0 Y1 Z0i", label="Define a sequence of moves")
|
55 |
+
rotate_btn = gr.Button("Rotate the Cube")
|
56 |
+
|
57 |
+
with gr.Column(scale=85):
|
58 |
+
plot = gr.Plot(display_cube(), container=False)
|
59 |
+
|
60 |
+
# interactions
|
61 |
+
create_btn.click(create_cube, inputs=size).success(display_cube, None, plot)
|
62 |
+
scramble_btn.click(scramble_cube, inputs=num_moves).success(display_cube, None, plot)
|
63 |
+
rotate_btn.click(rotate_cube, inputs=moves).success(display_cube, None, plot)
|
64 |
+
|
65 |
+
demo.launch()
|
66 |
+
return
|
src/rubik/interface/plot.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import plotly.graph_objects as go
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class CubeVisualizer:
|
8 |
+
"""
|
9 |
+
Utility class for ploting a cube, with some layout ingredients precomputed at init.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, size: int):
|
13 |
+
self.vertices = self.build_vertices(size)
|
14 |
+
self.shifts = self.build_shifts()
|
15 |
+
self.x_coor = [v[1] for v in self.vertices]
|
16 |
+
self.y_coor = [v[0] for v in self.vertices]
|
17 |
+
self.z_coor = [v[2] for v in self.vertices]
|
18 |
+
self.fig = self.build_base_figure(size)
|
19 |
+
|
20 |
+
@property
|
21 |
+
def colors(self):
|
22 |
+
"""
|
23 |
+
Assign colors for each face.
|
24 |
+
"""
|
25 |
+
return {
|
26 |
+
0: "white",
|
27 |
+
1: "#3588cc",
|
28 |
+
2: "red",
|
29 |
+
3: "green",
|
30 |
+
4: "orange",
|
31 |
+
5: "yellow",
|
32 |
+
}
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def build_vertices(size: int):
|
36 |
+
face_vertices = [
|
37 |
+
[[x, y, size] for y in range(size + 1) for x in range(size + 1)], # Up
|
38 |
+
[[0, y, z] for y in range(size + 1) for z in range(size + 1)], # Left
|
39 |
+
[[x, size, z] for x in range(size + 1) for z in range(size + 1)], # Front
|
40 |
+
[[size, y, z] for y in range(size + 1) for z in range(size + 1)], # Right
|
41 |
+
[[x, 0, z] for x in range(size + 1) for z in range(size + 1)], # Back
|
42 |
+
[[x, y, 0] for x in range(size + 1) for y in range(size + 1)], # Down
|
43 |
+
]
|
44 |
+
return [vertex for face in face_vertices for vertex in face]
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def build_shifts():
|
48 |
+
return [
|
49 |
+
[(0, 1, 0), (1, 0, 0), (1, 1, 0)], # Up
|
50 |
+
[(0, 1, 0), (0, 0, 1), (0, 1, 1)], # Left
|
51 |
+
[(1, 0, 0), (0, 0, 1), (1, 0, 1)], # Front
|
52 |
+
[(0, 1, 0), (0, 0, 1), (0, 1, 1)], # Right
|
53 |
+
[(1, 0, 0), (0, 0, 1), (1, 0, 1)], # Back
|
54 |
+
[(0, 1, 0), (1, 0, 0), (1, 1, 0)], # Down
|
55 |
+
]
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def build_base_figure(size):
|
59 |
+
"""
|
60 |
+
Create base figure for the cube, containing everything but the cube facelets.
|
61 |
+
"""
|
62 |
+
fig = go.Figure()
|
63 |
+
|
64 |
+
# add black lines to the cube
|
65 |
+
lines_seq = [[0, size, size, 0, 0], [0, 0, size, size, 0]]
|
66 |
+
lines_args = {"mode": "lines", "line": {"width": 5, "color": "black"}, "hoverinfo": "none"}
|
67 |
+
for i in range(size + 1):
|
68 |
+
fig.add_trace(go.Scatter3d(x=[i] * 5, y=lines_seq[1], z=lines_seq[0], **lines_args))
|
69 |
+
fig.add_trace(go.Scatter3d(x=lines_seq[1], y=[i] * 5, z=lines_seq[0], **lines_args))
|
70 |
+
fig.add_trace(go.Scatter3d(x=lines_seq[0], y=lines_seq[1], z=[i] * 5, **lines_args))
|
71 |
+
|
72 |
+
# add text along each axis
|
73 |
+
fig.add_trace(
|
74 |
+
go.Scatter3d(
|
75 |
+
x=[size / 2, size / 2, size + 1.5 + size * 0.5],
|
76 |
+
y=[size / 2, -1.5 - size * 0.5, size / 2],
|
77 |
+
z=[size + 1 + size * 0.5, size / 2, size / 2],
|
78 |
+
mode="text",
|
79 |
+
text=["UP", "LEFT", "FRONT"],
|
80 |
+
textposition="middle center",
|
81 |
+
textfont={"size": 20},
|
82 |
+
hoverinfo="none",
|
83 |
+
)
|
84 |
+
)
|
85 |
+
|
86 |
+
# remove legend, background, grid, ticks, etc.
|
87 |
+
scene_axis = {
|
88 |
+
"showgrid": False,
|
89 |
+
"zeroline": False,
|
90 |
+
"showticklabels": False,
|
91 |
+
"showbackground": False,
|
92 |
+
"title_text": "",
|
93 |
+
"showspikes": False,
|
94 |
+
}
|
95 |
+
return fig.update_layout(
|
96 |
+
showlegend=False,
|
97 |
+
autosize=True,
|
98 |
+
scene={
|
99 |
+
"xaxis": scene_axis,
|
100 |
+
"yaxis": scene_axis,
|
101 |
+
"zaxis": scene_axis,
|
102 |
+
"camera": {"eye": {"x": 0.8, "y": -1.2, "z": 0.7}},
|
103 |
+
},
|
104 |
+
)
|
105 |
+
|
106 |
+
def __call__(self, coordinates: torch.Tensor, state: torch.Tensor, size: int) -> go.Figure:
|
107 |
+
"""
|
108 |
+
Generates a 3D plot of a cube given its coordinates, state and size.
|
109 |
+
"""
|
110 |
+
# set the color of each facelet, face after face
|
111 |
+
face_state = (state.argmax(dim=-1) - 1).reshape(6, -1).tolist()
|
112 |
+
face_colors = [[self.colors[f] for f in face] for face in face_state]
|
113 |
+
|
114 |
+
face_coordinates = [coordinates[1:, (coordinates[0] == i)].transpose(0, 1).tolist() for i in range(6)]
|
115 |
+
|
116 |
+
# for each facelet of a face, draw 2 complementary triangles covering it
|
117 |
+
i_coor = []
|
118 |
+
j_coor = []
|
119 |
+
k_coor = []
|
120 |
+
facecolor = []
|
121 |
+
for i in range(6):
|
122 |
+
face = face_coordinates[i]
|
123 |
+
if i == 0:
|
124 |
+
face = [[x, y, size] for x, y, z in face]
|
125 |
+
if i == 2:
|
126 |
+
face = [[x, size, z] for x, y, z in face]
|
127 |
+
if i == 3:
|
128 |
+
face = [[size, y, z] for x, y, z in face]
|
129 |
+
|
130 |
+
# add first triangle of each facelet
|
131 |
+
i_coor += [self.vertices.index(p) for p in face]
|
132 |
+
j_coor += [self.vertices.index([c + s for c, s in zip(p, self.shifts[i][0])]) for p in face]
|
133 |
+
k_coor += [self.vertices.index([c + s for c, s in zip(p, self.shifts[i][1])]) for p in face]
|
134 |
+
|
135 |
+
# add second triangle of each facelet
|
136 |
+
i_coor += [self.vertices.index([c + s for c, s in zip(p, self.shifts[i][2])]) for p in face]
|
137 |
+
j_coor += [self.vertices.index([c + s for c, s in zip(p, self.shifts[i][0])]) for p in face]
|
138 |
+
k_coor += [self.vertices.index([c + s for c, s in zip(p, self.shifts[i][1])]) for p in face]
|
139 |
+
|
140 |
+
facecolor += face_colors[i] * 2
|
141 |
+
|
142 |
+
fig = copy.deepcopy(self.fig)
|
143 |
+
return fig.add_trace(
|
144 |
+
go.Mesh3d(
|
145 |
+
x=self.x_coor,
|
146 |
+
y=self.y_coor,
|
147 |
+
z=self.z_coor,
|
148 |
+
i=i_coor,
|
149 |
+
j=j_coor,
|
150 |
+
k=k_coor,
|
151 |
+
facecolor=facecolor,
|
152 |
+
opacity=1,
|
153 |
+
hoverinfo="none",
|
154 |
+
)
|
155 |
+
)
|
src/rubik/{tensor_utils.py → state.py}
RENAMED
@@ -1,16 +1,15 @@
|
|
1 |
import torch
|
2 |
|
3 |
|
4 |
-
def build_cube_tensor(
|
5 |
"""
|
6 |
Convert a list of 6 colors and size into a sparse 4D tensor representing a cube.
|
7 |
"""
|
8 |
-
assert (num := len(set(colors))) == 6, f"Expected 6 distinct colors, got {num}"
|
9 |
assert isinstance(size, int) and size > 1, f"Expected non-zero integrer size, got {size}"
|
10 |
|
11 |
# build dense tensor filled with colors
|
12 |
n = size - 1
|
13 |
-
tensor = torch.zeros([6, size, size, size], dtype=torch.
|
14 |
tensor[0, :, :, n] = 1 # up
|
15 |
tensor[1, 0, :, :] = 2 # left
|
16 |
tensor[2, :, n, :] = 3 # front
|
@@ -26,6 +25,6 @@ def build_permutation_matrix(size: int, perm: str) -> torch.Tensor:
|
|
26 |
"""
|
27 |
perm_list = [int(p) for p in (perm + perm[0])]
|
28 |
perm_dict = {perm_list[i]: perm_list[i + 1] for i in range(len(perm))}
|
29 |
-
indices = torch.tensor([list(range(size)), [(perm_dict.get(i, i)) for i in range(size)]], dtype=torch.
|
30 |
-
values = torch.tensor([1] * size, dtype=torch.
|
31 |
-
return torch.sparse_coo_tensor(indices=indices, values=values, size=(size, size), dtype=torch.
|
|
|
1 |
import torch
|
2 |
|
3 |
|
4 |
+
def build_cube_tensor(size: int) -> torch.Tensor:
|
5 |
"""
|
6 |
Convert a list of 6 colors and size into a sparse 4D tensor representing a cube.
|
7 |
"""
|
|
|
8 |
assert isinstance(size, int) and size > 1, f"Expected non-zero integrer size, got {size}"
|
9 |
|
10 |
# build dense tensor filled with colors
|
11 |
n = size - 1
|
12 |
+
tensor = torch.zeros([6, size, size, size], dtype=torch.int32)
|
13 |
tensor[0, :, :, n] = 1 # up
|
14 |
tensor[1, 0, :, :] = 2 # left
|
15 |
tensor[2, :, n, :] = 3 # front
|
|
|
25 |
"""
|
26 |
perm_list = [int(p) for p in (perm + perm[0])]
|
27 |
perm_dict = {perm_list[i]: perm_list[i + 1] for i in range(len(perm))}
|
28 |
+
indices = torch.tensor([list(range(size)), [(perm_dict.get(i, i)) for i in range(size)]], dtype=torch.int32)
|
29 |
+
values = torch.tensor([1] * size, dtype=torch.int32)
|
30 |
+
return torch.sparse_coo_tensor(indices=indices, values=values, size=(size, size), dtype=torch.int32).coalesce()
|
tests/unit/test_cube.py
CHANGED
@@ -10,23 +10,16 @@ class TestCube:
|
|
10 |
A testing class for the Cube class.
|
11 |
"""
|
12 |
|
13 |
-
@pytest.mark.parametrize(
|
14 |
-
|
15 |
-
[
|
16 |
-
[["U", "L", "C", "R", "B", "D"], 3],
|
17 |
-
[["Up", "Left", "Center", "Right", "Back", "Down"], 5],
|
18 |
-
[["A", "BB", "CCC", "DDDD", "EEEEE", "FFFFFF"], 10],
|
19 |
-
],
|
20 |
-
)
|
21 |
-
def test__init__(self, colors: list[str], size: int):
|
22 |
"""
|
23 |
Test that the __init__ method produce expected attributes.
|
24 |
"""
|
25 |
-
cube = Cube(
|
26 |
-
assert cube.coordinates.shape == (6 * (size**2), 4), (
|
27 |
-
f"'coordinates' has incorrect shape {cube.coordinates.shape}"
|
28 |
-
)
|
29 |
assert cube.state.shape == (6 * (size**2), 7), f"'state' has incorrect shape {cube.state.shape}"
|
|
|
|
|
|
|
30 |
assert len(cube.history) == 0, "'history' field should be empty"
|
31 |
|
32 |
@pytest.mark.parametrize("device", ["cpu"])
|
@@ -34,7 +27,7 @@ class TestCube:
|
|
34 |
"""
|
35 |
Test that the .to method behaves as expected.
|
36 |
"""
|
37 |
-
cube = Cube(
|
38 |
cube_2 = cube.to(device)
|
39 |
assert torch.equal(cube.state, cube_2.state), "cube has different state after calling 'to' method"
|
40 |
|
@@ -42,7 +35,7 @@ class TestCube:
|
|
42 |
"""
|
43 |
Test that the .reset_history method behaves as expected.
|
44 |
"""
|
45 |
-
cube = Cube(
|
46 |
cube.rotate("X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i")
|
47 |
cube.reset_history()
|
48 |
assert cube.history == [], "method 'reset_history' does not flush content"
|
@@ -52,9 +45,9 @@ class TestCube:
|
|
52 |
"""
|
53 |
Test that the .shuffle method behaves as expected.
|
54 |
"""
|
55 |
-
cube = Cube(
|
56 |
cube_state = cube.state.clone()
|
57 |
-
cube.
|
58 |
assert cube.history == [], "method 'shuffle' does not flush content"
|
59 |
assert not torch.equal(cube_state, cube.state), "method 'shuffle' does not change state"
|
60 |
|
@@ -69,7 +62,7 @@ class TestCube:
|
|
69 |
"""
|
70 |
Test that the .rotate method behaves as expected.
|
71 |
"""
|
72 |
-
cube = Cube(
|
73 |
cube_state = cube.state.clone()
|
74 |
cube.rotate(moves)
|
75 |
assert cube.history != [], "method 'rotate' does not update history"
|
@@ -87,10 +80,10 @@ class TestCube:
|
|
87 |
"""
|
88 |
Test that the .rotate_once method behaves as expected.
|
89 |
"""
|
90 |
-
cube = Cube(
|
91 |
cube_state = cube.state.clone()
|
92 |
cube.rotate_once(axis, slice, inverse)
|
93 |
-
assert cube.history == [
|
94 |
assert not torch.equal(cube_state, cube.state), "method 'rotate_once' does not change state"
|
95 |
|
96 |
@pytest.mark.parametrize(
|
@@ -104,8 +97,8 @@ class TestCube:
|
|
104 |
"""
|
105 |
Test that the .compute_changes method behaves as expected.
|
106 |
"""
|
107 |
-
cube = Cube(
|
108 |
-
facets = cube.state.argmax(dim=-1).to(
|
109 |
changes = cube.compute_changes(moves)
|
110 |
|
111 |
# apply changes induced by moves using the permutation dict returned by 'compute_changes'
|
@@ -113,15 +106,25 @@ class TestCube:
|
|
113 |
|
114 |
# apply changes induced by moves using the optimized 'rotate' method
|
115 |
cube.rotate(moves)
|
116 |
-
observed = cube.state.argmax(dim=-1).to(
|
117 |
|
118 |
# assert the tow are identical
|
119 |
assert expected == observed, "method 'compute_changes' does not behave correctly: "
|
120 |
|
121 |
-
def
|
122 |
"""
|
123 |
Test that the __str__ method behaves as expected.
|
124 |
"""
|
125 |
-
cube = Cube(
|
126 |
repr = str(cube)
|
127 |
assert len(repr), "__str__ method returns an empty representation"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
A testing class for the Cube class.
|
11 |
"""
|
12 |
|
13 |
+
@pytest.mark.parametrize("size", [3, 5, 10, 25])
|
14 |
+
def test__init__(self, size: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
"""
|
16 |
Test that the __init__ method produce expected attributes.
|
17 |
"""
|
18 |
+
cube = Cube(size)
|
|
|
|
|
|
|
19 |
assert cube.state.shape == (6 * (size**2), 7), f"'state' has incorrect shape {cube.state.shape}"
|
20 |
+
assert cube.actions.shape == (3, size, 2, cube.state.shape[0], cube.state.shape[0]), (
|
21 |
+
f"'actions' has incorrect shape {cube.actions.shape}"
|
22 |
+
)
|
23 |
assert len(cube.history) == 0, "'history' field should be empty"
|
24 |
|
25 |
@pytest.mark.parametrize("device", ["cpu"])
|
|
|
27 |
"""
|
28 |
Test that the .to method behaves as expected.
|
29 |
"""
|
30 |
+
cube = Cube(3)
|
31 |
cube_2 = cube.to(device)
|
32 |
assert torch.equal(cube.state, cube_2.state), "cube has different state after calling 'to' method"
|
33 |
|
|
|
35 |
"""
|
36 |
Test that the .reset_history method behaves as expected.
|
37 |
"""
|
38 |
+
cube = Cube(3)
|
39 |
cube.rotate("X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i")
|
40 |
cube.reset_history()
|
41 |
assert cube.history == [], "method 'reset_history' does not flush content"
|
|
|
45 |
"""
|
46 |
Test that the .shuffle method behaves as expected.
|
47 |
"""
|
48 |
+
cube = Cube(3)
|
49 |
cube_state = cube.state.clone()
|
50 |
+
cube.scramble(num_moves, seed)
|
51 |
assert cube.history == [], "method 'shuffle' does not flush content"
|
52 |
assert not torch.equal(cube_state, cube.state), "method 'shuffle' does not change state"
|
53 |
|
|
|
62 |
"""
|
63 |
Test that the .rotate method behaves as expected.
|
64 |
"""
|
65 |
+
cube = Cube(3)
|
66 |
cube_state = cube.state.clone()
|
67 |
cube.rotate(moves)
|
68 |
assert cube.history != [], "method 'rotate' does not update history"
|
|
|
80 |
"""
|
81 |
Test that the .rotate_once method behaves as expected.
|
82 |
"""
|
83 |
+
cube = Cube(3)
|
84 |
cube_state = cube.state.clone()
|
85 |
cube.rotate_once(axis, slice, inverse)
|
86 |
+
assert cube.history == [(axis, slice, inverse)], "method 'rotate_once' does not update history"
|
87 |
assert not torch.equal(cube_state, cube.state), "method 'rotate_once' does not change state"
|
88 |
|
89 |
@pytest.mark.parametrize(
|
|
|
97 |
"""
|
98 |
Test that the .compute_changes method behaves as expected.
|
99 |
"""
|
100 |
+
cube = Cube(3)
|
101 |
+
facets = cube.state.argmax(dim=-1).to(cube.dtype).tolist()
|
102 |
changes = cube.compute_changes(moves)
|
103 |
|
104 |
# apply changes induced by moves using the permutation dict returned by 'compute_changes'
|
|
|
106 |
|
107 |
# apply changes induced by moves using the optimized 'rotate' method
|
108 |
cube.rotate(moves)
|
109 |
+
observed = cube.state.argmax(dim=-1).to(cube.dtype).tolist()
|
110 |
|
111 |
# assert the tow are identical
|
112 |
assert expected == observed, "method 'compute_changes' does not behave correctly: "
|
113 |
|
114 |
+
def test__str__len(self):
|
115 |
"""
|
116 |
Test that the __str__ method behaves as expected.
|
117 |
"""
|
118 |
+
cube = Cube(3)
|
119 |
repr = str(cube)
|
120 |
assert len(repr), "__str__ method returns an empty representation"
|
121 |
+
|
122 |
+
@pytest.mark.parametrize("size", [3, 5, 8, 10])
|
123 |
+
def test__str__content(self, size: int):
|
124 |
+
"""
|
125 |
+
Test that stringify behaves as expected.
|
126 |
+
"""
|
127 |
+
cube = Cube(size=size)
|
128 |
+
repr = str(cube)
|
129 |
+
lens = {len(line) for line in repr.split("\n")}
|
130 |
+
assert len(lens) == 1, f"'stringify' lines have variable length: {lens}"
|
tests/unit/test_display.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
import pytest
|
2 |
-
|
3 |
-
import torch
|
4 |
-
|
5 |
-
from rubik.cube import Cube
|
6 |
-
from rubik.display import stringify, pad_colors
|
7 |
-
|
8 |
-
|
9 |
-
@pytest.mark.parametrize(
|
10 |
-
"colors, size",
|
11 |
-
[
|
12 |
-
[["U", "L", "C", "R", "B", "D"], 3],
|
13 |
-
[["Up", "Left", "Center", "Right", "Back", "Down"], 5],
|
14 |
-
[["A", "BB", "CCC", "DDDD", "EEEEE", "FFFFFF"], 10],
|
15 |
-
],
|
16 |
-
)
|
17 |
-
def test_stringify(colors: list[str], size: int):
|
18 |
-
"""
|
19 |
-
Test that stringify behaves as expected.
|
20 |
-
"""
|
21 |
-
cube = Cube(colors=colors, size=size)
|
22 |
-
state = cube.state.argmax(dim=-1).to(device="cpu", dtype=torch.int16)
|
23 |
-
repr = stringify(state, colors, size)
|
24 |
-
lens = {len(line) for line in repr.split("\n")}
|
25 |
-
assert len(lens) == 1, f"'stringify' lines have variable length: {lens}"
|
26 |
-
|
27 |
-
|
28 |
-
@pytest.mark.parametrize(
|
29 |
-
"colors",
|
30 |
-
[
|
31 |
-
["U", "L", "C", "R", "B", "D"],
|
32 |
-
["Up", "Left", "Center", "Right", "Back", "Down"],
|
33 |
-
["A", "BB", "CCC", "DDDD", "EEEEE", "FFFFFF"],
|
34 |
-
],
|
35 |
-
)
|
36 |
-
def test_pad_colors(colors: list[str]):
|
37 |
-
"""
|
38 |
-
Test that pad_colors behaves as expected.
|
39 |
-
"""
|
40 |
-
padded = pad_colors(colors)
|
41 |
-
lengths = {len(color) for color in padded}
|
42 |
-
assert len(lengths) == 1, f"'pad_colors' generates non-unique lengths: {lengths}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/{test_tensor_utils.py → test_state.py}
RENAMED
@@ -2,7 +2,7 @@ import pytest
|
|
2 |
|
3 |
import torch
|
4 |
|
5 |
-
from rubik.
|
6 |
|
7 |
|
8 |
@pytest.mark.parametrize("size", [2, 3, 5, 20])
|
@@ -10,7 +10,7 @@ def test_build_cube_tensor(size: int):
|
|
10 |
"""
|
11 |
Test that build_cube_tensor behaves as expected.
|
12 |
"""
|
13 |
-
tensor = build_cube_tensor(
|
14 |
facets = tensor.to_dense().to(dtype=torch.int8) != 0
|
15 |
x_sums = facets.sum(dim=(0, 2, 3)).tolist()
|
16 |
y_sums = facets.sum(dim=(0, 1, 3)).tolist()
|
|
|
2 |
|
3 |
import torch
|
4 |
|
5 |
+
from rubik.state import build_cube_tensor, build_permutation_matrix
|
6 |
|
7 |
|
8 |
@pytest.mark.parametrize("size", [2, 3, 5, 20])
|
|
|
10 |
"""
|
11 |
Test that build_cube_tensor behaves as expected.
|
12 |
"""
|
13 |
+
tensor = build_cube_tensor(size)
|
14 |
facets = tensor.to_dense().to(dtype=torch.int8) != 0
|
15 |
x_sums = facets.sum(dim=(0, 2, 3)).tolist()
|
16 |
y_sums = facets.sum(dim=(0, 1, 3)).tolist()
|
uv.lock
CHANGED
The diff for this file is too large to render.
See raw diff
|
|