Jean-baptiste Aujogue commited on
Commit
f12b6ac
·
unverified ·
1 Parent(s): 7aeb2a0

Gradio Interface (#2)

Browse files

* add plotly and gradio

* edit readme

* simplify cube object

* simplify cube object

* edit facelect computation

* base demo with 3d plot

* remove legacy plot function

* rename demo as interface

* edit plot function

* remove cupy dependency

* all dependencies

* gradio interface with 3D plot

README.md CHANGED
@@ -14,14 +14,20 @@ pre-commit install
14
 
15
  ## Usage
16
 
17
- ### Create a cube
 
 
 
 
 
 
18
 
19
  ```python
20
  from rubik.cube import Cube
21
 
22
- cube = Cube(colors=['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
23
 
24
- # display the cube state and history of moves
25
  print(cube)
26
  # UUU
27
  # UUU
@@ -33,17 +39,14 @@ print(cube)
33
  # DDD
34
  # DDD
35
 
 
36
  print(cube.history)
37
  # []
38
- ```
39
-
40
- ### Perform basic moves
41
 
42
- ```python
43
- # shuffle the cube using 1000 random moves (random shuffling resets the history)
44
- cube.shuffle(num_moves=1000, seed=0)
45
 
46
- # rotate it in some way
47
  cube.rotate('X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i')
48
  ```
49
 
@@ -66,8 +69,9 @@ cube.rotate('X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i')
66
 
67
 
68
 
 
69
 
70
- ## Related projects
71
 
72
  Open-source projects related to Rubik's Cube, sorted by number of stars:
73
  - [adrianliaw/PyCuber](https://github.com/adrianliaw/PyCuber)
@@ -77,3 +81,16 @@ Open-source projects related to Rubik's Cube, sorted by number of stars:
77
  - [trincaog/magiccube](https://github.com/trincaog/magiccube)
78
  - [charlstown/rubiks-cube-solver](https://github.com/charlstown/rubiks-cube-solver)
79
  - [staetyk/NxNxN-Cubes](https://github.com/staetyk/NxNxN-Cubes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  ## Usage
16
 
17
+ ### Launch the web interface
18
+
19
+ ```shell
20
+ python -m rubik interface
21
+ ```
22
+
23
+ ### Use the python API
24
 
25
  ```python
26
  from rubik.cube import Cube
27
 
28
+ cube = Cube(size=3)
29
 
30
+ # display cube state
31
  print(cube)
32
  # UUU
33
  # UUU
 
39
  # DDD
40
  # DDD
41
 
42
+ # display history of moves
43
  print(cube.history)
44
  # []
 
 
 
45
 
46
+ # scramble the cube using 1000 random moves (this resets the history)
47
+ cube.scramble(num_moves=1000, seed=0)
 
48
 
49
+ # rotate it in some way (this gets appended to history)
50
  cube.rotate('X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i')
51
  ```
52
 
 
69
 
70
 
71
 
72
+ ## References
73
 
74
+ ### Python implementations & rule-based solvers
75
 
76
  Open-source projects related to Rubik's Cube, sorted by number of stars:
77
  - [adrianliaw/PyCuber](https://github.com/adrianliaw/PyCuber)
 
81
  - [trincaog/magiccube](https://github.com/trincaog/magiccube)
82
  - [charlstown/rubiks-cube-solver](https://github.com/charlstown/rubiks-cube-solver)
83
  - [staetyk/NxNxN-Cubes](https://github.com/staetyk/NxNxN-Cubes)
84
+
85
+ ### Machine Learning based solver models
86
+
87
+ - 2025, _CayleyPy Cube_, [Github](https://github.com/k1242/cayleypy-cube), [Paper](https://arxiv.org/html/2502.13266v1)
88
+
89
+ - 2025, _Solving A Rubik’s Cube with Supervised Learning – Intuitively and Exhaustively Explained_, [Blog post](https://towardsdatascience.com/solving-a-rubiks-cube-with-supervised-learning-intuitively-and-exhaustively-explained-4f87b72ba1e2/)
90
+
91
+ - 2024, _Solving Rubik's Cube Without Tricky Sampling_, [Paper](https://arxiv.org/abs/2411.19583).<br>
92
+ This involves training a scorer, that estimates the number of moves transforming a given source state into a given target state, where the latter is not necessarily a solved cube. Data are synthetically generated performing random moves and factorizing repeated identical moves.
93
+
94
+ - 2023, _Curious Transformer_, [Github](https://github.com/tedtedtedtedtedted/Solve-Rubiks-Cube-Via-Transformer)
95
+
96
+ - 2021, _Efficient Cube_, [Github](https://github.com/kyo-takano/efficientcube), [Paper](https://arxiv.org/abs/2106.03157)
notebooks/dev.ipynb CHANGED
@@ -31,7 +31,8 @@
31
  "outputs": [],
32
  "source": [
33
  "from rubik.cube import Cube\n",
34
- "from rubik.action import build_actions_tensor"
 
35
  ]
36
  },
37
  {
@@ -45,7 +46,7 @@
45
  "\n",
46
  "actions = build_actions_tensor(size)\n",
47
  "\n",
48
- "cube = Cube([\"U\", \"L\", \"C\", \"R\", \"B\", \"D\"], size=size)\n",
49
  "print(cube)"
50
  ]
51
  },
@@ -57,7 +58,8 @@
57
  "outputs": [],
58
  "source": [
59
  "cubis = copy.deepcopy(cube)\n",
60
- "cubis.shuffle(2000, seed=0)\n",
 
61
  "print(cubis)\n",
62
  "print(cubis.history)"
63
  ]
@@ -69,7 +71,10 @@
69
  "metadata": {},
70
  "outputs": [],
71
  "source": [
72
- "cubis.compute_changes(\"X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i \")"
 
 
 
73
  ]
74
  },
75
  {
@@ -80,7 +85,7 @@
80
  "outputs": [],
81
  "source": [
82
  "cubis = copy.deepcopy(cube)\n",
83
- "cubis.rotate(\"X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i \" * 1000)\n",
84
  "print(cubis)\n",
85
  "print(cubis.history)"
86
  ]
@@ -91,6 +96,19 @@
91
  "id": "7",
92
  "metadata": {},
93
  "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  "source": [
95
  "(actions[0, 2, 0].type(torch.float32) @ actions[0, 1, 1].type(torch.float32)).type(torch.int8)"
96
  ]
 
31
  "outputs": [],
32
  "source": [
33
  "from rubik.cube import Cube\n",
34
+ "from rubik.action import build_actions_tensor\n",
35
+ "from rubik.interface.plot import CubeVisualizer"
36
  ]
37
  },
38
  {
 
46
  "\n",
47
  "actions = build_actions_tensor(size)\n",
48
  "\n",
49
+ "cube = Cube(size)\n",
50
  "print(cube)"
51
  ]
52
  },
 
58
  "outputs": [],
59
  "source": [
60
  "cubis = copy.deepcopy(cube)\n",
61
+ "cubis.rotate(\"X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i \")\n",
62
+ "\n",
63
  "print(cubis)\n",
64
  "print(cubis.history)"
65
  ]
 
71
  "metadata": {},
72
  "outputs": [],
73
  "source": [
74
+ "visualizer = CubeVisualizer(size=cubis.size)\n",
75
+ "layout_args = {\"autosize\": False, \"width\": 600, \"height\": 600}\n",
76
+ "\n",
77
+ "visualizer(cubis.coordinates, cubis.state, cubis.size).update_layout(**layout_args).show()"
78
  ]
79
  },
80
  {
 
85
  "outputs": [],
86
  "source": [
87
  "cubis = copy.deepcopy(cube)\n",
88
+ "cubis.scramble(2000, seed=0)\n",
89
  "print(cubis)\n",
90
  "print(cubis.history)"
91
  ]
 
96
  "id": "7",
97
  "metadata": {},
98
  "outputs": [],
99
+ "source": [
100
+ "cubis = copy.deepcopy(cube)\n",
101
+ "cubis.rotate(\"X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i \" * 1000)\n",
102
+ "print(cubis)\n",
103
+ "print(cubis.history)"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "id": "8",
110
+ "metadata": {},
111
+ "outputs": [],
112
  "source": [
113
  "(actions[0, 2, 0].type(torch.float32) @ actions[0, 1, 1].type(torch.float32)).type(torch.int8)"
114
  ]
pyproject.toml CHANGED
@@ -5,9 +5,11 @@ description = "Add your description here"
5
  readme = "README.md"
6
  requires-python = ">=3.11,<3.12"
7
  dependencies = [
8
- "cupy-cuda12x>=13.4.1",
9
  "fire>=0.7.0",
 
10
  "loguru>=0.7.3",
 
 
11
  "torch>=2.7.1",
12
  ]
13
 
 
5
  readme = "README.md"
6
  requires-python = ">=3.11,<3.12"
7
  dependencies = [
 
8
  "fire>=0.7.0",
9
+ "gradio>=5.38",
10
  "loguru>=0.7.3",
11
+ "plotly>=6.2.0",
12
+ "pydantic>=2.11.7",
13
  "torch>=2.7.1",
14
  ]
15
 
src/rubik/__main__.py CHANGED
@@ -1,4 +1,6 @@
1
- # from fire import Fire
2
 
 
3
 
4
- # Fire({"hello": hello_world})
 
 
1
+ from fire import Fire
2
 
3
+ from rubik.interface.app import app
4
 
5
+
6
+ Fire({"interface": app})
src/rubik/action.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  import torch
5
  import torch.nn.functional as F
6
 
7
- from rubik.tensor_utils import build_permutation_matrix, build_cube_tensor
8
 
9
 
10
  POS_ROTATIONS = torch.stack(
@@ -17,7 +17,7 @@ POS_ROTATIONS = torch.stack(
17
  [0, 0, 0, 1],
18
  [0, 0, -1, 0],
19
  ],
20
- dtype=torch.int16,
21
  ),
22
  # rot about Y: X -> Z
23
  torch.tensor(
@@ -27,7 +27,7 @@ POS_ROTATIONS = torch.stack(
27
  [0, 0, 1, 0],
28
  [0, 1, 0, 0],
29
  ],
30
- dtype=torch.int16,
31
  ),
32
  # rot about Z: Y -> X
33
  torch.tensor(
@@ -37,7 +37,7 @@ POS_ROTATIONS = torch.stack(
37
  [0, -1, 0, 0],
38
  [0, 0, 0, 1],
39
  ],
40
- dtype=torch.int16,
41
  ),
42
  ]
43
  )
@@ -48,7 +48,7 @@ POS_SHIFTS = torch.tensor(
48
  [0, 1, 0, 0],
49
  [0, 0, 1, 0],
50
  ],
51
- dtype=torch.int16,
52
  )
53
 
54
 
@@ -76,7 +76,7 @@ def build_actions_tensor(size: int) -> torch.Tensor:
76
  for inverse in range(2)
77
  ],
78
  dim=0,
79
- ).sum(dim=0, dtype=torch.int16)
80
 
81
 
82
  def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor:
@@ -85,11 +85,11 @@ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch
85
  is the rotation along the specified axis, within the specified slice and the specified
86
  orientation.
87
  """
88
- tensor = build_cube_tensor(colors=list("ULCRBD"), size=size)
89
  length = 6 * (size**2)
90
 
91
  # extract faces impacted by the move
92
- indices = tensor.indices().to(dtype=torch.int16) # size = (4, length)
93
  changes = (indices[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
94
  extract = indices[:, changes] # size = (4, n)
95
 
@@ -99,7 +99,7 @@ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch
99
  rotated = rotated + offsets # size = (4, n)
100
 
101
  # apply face rotation
102
- rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(torch.int16) @ FACE_ROTATIONS[axis]).argmax(dim=-1)
103
 
104
  # from this point on, convert rotation into a position-based permutation of colors
105
  (inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
@@ -110,19 +110,25 @@ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch
110
  local_to_total = dict(enumerate(changes.tolist()))
111
  total_to_local = {ind: i for i, ind in local_to_total.items()}
112
 
113
- local_perm = {i: inputs.index(outputs[i]) for i in range(len(inputs))}
114
  total_perm = {
115
  i: (i if i not in total_to_local else local_to_total[local_perm[total_to_local[i]]]) for i in range(length)
116
  }
117
 
118
  # convert permutation dict into sparse tensor
119
  perm_indices = torch.tensor(
120
- [[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
121
- dtype=torch.int16,
 
 
 
 
 
 
122
  )
123
- perm_values = torch.tensor([1] * length, dtype=torch.int16)
124
  perm_size = (3, size, 2, length, length)
125
- return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.int16)
126
 
127
 
128
  def parse_action_str(move: str) -> tuple[int, int, int]:
 
4
  import torch
5
  import torch.nn.functional as F
6
 
7
+ from rubik.state import build_permutation_matrix, build_cube_tensor
8
 
9
 
10
  POS_ROTATIONS = torch.stack(
 
17
  [0, 0, 0, 1],
18
  [0, 0, -1, 0],
19
  ],
20
+ dtype=torch.int32,
21
  ),
22
  # rot about Y: X -> Z
23
  torch.tensor(
 
27
  [0, 0, 1, 0],
28
  [0, 1, 0, 0],
29
  ],
30
+ dtype=torch.int32,
31
  ),
32
  # rot about Z: Y -> X
33
  torch.tensor(
 
37
  [0, -1, 0, 0],
38
  [0, 0, 0, 1],
39
  ],
40
+ dtype=torch.int32,
41
  ),
42
  ]
43
  )
 
48
  [0, 1, 0, 0],
49
  [0, 0, 1, 0],
50
  ],
51
+ dtype=torch.int32,
52
  )
53
 
54
 
 
76
  for inverse in range(2)
77
  ],
78
  dim=0,
79
+ ).sum(dim=0, dtype=torch.int32)
80
 
81
 
82
  def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor:
 
85
  is the rotation along the specified axis, within the specified slice and the specified
86
  orientation.
87
  """
88
+ tensor = build_cube_tensor(size).to(dtype=torch.int32)
89
  length = 6 * (size**2)
90
 
91
  # extract faces impacted by the move
92
+ indices = tensor.indices().to(dtype=torch.int32) # size = (4, length)
93
  changes = (indices[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
94
  extract = indices[:, changes] # size = (4, n)
95
 
 
99
  rotated = rotated + offsets # size = (4, n)
100
 
101
  # apply face rotation
102
+ rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(torch.int32) @ FACE_ROTATIONS[axis]).argmax(dim=-1)
103
 
104
  # from this point on, convert rotation into a position-based permutation of colors
105
  (inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
 
110
  local_to_total = dict(enumerate(changes.tolist()))
111
  total_to_local = {ind: i for i, ind in local_to_total.items()}
112
 
113
+ local_perm = {i: outputs.index(inputs[i]) for i in range(len(inputs))}
114
  total_perm = {
115
  i: (i if i not in total_to_local else local_to_total[local_perm[total_to_local[i]]]) for i in range(length)
116
  }
117
 
118
  # convert permutation dict into sparse tensor
119
  perm_indices = torch.tensor(
120
+ [
121
+ [axis] * length,
122
+ [slice] * length,
123
+ [inverse] * length,
124
+ list(total_perm.keys()),
125
+ list(total_perm.values()),
126
+ ],
127
+ dtype=torch.int32,
128
  )
129
+ perm_values = torch.tensor([1] * length, dtype=torch.int32)
130
  perm_size = (3, size, 2, length, length)
131
+ return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.int32)
132
 
133
 
134
  def parse_action_str(move: str) -> tuple[int, int, int]:
src/rubik/cube.py CHANGED
@@ -5,8 +5,7 @@ import torch
5
  import torch.nn.functional as F
6
 
7
  from rubik.action import build_actions_tensor, parse_actions_str, sample_actions_str
8
- from rubik.display import stringify
9
- from rubik.tensor_utils import build_cube_tensor
10
 
11
 
12
  class Cube:
@@ -21,30 +20,65 @@ class Cube:
21
  the rest according to order given in "colors" attribute.
22
  """
23
 
24
- def __init__(self, colors: list[str], size: int):
25
  """
26
- Create Cube from a given list of 6 colors and size.
27
- Example:
28
- cube = Cube(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
29
  """
30
- tensor = build_cube_tensor(colors, size)
31
- self.coordinates = tensor.indices().transpose(0, 1).to(torch.int16)
32
- self.state = F.one_hot(tensor.values().long(), num_classes=7).to(torch.int16)
33
- self.actions = build_actions_tensor(size)
34
- self.history: list[list[int]] = []
35
- self.colors = colors
36
- self.size = size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def to(self, device: str | torch.device) -> "Cube":
39
  device = torch.device(device)
40
  dtype = (
41
  self.state.dtype
42
  if self.state.device == device
43
- else torch.int16
44
  if device == torch.device("cpu")
45
  else torch.float32
46
  )
47
- self.coordinates = self.coordinates.to(device=device, dtype=dtype)
48
  self.state = self.state.to(device=device, dtype=dtype)
49
  self.actions = self.actions.to(device=device, dtype=dtype)
50
  logger.info(f"Using device '{self.state.device}' and dtype '{dtype}'")
@@ -54,10 +88,10 @@ class Cube:
54
  """
55
  Reset internal history of moves.
56
  """
57
- self.history = []
58
  return
59
 
60
- def shuffle(self, num_moves: int, seed: int = 0) -> None:
61
  """
62
  Randomly shuffle the cube by the supplied number of steps, and reset history of moves.
63
  """
@@ -81,7 +115,7 @@ class Cube:
81
  """
82
  action = self.actions[axis, slice, inverse]
83
  self.state = action @ self.state
84
- self.history.append([axis, slice, inverse])
85
  return
86
 
87
  def compute_changes(self, moves: str) -> dict[int, int]:
@@ -97,5 +131,9 @@ class Cube:
97
  """
98
  Compute a string representation of a cube.
99
  """
100
- state = self.state.argmax(dim=-1).to(device="cpu", dtype=torch.int16)
101
- return stringify(state, self.colors, self.size)
 
 
 
 
 
5
  import torch.nn.functional as F
6
 
7
  from rubik.action import build_actions_tensor, parse_actions_str, sample_actions_str
8
+ from rubik.state import build_cube_tensor
 
9
 
10
 
11
  class Cube:
 
20
  the rest according to order given in "colors" attribute.
21
  """
22
 
23
+ def __init__(self, size: int):
24
  """
25
+ Create Cube of a given size.
 
 
26
  """
27
+ tensor = build_cube_tensor(size)
28
+
29
+ self.dtype = torch.int8 if size <= 6 else torch.int16 if size <= 73 else torch.int32
30
+ self.coordinates = tensor.indices().to(self.dtype)
31
+ self.state = F.one_hot(tensor.values().long(), num_classes=7).to(self.dtype)
32
+ self.actions = build_actions_tensor(size).to(self.dtype)
33
+ self._history: list[tuple[int, int, int]] = []
34
+ self._colors: list[str] = list("ULCRBD")
35
+ self._size: int = size
36
+
37
+ @property
38
+ def history(self) -> list[tuple[int, int, int]]:
39
+ return self._history
40
+
41
+ @property
42
+ def colors(self) -> list[str]:
43
+ return self._colors
44
+
45
+ @property
46
+ def size(self) -> int:
47
+ return self._size
48
+
49
+ @property
50
+ def facelets(self) -> list[list[list[str]]]:
51
+ """
52
+ Return the list of faces of the cube, each given by a list of rows,
53
+ each given by a list of facelets.
54
+ """
55
+ tensor = torch.sparse_coo_tensor(
56
+ indices=self.coordinates,
57
+ values=self.state.argmax(dim=-1),
58
+ size=(6, self.size, self.size, self.size),
59
+ dtype=self.dtype,
60
+ ).to_dense()
61
+
62
+ n = self.size - 1
63
+ faces = [
64
+ tensor[0, :, :, n].transpose(0, 1), # up
65
+ tensor[1, 0, :, :].flip(1).transpose(0, 1), # left
66
+ tensor[2, :, n, :].flip(1).transpose(0, 1), # front
67
+ tensor[3, n, :, :].flip(0).flip(1).transpose(0, 1), # right
68
+ tensor[4, :, 0, :].flip(0).flip(1).transpose(0, 1), # back
69
+ tensor[5, :, :, 0].flip(1).transpose(0, 1), # down
70
+ ]
71
+ return [[[self.colors[i - 1] for i in row] for row in face.tolist()] for face in faces]
72
 
73
  def to(self, device: str | torch.device) -> "Cube":
74
  device = torch.device(device)
75
  dtype = (
76
  self.state.dtype
77
  if self.state.device == device
78
+ else self.dtype
79
  if device == torch.device("cpu")
80
  else torch.float32
81
  )
 
82
  self.state = self.state.to(device=device, dtype=dtype)
83
  self.actions = self.actions.to(device=device, dtype=dtype)
84
  logger.info(f"Using device '{self.state.device}' and dtype '{dtype}'")
 
88
  """
89
  Reset internal history of moves.
90
  """
91
+ self._history = []
92
  return
93
 
94
+ def scramble(self, num_moves: int, seed: int = 0) -> None:
95
  """
96
  Randomly shuffle the cube by the supplied number of steps, and reset history of moves.
97
  """
 
115
  """
116
  action = self.actions[axis, slice, inverse]
117
  self.state = action @ self.state
118
+ self._history.append((axis, slice, inverse))
119
  return
120
 
121
  def compute_changes(self, moves: str) -> dict[int, int]:
 
131
  """
132
  Compute a string representation of a cube.
133
  """
134
+ space = " " * self.size
135
+ facelets = self.facelets
136
+ l1 = "\n".join(" ".join([space, "".join(row), space, space]) for row in facelets[0])
137
+ l2 = "\n".join(" ".join("".join(face[i]) for face in facelets[1:5]) for i in range(self.size))
138
+ l3 = "\n".join(" ".join((space, "".join(row), space, space)) for row in facelets[-1])
139
+ return "\n".join([l1, l2, l3])
src/rubik/display.py DELETED
@@ -1,23 +0,0 @@
1
- import torch
2
-
3
-
4
- def stringify(state: torch.Tensor, colors: list[str], size: int) -> str:
5
- """
6
- Compute a string representation of a cube.
7
- """
8
- colors = pad_colors(colors)
9
- faces = state.reshape(6, size, size).transpose(1, 2)
10
- faces = [[[colors[i - 1] for i in row] for row in face.tolist()] for face in faces]
11
- space = " " * max(len(c) for c in colors) * size
12
- l1 = "\n".join(" ".join([space, "".join(row), space, space]) for row in faces[0])
13
- l2 = "\n".join(" ".join("".join(face[i]) for face in faces[1:5]) for i in range(size))
14
- l3 = "\n".join(" ".join((space, "".join(row), space, space)) for row in faces[-1])
15
- return "\n".join([l1, l2, l3])
16
-
17
-
18
- def pad_colors(colors: list[str]) -> list[str]:
19
- """
20
- Pad color names to strings of equal length.
21
- """
22
- max_len = max(len(c) for c in colors)
23
- return [c + " " * (max_len - len(c)) for c in colors]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rubik/interface/__init__.py ADDED
File without changes
src/rubik/interface/app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from plotly import graph_objects as go
4
+
5
+ from rubik.cube import Cube
6
+ from rubik.interface.plot import CubeVisualizer
7
+
8
+
9
+ def app(default_size: int = 3):
10
+ """
11
+ Interface with the following features:
12
+ - create a cube of the specified size.
13
+ - ability to scramble it with a specified number of moves.
14
+ - ability to rotate it through a text field.
15
+ - display a cube upon creation or update.
16
+ """
17
+ cube = Cube(default_size)
18
+ cube_visualizer = CubeVisualizer(default_size)
19
+
20
+ def create_cube(size) -> None:
21
+ nonlocal cube
22
+ nonlocal cube_visualizer
23
+ cube = Cube(size)
24
+ cube_visualizer = CubeVisualizer(size)
25
+ return
26
+
27
+ def scramble_cube(num_moves: int) -> None:
28
+ nonlocal cube
29
+ cube.scramble(num_moves, seed=0)
30
+ return
31
+
32
+ def rotate_cube(moves: str) -> None:
33
+ nonlocal cube
34
+ cube.rotate(moves)
35
+ return
36
+
37
+ def display_cube() -> go.Figure:
38
+ nonlocal cube
39
+ nonlocal cube_visualizer
40
+ layout_args = {"autosize": False, "width": 600, "height": 600}
41
+ return cube_visualizer(cube.coordinates, cube.state, cube.size).update_layout(**layout_args)
42
+
43
+ with gr.Blocks(fill_height=True) as demo:
44
+ # structure
45
+ gr.Markdown("Rubik's Cube Interface")
46
+ with gr.Row():
47
+ with gr.Column(scale=15):
48
+ size = gr.Slider(1, 100, value=default_size, step=1, label="Select a size")
49
+ create_btn = gr.Button("Generate a Cube")
50
+
51
+ num_moves = gr.Slider(0, 10000, value=500, step=100, label="Select a number of steps for scrambling")
52
+ scramble_btn = gr.Button("Scramble the Cube")
53
+
54
+ moves = gr.Textbox(value="X0 Y1 Z0i", label="Define a sequence of moves")
55
+ rotate_btn = gr.Button("Rotate the Cube")
56
+
57
+ with gr.Column(scale=85):
58
+ plot = gr.Plot(display_cube(), container=False)
59
+
60
+ # interactions
61
+ create_btn.click(create_cube, inputs=size).success(display_cube, None, plot)
62
+ scramble_btn.click(scramble_cube, inputs=num_moves).success(display_cube, None, plot)
63
+ rotate_btn.click(rotate_cube, inputs=moves).success(display_cube, None, plot)
64
+
65
+ demo.launch()
66
+ return
src/rubik/interface/plot.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import plotly.graph_objects as go
3
+
4
+ import torch
5
+
6
+
7
+ class CubeVisualizer:
8
+ """
9
+ Utility class for ploting a cube, with some layout ingredients precomputed at init.
10
+ """
11
+
12
+ def __init__(self, size: int):
13
+ self.vertices = self.build_vertices(size)
14
+ self.shifts = self.build_shifts()
15
+ self.x_coor = [v[1] for v in self.vertices]
16
+ self.y_coor = [v[0] for v in self.vertices]
17
+ self.z_coor = [v[2] for v in self.vertices]
18
+ self.fig = self.build_base_figure(size)
19
+
20
+ @property
21
+ def colors(self):
22
+ """
23
+ Assign colors for each face.
24
+ """
25
+ return {
26
+ 0: "white",
27
+ 1: "#3588cc",
28
+ 2: "red",
29
+ 3: "green",
30
+ 4: "orange",
31
+ 5: "yellow",
32
+ }
33
+
34
+ @staticmethod
35
+ def build_vertices(size: int):
36
+ face_vertices = [
37
+ [[x, y, size] for y in range(size + 1) for x in range(size + 1)], # Up
38
+ [[0, y, z] for y in range(size + 1) for z in range(size + 1)], # Left
39
+ [[x, size, z] for x in range(size + 1) for z in range(size + 1)], # Front
40
+ [[size, y, z] for y in range(size + 1) for z in range(size + 1)], # Right
41
+ [[x, 0, z] for x in range(size + 1) for z in range(size + 1)], # Back
42
+ [[x, y, 0] for x in range(size + 1) for y in range(size + 1)], # Down
43
+ ]
44
+ return [vertex for face in face_vertices for vertex in face]
45
+
46
+ @staticmethod
47
+ def build_shifts():
48
+ return [
49
+ [(0, 1, 0), (1, 0, 0), (1, 1, 0)], # Up
50
+ [(0, 1, 0), (0, 0, 1), (0, 1, 1)], # Left
51
+ [(1, 0, 0), (0, 0, 1), (1, 0, 1)], # Front
52
+ [(0, 1, 0), (0, 0, 1), (0, 1, 1)], # Right
53
+ [(1, 0, 0), (0, 0, 1), (1, 0, 1)], # Back
54
+ [(0, 1, 0), (1, 0, 0), (1, 1, 0)], # Down
55
+ ]
56
+
57
+ @staticmethod
58
+ def build_base_figure(size):
59
+ """
60
+ Create base figure for the cube, containing everything but the cube facelets.
61
+ """
62
+ fig = go.Figure()
63
+
64
+ # add black lines to the cube
65
+ lines_seq = [[0, size, size, 0, 0], [0, 0, size, size, 0]]
66
+ lines_args = {"mode": "lines", "line": {"width": 5, "color": "black"}, "hoverinfo": "none"}
67
+ for i in range(size + 1):
68
+ fig.add_trace(go.Scatter3d(x=[i] * 5, y=lines_seq[1], z=lines_seq[0], **lines_args))
69
+ fig.add_trace(go.Scatter3d(x=lines_seq[1], y=[i] * 5, z=lines_seq[0], **lines_args))
70
+ fig.add_trace(go.Scatter3d(x=lines_seq[0], y=lines_seq[1], z=[i] * 5, **lines_args))
71
+
72
+ # add text along each axis
73
+ fig.add_trace(
74
+ go.Scatter3d(
75
+ x=[size / 2, size / 2, size + 1.5 + size * 0.5],
76
+ y=[size / 2, -1.5 - size * 0.5, size / 2],
77
+ z=[size + 1 + size * 0.5, size / 2, size / 2],
78
+ mode="text",
79
+ text=["UP", "LEFT", "FRONT"],
80
+ textposition="middle center",
81
+ textfont={"size": 20},
82
+ hoverinfo="none",
83
+ )
84
+ )
85
+
86
+ # remove legend, background, grid, ticks, etc.
87
+ scene_axis = {
88
+ "showgrid": False,
89
+ "zeroline": False,
90
+ "showticklabels": False,
91
+ "showbackground": False,
92
+ "title_text": "",
93
+ "showspikes": False,
94
+ }
95
+ return fig.update_layout(
96
+ showlegend=False,
97
+ autosize=True,
98
+ scene={
99
+ "xaxis": scene_axis,
100
+ "yaxis": scene_axis,
101
+ "zaxis": scene_axis,
102
+ "camera": {"eye": {"x": 0.8, "y": -1.2, "z": 0.7}},
103
+ },
104
+ )
105
+
106
+ def __call__(self, coordinates: torch.Tensor, state: torch.Tensor, size: int) -> go.Figure:
107
+ """
108
+ Generates a 3D plot of a cube given its coordinates, state and size.
109
+ """
110
+ # set the color of each facelet, face after face
111
+ face_state = (state.argmax(dim=-1) - 1).reshape(6, -1).tolist()
112
+ face_colors = [[self.colors[f] for f in face] for face in face_state]
113
+
114
+ face_coordinates = [coordinates[1:, (coordinates[0] == i)].transpose(0, 1).tolist() for i in range(6)]
115
+
116
+ # for each facelet of a face, draw 2 complementary triangles covering it
117
+ i_coor = []
118
+ j_coor = []
119
+ k_coor = []
120
+ facecolor = []
121
+ for i in range(6):
122
+ face = face_coordinates[i]
123
+ if i == 0:
124
+ face = [[x, y, size] for x, y, z in face]
125
+ if i == 2:
126
+ face = [[x, size, z] for x, y, z in face]
127
+ if i == 3:
128
+ face = [[size, y, z] for x, y, z in face]
129
+
130
+ # add first triangle of each facelet
131
+ i_coor += [self.vertices.index(p) for p in face]
132
+ j_coor += [self.vertices.index([c + s for c, s in zip(p, self.shifts[i][0])]) for p in face]
133
+ k_coor += [self.vertices.index([c + s for c, s in zip(p, self.shifts[i][1])]) for p in face]
134
+
135
+ # add second triangle of each facelet
136
+ i_coor += [self.vertices.index([c + s for c, s in zip(p, self.shifts[i][2])]) for p in face]
137
+ j_coor += [self.vertices.index([c + s for c, s in zip(p, self.shifts[i][0])]) for p in face]
138
+ k_coor += [self.vertices.index([c + s for c, s in zip(p, self.shifts[i][1])]) for p in face]
139
+
140
+ facecolor += face_colors[i] * 2
141
+
142
+ fig = copy.deepcopy(self.fig)
143
+ return fig.add_trace(
144
+ go.Mesh3d(
145
+ x=self.x_coor,
146
+ y=self.y_coor,
147
+ z=self.z_coor,
148
+ i=i_coor,
149
+ j=j_coor,
150
+ k=k_coor,
151
+ facecolor=facecolor,
152
+ opacity=1,
153
+ hoverinfo="none",
154
+ )
155
+ )
src/rubik/{tensor_utils.py → state.py} RENAMED
@@ -1,16 +1,15 @@
1
  import torch
2
 
3
 
4
- def build_cube_tensor(colors: list[str], size: int) -> torch.Tensor:
5
  """
6
  Convert a list of 6 colors and size into a sparse 4D tensor representing a cube.
7
  """
8
- assert (num := len(set(colors))) == 6, f"Expected 6 distinct colors, got {num}"
9
  assert isinstance(size, int) and size > 1, f"Expected non-zero integrer size, got {size}"
10
 
11
  # build dense tensor filled with colors
12
  n = size - 1
13
- tensor = torch.zeros([6, size, size, size], dtype=torch.int16)
14
  tensor[0, :, :, n] = 1 # up
15
  tensor[1, 0, :, :] = 2 # left
16
  tensor[2, :, n, :] = 3 # front
@@ -26,6 +25,6 @@ def build_permutation_matrix(size: int, perm: str) -> torch.Tensor:
26
  """
27
  perm_list = [int(p) for p in (perm + perm[0])]
28
  perm_dict = {perm_list[i]: perm_list[i + 1] for i in range(len(perm))}
29
- indices = torch.tensor([list(range(size)), [(perm_dict.get(i, i)) for i in range(size)]], dtype=torch.int16)
30
- values = torch.tensor([1] * size, dtype=torch.int16)
31
- return torch.sparse_coo_tensor(indices=indices, values=values, size=(size, size), dtype=torch.int16).coalesce()
 
1
  import torch
2
 
3
 
4
+ def build_cube_tensor(size: int) -> torch.Tensor:
5
  """
6
  Convert a list of 6 colors and size into a sparse 4D tensor representing a cube.
7
  """
 
8
  assert isinstance(size, int) and size > 1, f"Expected non-zero integrer size, got {size}"
9
 
10
  # build dense tensor filled with colors
11
  n = size - 1
12
+ tensor = torch.zeros([6, size, size, size], dtype=torch.int32)
13
  tensor[0, :, :, n] = 1 # up
14
  tensor[1, 0, :, :] = 2 # left
15
  tensor[2, :, n, :] = 3 # front
 
25
  """
26
  perm_list = [int(p) for p in (perm + perm[0])]
27
  perm_dict = {perm_list[i]: perm_list[i + 1] for i in range(len(perm))}
28
+ indices = torch.tensor([list(range(size)), [(perm_dict.get(i, i)) for i in range(size)]], dtype=torch.int32)
29
+ values = torch.tensor([1] * size, dtype=torch.int32)
30
+ return torch.sparse_coo_tensor(indices=indices, values=values, size=(size, size), dtype=torch.int32).coalesce()
tests/unit/test_cube.py CHANGED
@@ -10,23 +10,16 @@ class TestCube:
10
  A testing class for the Cube class.
11
  """
12
 
13
- @pytest.mark.parametrize(
14
- "colors, size",
15
- [
16
- [["U", "L", "C", "R", "B", "D"], 3],
17
- [["Up", "Left", "Center", "Right", "Back", "Down"], 5],
18
- [["A", "BB", "CCC", "DDDD", "EEEEE", "FFFFFF"], 10],
19
- ],
20
- )
21
- def test__init__(self, colors: list[str], size: int):
22
  """
23
  Test that the __init__ method produce expected attributes.
24
  """
25
- cube = Cube(colors, size)
26
- assert cube.coordinates.shape == (6 * (size**2), 4), (
27
- f"'coordinates' has incorrect shape {cube.coordinates.shape}"
28
- )
29
  assert cube.state.shape == (6 * (size**2), 7), f"'state' has incorrect shape {cube.state.shape}"
 
 
 
30
  assert len(cube.history) == 0, "'history' field should be empty"
31
 
32
  @pytest.mark.parametrize("device", ["cpu"])
@@ -34,7 +27,7 @@ class TestCube:
34
  """
35
  Test that the .to method behaves as expected.
36
  """
37
- cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
38
  cube_2 = cube.to(device)
39
  assert torch.equal(cube.state, cube_2.state), "cube has different state after calling 'to' method"
40
 
@@ -42,7 +35,7 @@ class TestCube:
42
  """
43
  Test that the .reset_history method behaves as expected.
44
  """
45
- cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
46
  cube.rotate("X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i")
47
  cube.reset_history()
48
  assert cube.history == [], "method 'reset_history' does not flush content"
@@ -52,9 +45,9 @@ class TestCube:
52
  """
53
  Test that the .shuffle method behaves as expected.
54
  """
55
- cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
56
  cube_state = cube.state.clone()
57
- cube.shuffle(num_moves, seed)
58
  assert cube.history == [], "method 'shuffle' does not flush content"
59
  assert not torch.equal(cube_state, cube.state), "method 'shuffle' does not change state"
60
 
@@ -69,7 +62,7 @@ class TestCube:
69
  """
70
  Test that the .rotate method behaves as expected.
71
  """
72
- cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
73
  cube_state = cube.state.clone()
74
  cube.rotate(moves)
75
  assert cube.history != [], "method 'rotate' does not update history"
@@ -87,10 +80,10 @@ class TestCube:
87
  """
88
  Test that the .rotate_once method behaves as expected.
89
  """
90
- cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
91
  cube_state = cube.state.clone()
92
  cube.rotate_once(axis, slice, inverse)
93
- assert cube.history == [[axis, slice, inverse]], "method 'rotate_once' does not update history"
94
  assert not torch.equal(cube_state, cube.state), "method 'rotate_once' does not change state"
95
 
96
  @pytest.mark.parametrize(
@@ -104,8 +97,8 @@ class TestCube:
104
  """
105
  Test that the .compute_changes method behaves as expected.
106
  """
107
- cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
108
- facets = cube.state.argmax(dim=-1).to(torch.int16).tolist()
109
  changes = cube.compute_changes(moves)
110
 
111
  # apply changes induced by moves using the permutation dict returned by 'compute_changes'
@@ -113,15 +106,25 @@ class TestCube:
113
 
114
  # apply changes induced by moves using the optimized 'rotate' method
115
  cube.rotate(moves)
116
- observed = cube.state.argmax(dim=-1).to(torch.int16).tolist()
117
 
118
  # assert the tow are identical
119
  assert expected == observed, "method 'compute_changes' does not behave correctly: "
120
 
121
- def test__str__(self):
122
  """
123
  Test that the __str__ method behaves as expected.
124
  """
125
- cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
126
  repr = str(cube)
127
  assert len(repr), "__str__ method returns an empty representation"
 
 
 
 
 
 
 
 
 
 
 
10
  A testing class for the Cube class.
11
  """
12
 
13
+ @pytest.mark.parametrize("size", [3, 5, 10, 25])
14
+ def test__init__(self, size: int):
 
 
 
 
 
 
 
15
  """
16
  Test that the __init__ method produce expected attributes.
17
  """
18
+ cube = Cube(size)
 
 
 
19
  assert cube.state.shape == (6 * (size**2), 7), f"'state' has incorrect shape {cube.state.shape}"
20
+ assert cube.actions.shape == (3, size, 2, cube.state.shape[0], cube.state.shape[0]), (
21
+ f"'actions' has incorrect shape {cube.actions.shape}"
22
+ )
23
  assert len(cube.history) == 0, "'history' field should be empty"
24
 
25
  @pytest.mark.parametrize("device", ["cpu"])
 
27
  """
28
  Test that the .to method behaves as expected.
29
  """
30
+ cube = Cube(3)
31
  cube_2 = cube.to(device)
32
  assert torch.equal(cube.state, cube_2.state), "cube has different state after calling 'to' method"
33
 
 
35
  """
36
  Test that the .reset_history method behaves as expected.
37
  """
38
+ cube = Cube(3)
39
  cube.rotate("X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i")
40
  cube.reset_history()
41
  assert cube.history == [], "method 'reset_history' does not flush content"
 
45
  """
46
  Test that the .shuffle method behaves as expected.
47
  """
48
+ cube = Cube(3)
49
  cube_state = cube.state.clone()
50
+ cube.scramble(num_moves, seed)
51
  assert cube.history == [], "method 'shuffle' does not flush content"
52
  assert not torch.equal(cube_state, cube.state), "method 'shuffle' does not change state"
53
 
 
62
  """
63
  Test that the .rotate method behaves as expected.
64
  """
65
+ cube = Cube(3)
66
  cube_state = cube.state.clone()
67
  cube.rotate(moves)
68
  assert cube.history != [], "method 'rotate' does not update history"
 
80
  """
81
  Test that the .rotate_once method behaves as expected.
82
  """
83
+ cube = Cube(3)
84
  cube_state = cube.state.clone()
85
  cube.rotate_once(axis, slice, inverse)
86
+ assert cube.history == [(axis, slice, inverse)], "method 'rotate_once' does not update history"
87
  assert not torch.equal(cube_state, cube.state), "method 'rotate_once' does not change state"
88
 
89
  @pytest.mark.parametrize(
 
97
  """
98
  Test that the .compute_changes method behaves as expected.
99
  """
100
+ cube = Cube(3)
101
+ facets = cube.state.argmax(dim=-1).to(cube.dtype).tolist()
102
  changes = cube.compute_changes(moves)
103
 
104
  # apply changes induced by moves using the permutation dict returned by 'compute_changes'
 
106
 
107
  # apply changes induced by moves using the optimized 'rotate' method
108
  cube.rotate(moves)
109
+ observed = cube.state.argmax(dim=-1).to(cube.dtype).tolist()
110
 
111
  # assert the tow are identical
112
  assert expected == observed, "method 'compute_changes' does not behave correctly: "
113
 
114
+ def test__str__len(self):
115
  """
116
  Test that the __str__ method behaves as expected.
117
  """
118
+ cube = Cube(3)
119
  repr = str(cube)
120
  assert len(repr), "__str__ method returns an empty representation"
121
+
122
+ @pytest.mark.parametrize("size", [3, 5, 8, 10])
123
+ def test__str__content(self, size: int):
124
+ """
125
+ Test that stringify behaves as expected.
126
+ """
127
+ cube = Cube(size=size)
128
+ repr = str(cube)
129
+ lens = {len(line) for line in repr.split("\n")}
130
+ assert len(lens) == 1, f"'stringify' lines have variable length: {lens}"
tests/unit/test_display.py DELETED
@@ -1,42 +0,0 @@
1
- import pytest
2
-
3
- import torch
4
-
5
- from rubik.cube import Cube
6
- from rubik.display import stringify, pad_colors
7
-
8
-
9
- @pytest.mark.parametrize(
10
- "colors, size",
11
- [
12
- [["U", "L", "C", "R", "B", "D"], 3],
13
- [["Up", "Left", "Center", "Right", "Back", "Down"], 5],
14
- [["A", "BB", "CCC", "DDDD", "EEEEE", "FFFFFF"], 10],
15
- ],
16
- )
17
- def test_stringify(colors: list[str], size: int):
18
- """
19
- Test that stringify behaves as expected.
20
- """
21
- cube = Cube(colors=colors, size=size)
22
- state = cube.state.argmax(dim=-1).to(device="cpu", dtype=torch.int16)
23
- repr = stringify(state, colors, size)
24
- lens = {len(line) for line in repr.split("\n")}
25
- assert len(lens) == 1, f"'stringify' lines have variable length: {lens}"
26
-
27
-
28
- @pytest.mark.parametrize(
29
- "colors",
30
- [
31
- ["U", "L", "C", "R", "B", "D"],
32
- ["Up", "Left", "Center", "Right", "Back", "Down"],
33
- ["A", "BB", "CCC", "DDDD", "EEEEE", "FFFFFF"],
34
- ],
35
- )
36
- def test_pad_colors(colors: list[str]):
37
- """
38
- Test that pad_colors behaves as expected.
39
- """
40
- padded = pad_colors(colors)
41
- lengths = {len(color) for color in padded}
42
- assert len(lengths) == 1, f"'pad_colors' generates non-unique lengths: {lengths}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/{test_tensor_utils.py → test_state.py} RENAMED
@@ -2,7 +2,7 @@ import pytest
2
 
3
  import torch
4
 
5
- from rubik.tensor_utils import build_cube_tensor, build_permutation_matrix
6
 
7
 
8
  @pytest.mark.parametrize("size", [2, 3, 5, 20])
@@ -10,7 +10,7 @@ def test_build_cube_tensor(size: int):
10
  """
11
  Test that build_cube_tensor behaves as expected.
12
  """
13
- tensor = build_cube_tensor(colors=["U", "L", "C", "R", "B", "D"], size=size)
14
  facets = tensor.to_dense().to(dtype=torch.int8) != 0
15
  x_sums = facets.sum(dim=(0, 2, 3)).tolist()
16
  y_sums = facets.sum(dim=(0, 1, 3)).tolist()
 
2
 
3
  import torch
4
 
5
+ from rubik.state import build_cube_tensor, build_permutation_matrix
6
 
7
 
8
  @pytest.mark.parametrize("size", [2, 3, 5, 20])
 
10
  """
11
  Test that build_cube_tensor behaves as expected.
12
  """
13
+ tensor = build_cube_tensor(size)
14
  facets = tensor.to_dense().to(dtype=torch.int8) != 0
15
  x_sums = facets.sum(dim=(0, 2, 3)).tolist()
16
  y_sums = facets.sum(dim=(0, 1, 3)).tolist()
uv.lock CHANGED
The diff for this file is too large to render. See raw diff