Spaces:
Sleeping
Sleeping
Jean-baptiste Aujogue
commited on
Faster moves (#3)
Browse files* edit readme
* edit readme
* add docstring
* replace matmul by torch gathering for permutations of facelets
- README.md +4 -5
- notebooks/dev.ipynb +4 -4
- src/rubik/action.py +21 -37
- src/rubik/cube.py +11 -18
- src/rubik/interface/plot.py +2 -1
- src/rubik/state.py +4 -4
- tests/unit/test_action.py +6 -6
- tests/unit/test_cube.py +8 -9
README.md
CHANGED
@@ -7,7 +7,7 @@ This project uses `uv 0.7` as environment & dependency manager, and `python 3.11
|
|
7 |
|
8 |
```shell
|
9 |
uv venv
|
10 |
-
|
11 |
uv sync
|
12 |
pre-commit install
|
13 |
```
|
@@ -67,13 +67,11 @@ cube.rotate('X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i')
|
|
67 |
|
68 |
#### Base solvers following rule-based policies
|
69 |
|
70 |
-
|
71 |
-
|
72 |
## References
|
73 |
|
74 |
-
###
|
75 |
|
76 |
-
Open-source projects related to Rubik's Cube
|
77 |
- [adrianliaw/PyCuber](https://github.com/adrianliaw/PyCuber)
|
78 |
- [pglass/cube](https://github.com/pglass/cube)
|
79 |
- [dwalton76/rubiks-cube-NxNxN-solver](https://github.com/dwalton76/rubiks-cube-NxNxN-solver)
|
@@ -81,6 +79,7 @@ Open-source projects related to Rubik's Cube, sorted by number of stars:
|
|
81 |
- [trincaog/magiccube](https://github.com/trincaog/magiccube)
|
82 |
- [charlstown/rubiks-cube-solver](https://github.com/charlstown/rubiks-cube-solver)
|
83 |
- [staetyk/NxNxN-Cubes](https://github.com/staetyk/NxNxN-Cubes)
|
|
|
84 |
|
85 |
### Machine Learning based solver models
|
86 |
|
|
|
7 |
|
8 |
```shell
|
9 |
uv venv
|
10 |
+
(Activate env)
|
11 |
uv sync
|
12 |
pre-commit install
|
13 |
```
|
|
|
67 |
|
68 |
#### Base solvers following rule-based policies
|
69 |
|
|
|
|
|
70 |
## References
|
71 |
|
72 |
+
### Implementations & rule-based solvers
|
73 |
|
74 |
+
Open-source projects related to Rubik's Cube:
|
75 |
- [adrianliaw/PyCuber](https://github.com/adrianliaw/PyCuber)
|
76 |
- [pglass/cube](https://github.com/pglass/cube)
|
77 |
- [dwalton76/rubiks-cube-NxNxN-solver](https://github.com/dwalton76/rubiks-cube-NxNxN-solver)
|
|
|
79 |
- [trincaog/magiccube](https://github.com/trincaog/magiccube)
|
80 |
- [charlstown/rubiks-cube-solver](https://github.com/charlstown/rubiks-cube-solver)
|
81 |
- [staetyk/NxNxN-Cubes](https://github.com/staetyk/NxNxN-Cubes)
|
82 |
+
- [wata-orz/santa2023_permutation_puzzle](https://github.com/wata-orz/santa2023_permutation_puzzle/tree/main)
|
83 |
|
84 |
### Machine Learning based solver models
|
85 |
|
notebooks/dev.ipynb
CHANGED
@@ -44,8 +44,6 @@
|
|
44 |
"source": [
|
45 |
"size = 3\n",
|
46 |
"\n",
|
47 |
-
"actions = build_actions_tensor(size)\n",
|
48 |
-
"\n",
|
49 |
"cube = Cube(size)\n",
|
50 |
"print(cube)"
|
51 |
]
|
@@ -85,7 +83,7 @@
|
|
85 |
"outputs": [],
|
86 |
"source": [
|
87 |
"cubis = copy.deepcopy(cube)\n",
|
88 |
-
"cubis.scramble(
|
89 |
"print(cubis)\n",
|
90 |
"print(cubis.history)"
|
91 |
]
|
@@ -110,7 +108,9 @@
|
|
110 |
"metadata": {},
|
111 |
"outputs": [],
|
112 |
"source": [
|
113 |
-
"
|
|
|
|
|
114 |
]
|
115 |
}
|
116 |
],
|
|
|
44 |
"source": [
|
45 |
"size = 3\n",
|
46 |
"\n",
|
|
|
|
|
47 |
"cube = Cube(size)\n",
|
48 |
"print(cube)"
|
49 |
]
|
|
|
83 |
"outputs": [],
|
84 |
"source": [
|
85 |
"cubis = copy.deepcopy(cube)\n",
|
86 |
+
"cubis.scramble(20000, seed=0)\n",
|
87 |
"print(cubis)\n",
|
88 |
"print(cubis.history)"
|
89 |
]
|
|
|
108 |
"metadata": {},
|
109 |
"outputs": [],
|
110 |
"source": [
|
111 |
+
"actions = build_actions_tensor(size)\n",
|
112 |
+
"\n",
|
113 |
+
"torch.gather(actions[0, 2, 0], 0, actions[0, 1, 1])"
|
114 |
]
|
115 |
}
|
116 |
],
|
src/rubik/action.py
CHANGED
@@ -17,7 +17,7 @@ POS_ROTATIONS = torch.stack(
|
|
17 |
[0, 0, 0, 1],
|
18 |
[0, 0, -1, 0],
|
19 |
],
|
20 |
-
dtype=torch.
|
21 |
),
|
22 |
# rot about Y: X -> Z
|
23 |
torch.tensor(
|
@@ -27,7 +27,7 @@ POS_ROTATIONS = torch.stack(
|
|
27 |
[0, 0, 1, 0],
|
28 |
[0, 1, 0, 0],
|
29 |
],
|
30 |
-
dtype=torch.
|
31 |
),
|
32 |
# rot about Z: Y -> X
|
33 |
torch.tensor(
|
@@ -37,7 +37,7 @@ POS_ROTATIONS = torch.stack(
|
|
37 |
[0, -1, 0, 0],
|
38 |
[0, 0, 0, 1],
|
39 |
],
|
40 |
-
dtype=torch.
|
41 |
),
|
42 |
]
|
43 |
)
|
@@ -48,7 +48,7 @@ POS_SHIFTS = torch.tensor(
|
|
48 |
[0, 1, 0, 0],
|
49 |
[0, 0, 1, 0],
|
50 |
],
|
51 |
-
dtype=torch.
|
52 |
)
|
53 |
|
54 |
|
@@ -66,30 +66,30 @@ FACE_ROTATIONS = torch.stack(
|
|
66 |
|
67 |
def build_actions_tensor(size: int) -> torch.Tensor:
|
68 |
"""
|
69 |
-
Built the
|
70 |
"""
|
71 |
-
return torch.
|
72 |
[
|
73 |
-
|
|
|
|
|
|
|
74 |
for axis in range(3)
|
75 |
-
for slice in range(size)
|
76 |
-
for inverse in range(2)
|
77 |
],
|
78 |
-
|
79 |
-
)
|
80 |
|
81 |
|
82 |
-
def
|
83 |
"""
|
84 |
-
Compute the
|
85 |
-
|
86 |
-
orientation.
|
87 |
"""
|
88 |
-
tensor = build_cube_tensor(size).to(dtype=torch.
|
89 |
length = 6 * (size**2)
|
90 |
|
91 |
# extract faces impacted by the move
|
92 |
-
indices = tensor.indices().to(dtype=torch.
|
93 |
changes = (indices[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
|
94 |
extract = indices[:, changes] # size = (4, n)
|
95 |
|
@@ -99,7 +99,7 @@ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch
|
|
99 |
rotated = rotated + offsets # size = (4, n)
|
100 |
|
101 |
# apply face rotation
|
102 |
-
rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(torch.
|
103 |
|
104 |
# from this point on, convert rotation into a position-based permutation of colors
|
105 |
(inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
|
@@ -107,28 +107,12 @@ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch
|
|
107 |
outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
|
108 |
|
109 |
# compute position-based permutation of colors equivalent to rotation converting inputs into outputs
|
|
|
110 |
local_to_total = dict(enumerate(changes.tolist()))
|
111 |
total_to_local = {ind: i for i, ind in local_to_total.items()}
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
i: (i if i not in total_to_local else local_to_total[local_perm[total_to_local[i]]]) for i in range(length)
|
116 |
-
}
|
117 |
-
|
118 |
-
# convert permutation dict into sparse tensor
|
119 |
-
perm_indices = torch.tensor(
|
120 |
-
[
|
121 |
-
[axis] * length,
|
122 |
-
[slice] * length,
|
123 |
-
[inverse] * length,
|
124 |
-
list(total_perm.keys()),
|
125 |
-
list(total_perm.values()),
|
126 |
-
],
|
127 |
-
dtype=torch.int32,
|
128 |
-
)
|
129 |
-
perm_values = torch.tensor([1] * length, dtype=torch.int32)
|
130 |
-
perm_size = (3, size, 2, length, length)
|
131 |
-
return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.int32)
|
132 |
|
133 |
|
134 |
def parse_action_str(move: str) -> tuple[int, int, int]:
|
|
|
17 |
[0, 0, 0, 1],
|
18 |
[0, 0, -1, 0],
|
19 |
],
|
20 |
+
dtype=torch.int64,
|
21 |
),
|
22 |
# rot about Y: X -> Z
|
23 |
torch.tensor(
|
|
|
27 |
[0, 0, 1, 0],
|
28 |
[0, 1, 0, 0],
|
29 |
],
|
30 |
+
dtype=torch.int64,
|
31 |
),
|
32 |
# rot about Z: Y -> X
|
33 |
torch.tensor(
|
|
|
37 |
[0, -1, 0, 0],
|
38 |
[0, 0, 0, 1],
|
39 |
],
|
40 |
+
dtype=torch.int64,
|
41 |
),
|
42 |
]
|
43 |
)
|
|
|
48 |
[0, 1, 0, 0],
|
49 |
[0, 0, 1, 0],
|
50 |
],
|
51 |
+
dtype=torch.int64,
|
52 |
)
|
53 |
|
54 |
|
|
|
66 |
|
67 |
def build_actions_tensor(size: int) -> torch.Tensor:
|
68 |
"""
|
69 |
+
Built the 4D tensor carrying all rotations of a cube as index permutation.
|
70 |
"""
|
71 |
+
return torch.tensor(
|
72 |
[
|
73 |
+
[
|
74 |
+
[build_action_permutation(size=size, axis=axis, slice=slice, inverse=inverse) for inverse in range(2)]
|
75 |
+
for slice in range(size)
|
76 |
+
]
|
77 |
for axis in range(3)
|
|
|
|
|
78 |
],
|
79 |
+
dtype=torch.int64,
|
80 |
+
)
|
81 |
|
82 |
|
83 |
+
def build_action_permutation(size: int, axis: int, slice: int, inverse: int) -> list[int]:
|
84 |
"""
|
85 |
+
Compute the permutation list whose effect on a position-frozen color vector is the rotation
|
86 |
+
along the specified axis, within the specified slice and the specified orientation.
|
|
|
87 |
"""
|
88 |
+
tensor = build_cube_tensor(size).to(dtype=torch.int64)
|
89 |
length = 6 * (size**2)
|
90 |
|
91 |
# extract faces impacted by the move
|
92 |
+
indices = tensor.indices().to(dtype=torch.int64) # size = (4, length)
|
93 |
changes = (indices[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
|
94 |
extract = indices[:, changes] # size = (4, n)
|
95 |
|
|
|
99 |
rotated = rotated + offsets # size = (4, n)
|
100 |
|
101 |
# apply face rotation
|
102 |
+
rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(torch.int64) @ FACE_ROTATIONS[axis]).argmax(dim=-1)
|
103 |
|
104 |
# from this point on, convert rotation into a position-based permutation of colors
|
105 |
(inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
|
|
|
107 |
outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
|
108 |
|
109 |
# compute position-based permutation of colors equivalent to rotation converting inputs into outputs
|
110 |
+
local_perm = {i: outputs.index(inputs[i]) for i in range(len(inputs))}
|
111 |
local_to_total = dict(enumerate(changes.tolist()))
|
112 |
total_to_local = {ind: i for i, ind in local_to_total.items()}
|
113 |
|
114 |
+
# return permutation on total list of facelet positions
|
115 |
+
return [(i if i not in total_to_local else local_to_total[local_perm[total_to_local[i]]]) for i in range(length)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
|
118 |
def parse_action_str(move: str) -> tuple[int, int, int]:
|
src/rubik/cube.py
CHANGED
@@ -2,7 +2,6 @@ from functools import reduce
|
|
2 |
from loguru import logger
|
3 |
|
4 |
import torch
|
5 |
-
import torch.nn.functional as F
|
6 |
|
7 |
from rubik.action import build_actions_tensor, parse_actions_str, sample_actions_str
|
8 |
from rubik.state import build_cube_tensor
|
@@ -26,10 +25,11 @@ class Cube:
|
|
26 |
"""
|
27 |
tensor = build_cube_tensor(size)
|
28 |
|
29 |
-
self.dtype = torch.
|
30 |
-
self.coordinates = tensor.indices()
|
31 |
-
self.state =
|
32 |
-
self.actions = build_actions_tensor(size)
|
|
|
33 |
self._history: list[tuple[int, int, int]] = []
|
34 |
self._colors: list[str] = list("ULCRBD")
|
35 |
self._size: int = size
|
@@ -54,7 +54,7 @@ class Cube:
|
|
54 |
"""
|
55 |
tensor = torch.sparse_coo_tensor(
|
56 |
indices=self.coordinates,
|
57 |
-
values=self.state
|
58 |
size=(6, self.size, self.size, self.size),
|
59 |
dtype=self.dtype,
|
60 |
).to_dense()
|
@@ -72,13 +72,7 @@ class Cube:
|
|
72 |
|
73 |
def to(self, device: str | torch.device) -> "Cube":
|
74 |
device = torch.device(device)
|
75 |
-
dtype = (
|
76 |
-
self.state.dtype
|
77 |
-
if self.state.device == device
|
78 |
-
else self.dtype
|
79 |
-
if device == torch.device("cpu")
|
80 |
-
else torch.float32
|
81 |
-
)
|
82 |
self.state = self.state.to(device=device, dtype=dtype)
|
83 |
self.actions = self.actions.to(device=device, dtype=dtype)
|
84 |
logger.info(f"Using device '{self.state.device}' and dtype '{dtype}'")
|
@@ -114,18 +108,17 @@ class Cube:
|
|
114 |
Apply a move (defined as 3 coordinates) to the cube.
|
115 |
"""
|
116 |
action = self.actions[axis, slice, inverse]
|
117 |
-
self.state =
|
118 |
self._history.append((axis, slice, inverse))
|
119 |
return
|
120 |
|
121 |
-
def
|
122 |
"""
|
123 |
combine a sequence of moves and return the resulting changes.
|
124 |
"""
|
125 |
actions = parse_actions_str(moves)
|
126 |
-
tensors = [self.actions[*action]
|
127 |
-
|
128 |
-
return dict(result.indices().transpose(0, 1).tolist())
|
129 |
|
130 |
def __str__(self):
|
131 |
"""
|
|
|
2 |
from loguru import logger
|
3 |
|
4 |
import torch
|
|
|
5 |
|
6 |
from rubik.action import build_actions_tensor, parse_actions_str, sample_actions_str
|
7 |
from rubik.state import build_cube_tensor
|
|
|
25 |
"""
|
26 |
tensor = build_cube_tensor(size)
|
27 |
|
28 |
+
self.dtype = torch.int64
|
29 |
+
self.coordinates = tensor.indices()
|
30 |
+
self.state = tensor.values()
|
31 |
+
self.actions = build_actions_tensor(size)
|
32 |
+
# internal-only attributes
|
33 |
self._history: list[tuple[int, int, int]] = []
|
34 |
self._colors: list[str] = list("ULCRBD")
|
35 |
self._size: int = size
|
|
|
54 |
"""
|
55 |
tensor = torch.sparse_coo_tensor(
|
56 |
indices=self.coordinates,
|
57 |
+
values=self.state,
|
58 |
size=(6, self.size, self.size, self.size),
|
59 |
dtype=self.dtype,
|
60 |
).to_dense()
|
|
|
72 |
|
73 |
def to(self, device: str | torch.device) -> "Cube":
|
74 |
device = torch.device(device)
|
75 |
+
dtype = self.dtype if device == torch.device("cpu") else torch.float32
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
self.state = self.state.to(device=device, dtype=dtype)
|
77 |
self.actions = self.actions.to(device=device, dtype=dtype)
|
78 |
logger.info(f"Using device '{self.state.device}' and dtype '{dtype}'")
|
|
|
108 |
Apply a move (defined as 3 coordinates) to the cube.
|
109 |
"""
|
110 |
action = self.actions[axis, slice, inverse]
|
111 |
+
self.state = torch.gather(self.state, 0, action)
|
112 |
self._history.append((axis, slice, inverse))
|
113 |
return
|
114 |
|
115 |
+
def compose_moves(self, moves: str) -> torch.Tensor:
|
116 |
"""
|
117 |
combine a sequence of moves and return the resulting changes.
|
118 |
"""
|
119 |
actions = parse_actions_str(moves)
|
120 |
+
tensors = [self.actions[*action] for action in actions]
|
121 |
+
return reduce(lambda A, B: torch.gather(A, 0, B), tensors)
|
|
|
122 |
|
123 |
def __str__(self):
|
124 |
"""
|
src/rubik/interface/plot.py
CHANGED
@@ -7,6 +7,7 @@ import torch
|
|
7 |
class CubeVisualizer:
|
8 |
"""
|
9 |
Utility class for ploting a cube, with some layout ingredients precomputed at init.
|
|
|
10 |
"""
|
11 |
|
12 |
def __init__(self, size: int):
|
@@ -108,7 +109,7 @@ class CubeVisualizer:
|
|
108 |
Generates a 3D plot of a cube given its coordinates, state and size.
|
109 |
"""
|
110 |
# set the color of each facelet, face after face
|
111 |
-
face_state = (state
|
112 |
face_colors = [[self.colors[f] for f in face] for face in face_state]
|
113 |
|
114 |
face_coordinates = [coordinates[1:, (coordinates[0] == i)].transpose(0, 1).tolist() for i in range(6)]
|
|
|
7 |
class CubeVisualizer:
|
8 |
"""
|
9 |
Utility class for ploting a cube, with some layout ingredients precomputed at init.
|
10 |
+
Greatly inspired from https://www.kaggle.com/code/edomingo/nxn-rubik-s-cube-3d-interactive-viz-plotly/notebook.
|
11 |
"""
|
12 |
|
13 |
def __init__(self, size: int):
|
|
|
109 |
Generates a 3D plot of a cube given its coordinates, state and size.
|
110 |
"""
|
111 |
# set the color of each facelet, face after face
|
112 |
+
face_state = (state - 1).reshape(6, -1).tolist()
|
113 |
face_colors = [[self.colors[f] for f in face] for face in face_state]
|
114 |
|
115 |
face_coordinates = [coordinates[1:, (coordinates[0] == i)].transpose(0, 1).tolist() for i in range(6)]
|
src/rubik/state.py
CHANGED
@@ -9,7 +9,7 @@ def build_cube_tensor(size: int) -> torch.Tensor:
|
|
9 |
|
10 |
# build dense tensor filled with colors
|
11 |
n = size - 1
|
12 |
-
tensor = torch.zeros([6, size, size, size], dtype=torch.
|
13 |
tensor[0, :, :, n] = 1 # up
|
14 |
tensor[1, 0, :, :] = 2 # left
|
15 |
tensor[2, :, n, :] = 3 # front
|
@@ -25,6 +25,6 @@ def build_permutation_matrix(size: int, perm: str) -> torch.Tensor:
|
|
25 |
"""
|
26 |
perm_list = [int(p) for p in (perm + perm[0])]
|
27 |
perm_dict = {perm_list[i]: perm_list[i + 1] for i in range(len(perm))}
|
28 |
-
indices = torch.tensor([list(range(size)), [(perm_dict.get(i, i)) for i in range(size)]], dtype=torch.
|
29 |
-
values = torch.tensor([1] * size, dtype=torch.
|
30 |
-
return torch.sparse_coo_tensor(indices=indices, values=values, size=(size, size), dtype=torch.
|
|
|
9 |
|
10 |
# build dense tensor filled with colors
|
11 |
n = size - 1
|
12 |
+
tensor = torch.zeros([6, size, size, size], dtype=torch.int64)
|
13 |
tensor[0, :, :, n] = 1 # up
|
14 |
tensor[1, 0, :, :] = 2 # left
|
15 |
tensor[2, :, n, :] = 3 # front
|
|
|
25 |
"""
|
26 |
perm_list = [int(p) for p in (perm + perm[0])]
|
27 |
perm_dict = {perm_list[i]: perm_list[i + 1] for i in range(len(perm))}
|
28 |
+
indices = torch.tensor([list(range(size)), [(perm_dict.get(i, i)) for i in range(size)]], dtype=torch.int64)
|
29 |
+
values = torch.tensor([1] * size, dtype=torch.int64)
|
30 |
+
return torch.sparse_coo_tensor(indices=indices, values=values, size=(size, size), dtype=torch.int64).coalesce()
|
tests/unit/test_action.py
CHANGED
@@ -8,7 +8,7 @@ from rubik.action import (
|
|
8 |
POS_SHIFTS,
|
9 |
FACE_ROTATIONS,
|
10 |
build_actions_tensor,
|
11 |
-
|
12 |
parse_action_str,
|
13 |
parse_actions_str,
|
14 |
sample_actions_str,
|
@@ -96,7 +96,7 @@ def test_build_actions_tensor_shape(size: int):
|
|
96 |
"""
|
97 |
Test that "build_actions_tensor" output has expected shape.
|
98 |
"""
|
99 |
-
expected = (3, size, 2, 6 * (size**2)
|
100 |
observed = build_actions_tensor(size).shape
|
101 |
assert expected == observed, (
|
102 |
f"'build_actions_tensor' output has incorrect shape: expected shape '{expected}', got '{observed}' instead"
|
@@ -111,14 +111,14 @@ def test_build_actions_tensor_shape(size: int):
|
|
111 |
(5, 1, 4, 0),
|
112 |
],
|
113 |
)
|
114 |
-
def
|
115 |
"""
|
116 |
Test that "build_actions_tensor" output has expected shape.
|
117 |
"""
|
118 |
-
expected =
|
119 |
-
observed =
|
120 |
assert expected == observed, (
|
121 |
-
f"'build_action_tensor' output has incorrect
|
122 |
)
|
123 |
|
124 |
|
|
|
8 |
POS_SHIFTS,
|
9 |
FACE_ROTATIONS,
|
10 |
build_actions_tensor,
|
11 |
+
build_action_permutation,
|
12 |
parse_action_str,
|
13 |
parse_actions_str,
|
14 |
sample_actions_str,
|
|
|
96 |
"""
|
97 |
Test that "build_actions_tensor" output has expected shape.
|
98 |
"""
|
99 |
+
expected = (3, size, 2, 6 * (size**2))
|
100 |
observed = build_actions_tensor(size).shape
|
101 |
assert expected == observed, (
|
102 |
f"'build_actions_tensor' output has incorrect shape: expected shape '{expected}', got '{observed}' instead"
|
|
|
111 |
(5, 1, 4, 0),
|
112 |
],
|
113 |
)
|
114 |
+
def test_build_action_permutation(size: int, axis: int, slice: int, inverse: int):
|
115 |
"""
|
116 |
Test that "build_actions_tensor" output has expected shape.
|
117 |
"""
|
118 |
+
expected = 6 * (size**2)
|
119 |
+
observed = len(build_action_permutation(size, axis, slice, inverse))
|
120 |
assert expected == observed, (
|
121 |
+
f"'build_action_tensor' output has incorrect length: expected length '{expected}', got '{observed}'"
|
122 |
)
|
123 |
|
124 |
|
tests/unit/test_cube.py
CHANGED
@@ -16,8 +16,8 @@ class TestCube:
|
|
16 |
Test that the __init__ method produce expected attributes.
|
17 |
"""
|
18 |
cube = Cube(size)
|
19 |
-
assert cube.state.shape == (6 * (size**2),
|
20 |
-
assert cube.actions.shape == (3, size, 2, cube.state.shape[0]
|
21 |
f"'actions' has incorrect shape {cube.actions.shape}"
|
22 |
)
|
23 |
assert len(cube.history) == 0, "'history' field should be empty"
|
@@ -93,23 +93,22 @@ class TestCube:
|
|
93 |
"X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i " * 2,
|
94 |
],
|
95 |
)
|
96 |
-
def
|
97 |
"""
|
98 |
-
Test that the .
|
99 |
"""
|
100 |
cube = Cube(3)
|
101 |
-
facets = cube.state.argmax(dim=-1).to(cube.dtype).tolist()
|
102 |
-
changes = cube.compute_changes(moves)
|
103 |
|
104 |
# apply changes induced by moves using the permutation dict returned by 'compute_changes'
|
105 |
-
|
|
|
106 |
|
107 |
# apply changes induced by moves using the optimized 'rotate' method
|
108 |
cube.rotate(moves)
|
109 |
-
observed = cube.state
|
110 |
|
111 |
# assert the tow are identical
|
112 |
-
assert expected
|
113 |
|
114 |
def test__str__len(self):
|
115 |
"""
|
|
|
16 |
Test that the __init__ method produce expected attributes.
|
17 |
"""
|
18 |
cube = Cube(size)
|
19 |
+
assert cube.state.shape == (6 * (size**2),), f"'state' has incorrect shape {cube.state.shape}"
|
20 |
+
assert cube.actions.shape == (3, size, 2, cube.state.shape[0]), (
|
21 |
f"'actions' has incorrect shape {cube.actions.shape}"
|
22 |
)
|
23 |
assert len(cube.history) == 0, "'history' field should be empty"
|
|
|
93 |
"X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i " * 2,
|
94 |
],
|
95 |
)
|
96 |
+
def test_compose_moves(self, moves: str):
|
97 |
"""
|
98 |
+
Test that the .compose_moves method behaves as expected.
|
99 |
"""
|
100 |
cube = Cube(3)
|
|
|
|
|
101 |
|
102 |
# apply changes induced by moves using the permutation dict returned by 'compute_changes'
|
103 |
+
changes = cube.compose_moves(moves)
|
104 |
+
expected = torch.gather(cube.state.clone(), 0, changes)
|
105 |
|
106 |
# apply changes induced by moves using the optimized 'rotate' method
|
107 |
cube.rotate(moves)
|
108 |
+
observed = cube.state
|
109 |
|
110 |
# assert the tow are identical
|
111 |
+
assert torch.equal(expected, observed), "method 'compute_changes' does not behave correctly: "
|
112 |
|
113 |
def test__str__len(self):
|
114 |
"""
|