JBAujogue commited on
Commit
580f78c
·
1 Parent(s): ea09e61

interface for shuffling and rotating a cube

Browse files
README.md CHANGED
@@ -12,7 +12,9 @@ uv sync
12
  pre-commit install
13
  ```
14
 
15
- ## Basic usage
 
 
16
 
17
  ```python
18
  from rubik.cube import Cube
@@ -30,6 +32,20 @@ print(cube)
30
  # DDD
31
  ```
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  ## Roadmap
34
 
35
  #### Fully tensorized Rubik Cube model
 
12
  pre-commit install
13
  ```
14
 
15
+ ## Usage
16
+
17
+ ### Create a cube
18
 
19
  ```python
20
  from rubik.cube import Cube
 
32
  # DDD
33
  ```
34
 
35
+ ### Perform basic moves
36
+
37
+ ```python
38
+ # shuffle the cube using 1000 random moves
39
+ cube.shuffle(num_moves=1000, seed=0)
40
+ print(cube)
41
+ print(cube.history)
42
+
43
+ # rotate it in some way
44
+ cube.rotate('X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i')
45
+ print(cube)
46
+ print(cube.history)
47
+ ```
48
+
49
  ## Roadmap
50
 
51
  #### Fully tensorized Rubik Cube model
src/rubik/{moves.py → action.py} RENAMED
@@ -1,11 +1,10 @@
 
1
  import torch
2
  import torch.nn.functional as F
3
 
4
- from rubik.cube import Cube
5
 
6
 
7
- INT8 = torch.int8
8
-
9
  POS_ROTATIONS = torch.stack(
10
  [
11
  # rot about X: Z -> Y
@@ -16,7 +15,7 @@ POS_ROTATIONS = torch.stack(
16
  [0, 0, 0, 1],
17
  [0, 0, -1, 0],
18
  ],
19
- dtype=INT8,
20
  ),
21
  # rot about Y: X -> Z
22
  torch.tensor(
@@ -26,7 +25,7 @@ POS_ROTATIONS = torch.stack(
26
  [0, 0, 1, 0],
27
  [0, 1, 0, 0],
28
  ],
29
- dtype=INT8,
30
  ),
31
  # rot about Z: Y -> X
32
  torch.tensor(
@@ -36,7 +35,7 @@ POS_ROTATIONS = torch.stack(
36
  [0, -1, 0, 0],
37
  [0, 0, 0, 1],
38
  ],
39
- dtype=INT8,
40
  ),
41
  ]
42
  )
@@ -47,49 +46,20 @@ POS_SHIFTS = torch.tensor(
47
  [0, 1, 0, 0],
48
  [0, 0, 1, 0],
49
  ],
50
- dtype=INT8,
51
  )
52
 
53
- FACE_PERMS = torch.stack(
 
 
 
 
54
  [
55
- # rotation about X axis: Up -> Front -> Down -> Back -> Up
56
- torch.tensor(
57
- [
58
- [0, 0, 0, 0, 1, 0],
59
- [0, 1, 0, 0, 0, 0],
60
- [1, 0, 0, 0, 0, 0],
61
- [0, 0, 0, 1, 0, 0],
62
- [0, 0, 0, 0, 0, 1],
63
- [0, 0, 1, 0, 0, 0],
64
- ],
65
- dtype=INT8,
66
- ),
67
- # rotation about Y axis: Up -> Left -> Down -> Right -> Up
68
- torch.tensor(
69
- [
70
- [0, 0, 0, 1, 0, 0],
71
- [1, 0, 0, 0, 0, 0],
72
- [0, 0, 1, 0, 0, 0],
73
- [0, 0, 0, 0, 0, 1],
74
- [0, 0, 0, 0, 1, 0],
75
- [0, 1, 0, 0, 0, 0],
76
- ],
77
- dtype=INT8,
78
- ),
79
- # rotation about Z axis: Left -> Front -> Right -> Back -> Left
80
- torch.tensor(
81
- [
82
- [1, 0, 0, 0, 0, 0],
83
- [0, 0, 0, 0, 1, 0],
84
- [0, 1, 0, 0, 0, 0],
85
- [0, 0, 1, 0, 0, 0],
86
- [0, 0, 0, 1, 0, 0],
87
- [0, 0, 0, 0, 0, 1],
88
- ],
89
- dtype=INT8,
90
- ),
91
  ]
92
- ).transpose(1, 2)
93
 
94
 
95
  def build_actions_tensor(size: int) -> torch.Tensor:
@@ -98,29 +68,28 @@ def build_actions_tensor(size: int) -> torch.Tensor:
98
  """
99
  return torch.stack(
100
  [
101
- build_permunation_tensor(size=size, axis=axis, slice=slice, inverse=inverse)
102
  for axis in range(3)
103
  for slice in range(size)
104
  for inverse in range(2)
105
  ],
106
  dim=0,
107
- ).sum(dim=0, dtype=INT8)
108
 
109
 
110
- def build_permunation_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor:
111
  """
112
  Compute the sparse permutation tensor whose effect on a position-frozen color vector
113
  is the rotation along the specified axis, within the specified slice and the specified
114
  orientation.
115
  """
116
- cube = Cube.create(["U", "L", "C", "R", "B", "D"], size=size)
117
  length = 6 * (size**2)
118
 
119
  # extract faces impacted by the move
120
- coordinates: torch.Tensor = cube.coordinates # size = (length, 4)
121
- transposed = coordinates.transpose(0, 1) # size = (4, length)
122
- indices = (transposed[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
123
- extract = transposed[:, indices] # size = (4, n)
124
 
125
  # apply coordinate rotation
126
  rotated = POS_ROTATIONS[axis] @ extract # size = (4, n)
@@ -128,7 +97,7 @@ def build_permunation_tensor(size: int, axis: int, slice: int, inverse: int) ->
128
  rotated = rotated + offsets # size = (4, n)
129
 
130
  # apply face rotation
131
- rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(INT8) @ FACE_PERMS[axis]).argmax(dim=-1)
132
 
133
  # from this point on, convert rotation into a position-based permutation of colors
134
  (inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
@@ -136,7 +105,7 @@ def build_permunation_tensor(size: int, axis: int, slice: int, inverse: int) ->
136
  outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
137
 
138
  # compute position-based permutation of colors equivalent to rotation converting inputs into outputs
139
- local_to_total = dict(enumerate(indices.tolist()))
140
  total_to_local = {ind: i for i, ind in local_to_total.items()}
141
 
142
  local_perm = {i: inputs.index(outputs[i]) for i in range(len(inputs))}
@@ -147,8 +116,29 @@ def build_permunation_tensor(size: int, axis: int, slice: int, inverse: int) ->
147
  # convert permutation dict into sparse tensor
148
  perm_indices = torch.tensor(
149
  [[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
150
- dtype=INT8,
151
  )
152
- perm_values = torch.tensor([1] * length, dtype=INT8)
153
  perm_size = (3, size, 2, length, length)
154
- return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=INT8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
  import torch
3
  import torch.nn.functional as F
4
 
5
+ from rubik.tensor_utils import build_permutation_matrix, build_cube_tensor
6
 
7
 
 
 
8
  POS_ROTATIONS = torch.stack(
9
  [
10
  # rot about X: Z -> Y
 
15
  [0, 0, 0, 1],
16
  [0, 0, -1, 0],
17
  ],
18
+ dtype=torch.int8,
19
  ),
20
  # rot about Y: X -> Z
21
  torch.tensor(
 
25
  [0, 0, 1, 0],
26
  [0, 1, 0, 0],
27
  ],
28
+ dtype=torch.int8,
29
  ),
30
  # rot about Z: Y -> X
31
  torch.tensor(
 
35
  [0, -1, 0, 0],
36
  [0, 0, 0, 1],
37
  ],
38
+ dtype=torch.int8,
39
  ),
40
  ]
41
  )
 
46
  [0, 1, 0, 0],
47
  [0, 0, 1, 0],
48
  ],
49
+ dtype=torch.int8,
50
  )
51
 
52
+
53
+ # rotation about X axis: 0 (Up) -> 2 (Front) -> 5 (Down) -> 4 (Back) -> 0 (Up)
54
+ # rotation about Y axis: 0 (Up) -> 1 (Left) -> 5 (Down) -> 3 (Right) -> 0 (Up)
55
+ # rotation about Z axis: 1 (Left) -> 2 (Front) -> 3 (Right) -> 4 (Back) -> 1 (Left)
56
+ FACE_ROTATIONS = torch.stack(
57
  [
58
+ build_permutation_matrix(size=6, perm="0254"),
59
+ build_permutation_matrix(size=6, perm="0153"),
60
+ build_permutation_matrix(size=6, perm="1234"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ]
62
+ )
63
 
64
 
65
  def build_actions_tensor(size: int) -> torch.Tensor:
 
68
  """
69
  return torch.stack(
70
  [
71
+ build_action_tensor(size=size, axis=axis, slice=slice, inverse=inverse)
72
  for axis in range(3)
73
  for slice in range(size)
74
  for inverse in range(2)
75
  ],
76
  dim=0,
77
+ ).sum(dim=0, dtype=torch.int8)
78
 
79
 
80
+ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor:
81
  """
82
  Compute the sparse permutation tensor whose effect on a position-frozen color vector
83
  is the rotation along the specified axis, within the specified slice and the specified
84
  orientation.
85
  """
86
+ tensor = build_cube_tensor(colors=list("ULCRBD"), size=size)
87
  length = 6 * (size**2)
88
 
89
  # extract faces impacted by the move
90
+ indices = tensor.indices().to(dtype=torch.int8) # size = (4, length)
91
+ changes = (indices[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
92
+ extract = indices[:, changes] # size = (4, n)
 
93
 
94
  # apply coordinate rotation
95
  rotated = POS_ROTATIONS[axis] @ extract # size = (4, n)
 
97
  rotated = rotated + offsets # size = (4, n)
98
 
99
  # apply face rotation
100
+ rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(torch.int8) @ FACE_ROTATIONS[axis]).argmax(dim=-1)
101
 
102
  # from this point on, convert rotation into a position-based permutation of colors
103
  (inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
 
105
  outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
106
 
107
  # compute position-based permutation of colors equivalent to rotation converting inputs into outputs
108
+ local_to_total = dict(enumerate(changes.tolist()))
109
  total_to_local = {ind: i for i, ind in local_to_total.items()}
110
 
111
  local_perm = {i: inputs.index(outputs[i]) for i in range(len(inputs))}
 
116
  # convert permutation dict into sparse tensor
117
  perm_indices = torch.tensor(
118
  [[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
119
+ dtype=torch.int8,
120
  )
121
+ perm_values = torch.tensor([1] * length, dtype=torch.int8)
122
  perm_size = (3, size, 2, length, length)
123
+ return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.int8)
124
+
125
+
126
+ def parse_action_str(name: str) -> tuple[int, ...]:
127
+ """
128
+ Convert the name of an action into a triple (axis, slice, inverse).
129
+ Examples:
130
+ 'X1' -> (0, 1, 0)
131
+ 'X2i' -> (0, 2, 1)
132
+ """
133
+ return ("XYZ".index(name[0]), int(name[1]), int(len(name) >= 3))
134
+
135
+
136
+ def sample_actions_str(num_moves: int, size: int, seed: int = 0) -> str:
137
+ """
138
+ Generate a string containing moves that are randomly sampled.
139
+ """
140
+ rng = np.random.default_rng(seed=seed)
141
+ axes = rng.choice(["X", "Y", "Z"], size=num_moves)
142
+ slices = rng.choice([str(i) for i in range(size)], size=num_moves)
143
+ orients = rng.choice(["", "i"], size=num_moves)
144
+ return " ".join("".join(move) for move in zip(axes, slices, orients))
src/rubik/cube.py CHANGED
@@ -2,6 +2,10 @@ from dataclasses import dataclass
2
 
3
  import torch
4
 
 
 
 
 
5
 
6
  @dataclass
7
  class Cube:
@@ -18,6 +22,8 @@ class Cube:
18
 
19
  coordinates: torch.Tensor
20
  state: torch.Tensor
 
 
21
  colors: list[str]
22
  size: int
23
 
@@ -28,66 +34,55 @@ class Cube:
28
  Example:
29
  cube = Cube.create(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
30
  """
31
- assert (num := len(set(colors))) == 6, f"Expected 6 distinct colors, got {num}"
32
- assert isinstance(size, int) and size > 1, f"Expected non-zero integrer size, got {size}"
33
-
34
- # build dense tensor filled with colors
35
- n = size - 1
36
- tensor = torch.zeros([6, size, size, size], dtype=torch.int8)
37
- tensor[0, :, :, n] = 1 # up
38
- tensor[1, 0, :, :] = 2 # left
39
- tensor[2, :, n, :] = 3 # front
40
- tensor[3, n, :, :] = 4 # right
41
- tensor[4, :, 0, :] = 5 # back
42
- tensor[5, :, :, 0] = 6 # down
43
- return cls.from_sparse(tensor.to_sparse(), colors, size)
44
-
45
- def shuffle(self, num_moves: int):
46
- raise NotImplementedError
47
 
48
- def rotate(self, moves: str):
49
- raise NotImplementedError
 
 
 
 
50
 
51
- def solve(slef, policy: str):
52
- raise NotImplementedError
 
 
 
 
 
 
53
 
54
- @staticmethod
55
- def pad_colors(colors: list[str]) -> list[str]:
56
  """
57
- Pad color names to strings of equal length.
58
  """
59
- max_len = max(len(c) for c in colors)
60
- return [c + " " * (max_len - len(c)) for c in colors]
 
 
61
 
62
- @classmethod
63
- def from_sparse(cls, tensor: torch.Tensor, colors: list[str], size: int) -> "Cube":
64
  """
65
- Gather cube attributes into a torch sparse tensor.
66
  """
67
- coordinates = tensor.indices().transpose(0, 1).to(torch.int8)
68
- values = tensor.values()
69
- return cls(coordinates, values, colors, size)
 
70
 
71
- def to_sparse(self) -> torch.Tensor:
72
  """
73
- Gather cube attributes into a torch sparse tensor.
74
  """
75
- return torch.sparse_coo_tensor(
76
- indices=self.coordinates.transpose(0, 1),
77
- values=self.state,
78
- size=(6, self.size, self.size, self.size),
79
- dtype=torch.int8,
80
- )
81
 
82
  def __str__(self):
83
  """
84
  Compute a string representation of a cube.
85
  """
86
- colors = self.pad_colors(self.colors)
87
- faces = self.state.reshape(6, self.size, self.size).transpose(1, 2)
88
- faces = [[[colors[i - 1] for i in row] for row in face.tolist()] for face in faces]
89
- void = " " * max(len(c) for c in self.colors) * self.size
90
- l1 = "\n".join(" ".join([void, "".join(row), void, void]) for row in faces[0])
91
- l2 = "\n".join(" ".join("".join(face[i]) for face in faces[1:5]) for i in range(self.size))
92
- l3 = "\n".join(" ".join((void, "".join(row), void, void)) for row in faces[-1])
93
- return "\n".join([l1, l2, l3])
 
2
 
3
  import torch
4
 
5
+ from rubik.action import build_actions_tensor, parse_action_str, sample_actions_str
6
+ from rubik.display import stringify
7
+ from rubik.tensor_utils import build_cube_tensor
8
+
9
 
10
  @dataclass
11
  class Cube:
 
22
 
23
  coordinates: torch.Tensor
24
  state: torch.Tensor
25
+ actions: torch.Tensor
26
+ history: list[list[int]]
27
  colors: list[str]
28
  size: int
29
 
 
34
  Example:
35
  cube = Cube.create(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
36
  """
37
+ tensor = build_cube_tensor(colors, size)
38
+ coordinates = tensor.indices().transpose(0, 1).to(torch.int8)
39
+ values = tensor.values()
40
+ actions = build_actions_tensor(size)
41
+ history: list[list[int]] = []
42
+ return cls(coordinates, values, actions, history, colors, size)
 
 
 
 
 
 
 
 
 
 
43
 
44
+ def reset_history(self) -> None:
45
+ """
46
+ Reset internal history of moves.
47
+ """
48
+ self.history = []
49
+ return
50
 
51
+ def shuffle(self, num_moves: int, seed: int = 0) -> None:
52
+ """
53
+ Randomly shuffle the cube by the supplied number of steps, and reset history of moves.
54
+ """
55
+ moves = sample_actions_str(num_moves, self.size, seed=seed)
56
+ self.rotate(moves)
57
+ self.reset_history()
58
+ return
59
 
60
+ def rotate(self, moves: str) -> None:
 
61
  """
62
+ Apply a sequence of moves (defined as plain string) to the cube.
63
  """
64
+ actions = [parse_action_str(move) for move in moves.strip().split()]
65
+ for action in actions:
66
+ self.rotate_once(*action)
67
+ return
68
 
69
+ def rotate_once(self, axis: int, slice: int, inverse: int) -> None:
 
70
  """
71
+ Apply a move (defined as 3 coordinates) to the cube.
72
  """
73
+ action = self.actions[axis, slice, inverse]
74
+ self.state = action @ self.state
75
+ self.history.append([axis, slice, inverse])
76
+ return
77
 
78
+ def solve(self, policy: str) -> None:
79
  """
80
+ Apply the specified solving policy to the cube.
81
  """
82
+ raise NotImplementedError
 
 
 
 
 
83
 
84
  def __str__(self):
85
  """
86
  Compute a string representation of a cube.
87
  """
88
+ return stringify(self)
 
 
 
 
 
 
 
src/rubik/display.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def stringify(cube) -> str:
2
+ """
3
+ Compute a string representation of a cube.
4
+ """
5
+ colors = pad_colors(cube.colors)
6
+ faces = cube.state.reshape(6, cube.size, cube.size).transpose(1, 2)
7
+ faces = [[[colors[i - 1] for i in row] for row in face.tolist()] for face in faces]
8
+ space = " " * max(len(c) for c in cube.colors) * cube.size
9
+ l1 = "\n".join(" ".join([space, "".join(row), space, space]) for row in faces[0])
10
+ l2 = "\n".join(" ".join("".join(face[i]) for face in faces[1:5]) for i in range(cube.size))
11
+ l3 = "\n".join(" ".join((space, "".join(row), space, space)) for row in faces[-1])
12
+ return "\n".join([l1, l2, l3])
13
+
14
+
15
+ def pad_colors(colors: list[str]) -> list[str]:
16
+ """
17
+ Pad color names to strings of equal length.
18
+ """
19
+ max_len = max(len(c) for c in colors)
20
+ return [c + " " * (max_len - len(c)) for c in colors]
src/rubik/tensor_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def build_cube_tensor(colors: list[str], size: int) -> torch.Tensor:
5
+ """
6
+ Convert a list of 6 colors and size into a sparse 4D tensor representing a cube.
7
+ """
8
+ assert (num := len(set(colors))) == 6, f"Expected 6 distinct colors, got {num}"
9
+ assert isinstance(size, int) and size > 1, f"Expected non-zero integrer size, got {size}"
10
+
11
+ # build dense tensor filled with colors
12
+ n = size - 1
13
+ tensor = torch.zeros([6, size, size, size], dtype=torch.int8)
14
+ tensor[0, :, :, n] = 1 # up
15
+ tensor[1, 0, :, :] = 2 # left
16
+ tensor[2, :, n, :] = 3 # front
17
+ tensor[3, n, :, :] = 4 # right
18
+ tensor[4, :, 0, :] = 5 # back
19
+ tensor[5, :, :, 0] = 6 # down
20
+ return tensor.to_sparse()
21
+
22
+
23
+ def build_permutation_matrix(size: int, perm: str) -> torch.Tensor:
24
+ """
25
+ Convert a permutation sting into a sparse 2D matrix.
26
+ """
27
+ perm_list = [int(p) for p in (perm + perm[0])]
28
+ perm_dict = {perm_list[i]: perm_list[i + 1] for i in range(len(perm))}
29
+ indices = torch.tensor([list(range(size)), [(perm_dict.get(i, i)) for i in range(size)]], dtype=torch.int8)
30
+ values = torch.tensor([1] * size, dtype=torch.int8)
31
+ return torch.sparse_coo_tensor(indices=indices, values=values, size=(size, size), dtype=torch.int8)