JBAujogue commited on
Commit
569ff0b
·
1 Parent(s): 3ab9d2b

add builder for tensor of actions

Browse files
README.md CHANGED
@@ -17,7 +17,7 @@ pre-commit install
17
  ```python
18
  from rubik.cube import Cube
19
 
20
- cube = Cube.from_default(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
21
  print(cube)
22
  # UUU
23
  # UUU
 
17
  ```python
18
  from rubik.cube import Cube
19
 
20
+ cube = Cube.create(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
21
  print(cube)
22
  # UUU
23
  # UUU
src/rubik/__main__.py CHANGED
@@ -1,6 +1,4 @@
1
- from fire import Fire
2
 
3
- from rubik.hello import hello_world
4
 
5
-
6
- Fire({"hello": hello_world})
 
1
+ # from fire import Fire
2
 
 
3
 
4
+ # Fire({"hello": hello_world})
 
src/rubik/cube.py CHANGED
@@ -6,38 +6,50 @@ import torch
6
  @dataclass
7
  class Cube:
8
  """
9
- A 5D tensor filled with 0 or 1. Dimensions have the following interpretation:
 
10
  - X coordinate (from 0 to self.size - 1, from Left to Right).
11
  - Y coordinate (from 0 to self.size - 1, from Back to Front).
12
  - Z coordinate (from 0 to self.size - 1, from Down to Up).
13
- - Face (from 0 to 5, with 0 = "Up", 1 = "Left", 2 = "Front", 3 = "Right", 4 = "Back", 5 = "Down").
14
- - Color (from 0 to 6, 0 being the "dark" color, the rest according to order given in "colors" attribute).
 
15
  """
16
 
 
17
  state: torch.Tensor
18
  colors: list[str]
19
  size: int
20
 
21
  @classmethod
22
- def from_default(cls, colors: list[str], size: int) -> "Cube":
23
  """
24
  Create Cube from a given list of 6 colors and size.
25
  Example:
26
- cube = Cube.from_default(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
27
  """
28
  assert (num := len(set(colors))) == 6, f"Expected 6 distinct colors, got {num}"
29
  assert isinstance(size, int) and size > 1, f"Expected non-zero integrer size, got {size}"
30
 
31
- # build tensor filled with 0's, and fill the faces with 1's
32
  n = size - 1
33
- state = torch.zeros([size, size, size, 6, 7], dtype=torch.int8)
34
- state[:, :, n, 0, 1] = 1 # up
35
- state[0, :, :, 1, 2] = 1 # left
36
- state[:, 0, :, 2, 3] = 1 # front
37
- state[n, :, :, 3, 4] = 1 # right
38
- state[:, n, :, 4, 5] = 1 # back
39
- state[:, :, 0, 5, 6] = 1 # down
40
- return cls(state, colors, size)
 
 
 
 
 
 
 
 
 
41
 
42
  @staticmethod
43
  def pad_colors(colors: list[str]) -> list[str]:
@@ -47,41 +59,35 @@ class Cube:
47
  max_len = max(len(c) for c in colors)
48
  return [c + " " * (max_len - len(c)) for c in colors]
49
 
50
- def to_grid(self, pad_colors: bool = False) -> list[list[list[str]]]:
 
51
  """
52
- Convert Cube into a 3D grid representation.
53
  """
54
- n = self.size - 1
55
- colors = self.pad_colors(self.colors) if pad_colors else self.colors
56
- grid = [
57
- self.state[:, :, n, 0, :].argmax(dim=-1), # up
58
- self.state[0, :, :, 1, :].argmax(dim=-1), # left
59
- self.state[:, 0, :, 2, :].argmax(dim=-1), # front
60
- self.state[n, :, :, 3, :].argmax(dim=-1), # right
61
- self.state[:, n, :, 4, :].argmax(dim=-1), # back
62
- self.state[:, :, 0, 5, :].argmax(dim=-1), # down
63
- ]
64
- return [[[colors[i - 1] for i in row] for row in face.tolist()] for face in grid]
 
 
 
65
 
66
  def __str__(self):
67
  """
68
  Compute a string representation of a cube.
69
- Example:
70
- cube = Cube.from_default(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
71
- print(cube)
72
- # UUU
73
- # UUU
74
- # UUU
75
- # LLL CCC RRR BBB
76
- # LLL CCC RRR BBB
77
- # LLL CCC RRR BBB
78
- # DDD
79
- # DDD
80
- # DDD
81
  """
82
- grid = self.to_grid(pad_colors=True)
 
 
83
  void = " " * max(len(c) for c in self.colors) * self.size
84
- l1 = "\n".join(" ".join([void, "".join(row), void, void]) for row in grid[0])
85
- l2 = "\n".join(" ".join("".join(grid[face_i][row_i]) for face_i in range(1, 5)) for row_i in range(self.size))
86
- l3 = "\n".join(" ".join((void, "".join(row), void, void)) for row in grid[-1])
87
  return "\n".join([l1, l2, l3])
 
6
  @dataclass
7
  class Cube:
8
  """
9
+ A 4D tensor filled with colors. Dimensions have the following interpretation:
10
+ - Face (from 0 to 5, with 0 = "Up", 1 = "Left", 2 = "Front", 3 = "Right", 4 = "Back", 5 = "Down").
11
  - X coordinate (from 0 to self.size - 1, from Left to Right).
12
  - Y coordinate (from 0 to self.size - 1, from Back to Front).
13
  - Z coordinate (from 0 to self.size - 1, from Down to Up).
14
+
15
+ Colors filling each tensor cell are from 0 to 6, 0 being the "dark" color,
16
+ the rest according to order given in "colors" attribute.
17
  """
18
 
19
+ coordinates: torch.Tensor
20
  state: torch.Tensor
21
  colors: list[str]
22
  size: int
23
 
24
  @classmethod
25
+ def create(cls, colors: list[str], size: int) -> "Cube":
26
  """
27
  Create Cube from a given list of 6 colors and size.
28
  Example:
29
+ cube = Cube.create(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
30
  """
31
  assert (num := len(set(colors))) == 6, f"Expected 6 distinct colors, got {num}"
32
  assert isinstance(size, int) and size > 1, f"Expected non-zero integrer size, got {size}"
33
 
34
+ # build dense tensor filled with colors
35
  n = size - 1
36
+ tensor = torch.zeros([6, size, size, size], dtype=torch.int8)
37
+ tensor[0, :, :, n] = 1 # up
38
+ tensor[1, 0, :, :] = 2 # left
39
+ tensor[2, :, n, :] = 3 # front
40
+ tensor[3, n, :, :] = 4 # right
41
+ tensor[4, :, 0, :] = 5 # back
42
+ tensor[5, :, :, 0] = 6 # down
43
+ return cls.from_sparse(tensor.to_sparse(), colors, size)
44
+
45
+ def shuffle(self, num_moves: int):
46
+ raise NotImplementedError
47
+
48
+ def rotate(self, moves: str):
49
+ raise NotImplementedError
50
+
51
+ def solve(slef, policy: str):
52
+ raise NotImplementedError
53
 
54
  @staticmethod
55
  def pad_colors(colors: list[str]) -> list[str]:
 
59
  max_len = max(len(c) for c in colors)
60
  return [c + " " * (max_len - len(c)) for c in colors]
61
 
62
+ @classmethod
63
+ def from_sparse(cls, tensor: torch.Tensor, colors: list[str], size: int) -> "Cube":
64
  """
65
+ Gather cube attributes into a torch sparse tensor.
66
  """
67
+ coordinates = tensor.indices().transpose(0, 1).to(torch.int8)
68
+ values = tensor.values()
69
+ return cls(coordinates, values, colors, size)
70
+
71
+ def to_sparse(self) -> torch.Tensor:
72
+ """
73
+ Gather cube attributes into a torch sparse tensor.
74
+ """
75
+ return torch.sparse_coo_tensor(
76
+ indices=self.coordinates.transpose(0, 1),
77
+ values=self.state,
78
+ size=(6, self.size, self.size, self.size),
79
+ dtype=torch.int8,
80
+ )
81
 
82
  def __str__(self):
83
  """
84
  Compute a string representation of a cube.
 
 
 
 
 
 
 
 
 
 
 
 
85
  """
86
+ colors = self.pad_colors(self.colors)
87
+ faces = self.state.reshape(6, self.size, self.size).transpose(1, 2)
88
+ faces = [[[colors[i - 1] for i in row] for row in face.tolist()] for face in faces]
89
  void = " " * max(len(c) for c in self.colors) * self.size
90
+ l1 = "\n".join(" ".join([void, "".join(row), void, void]) for row in faces[0])
91
+ l2 = "\n".join(" ".join("".join(face[i]) for face in faces[1:5]) for i in range(self.size))
92
+ l3 = "\n".join(" ".join((void, "".join(row), void, void)) for row in faces[-1])
93
  return "\n".join([l1, l2, l3])
src/rubik/hello.py DELETED
@@ -1,2 +0,0 @@
1
- def hello_world(s: str) -> None:
2
- print(f"Hello {s} !")
 
 
 
src/rubik/moves.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from rubik.cube import Cube
5
+
6
+
7
+ INT8 = torch.int8
8
+
9
+ POS_ROTATIONS = torch.stack(
10
+ [
11
+ # rot about X: Z -> Y
12
+ torch.tensor(
13
+ [
14
+ [1, 0, 0, 0],
15
+ [0, 1, 0, 0],
16
+ [0, 0, 0, 1],
17
+ [0, 0, -1, 0],
18
+ ],
19
+ dtype=INT8,
20
+ ),
21
+ # rot about Y: X -> Z
22
+ torch.tensor(
23
+ [
24
+ [1, 0, 0, 0],
25
+ [0, 0, 0, -1],
26
+ [0, 0, 1, 0],
27
+ [0, 1, 0, 0],
28
+ ],
29
+ dtype=INT8,
30
+ ),
31
+ # rot about Z: Y -> X
32
+ torch.tensor(
33
+ [
34
+ [1, 0, 0, 0],
35
+ [0, 0, 1, 0],
36
+ [0, -1, 0, 0],
37
+ [0, 0, 0, 1],
38
+ ],
39
+ dtype=INT8,
40
+ ),
41
+ ]
42
+ )
43
+
44
+ POS_SHIFTS = torch.tensor(
45
+ [
46
+ [0, 0, 0, 2],
47
+ [0, 2, 0, 0],
48
+ [0, 0, 2, 0],
49
+ ],
50
+ dtype=INT8,
51
+ )
52
+
53
+ FACE_PERMS = torch.stack(
54
+ [
55
+ # rotation about X axis: Up -> Front -> Down -> Back -> Up
56
+ torch.tensor(
57
+ [
58
+ [0, 0, 0, 0, 1, 0],
59
+ [0, 1, 0, 0, 0, 0],
60
+ [1, 0, 0, 0, 0, 0],
61
+ [0, 0, 0, 1, 0, 0],
62
+ [0, 0, 0, 0, 0, 1],
63
+ [0, 0, 1, 0, 0, 0],
64
+ ],
65
+ dtype=INT8,
66
+ ),
67
+ # rotation about Y axis: Up -> Left -> Down -> Right -> Up
68
+ torch.tensor(
69
+ [
70
+ [0, 0, 0, 1, 0, 0],
71
+ [1, 0, 0, 0, 0, 0],
72
+ [0, 0, 1, 0, 0, 0],
73
+ [0, 0, 0, 0, 0, 1],
74
+ [0, 0, 0, 0, 1, 0],
75
+ [0, 1, 0, 0, 0, 0],
76
+ ],
77
+ dtype=INT8,
78
+ ),
79
+ # rotation about Z axis: Left -> Front -> Right -> Back -> Left
80
+ torch.tensor(
81
+ [
82
+ [1, 0, 0, 0, 0, 0],
83
+ [0, 0, 0, 0, 1, 0],
84
+ [0, 1, 0, 0, 0, 0],
85
+ [0, 0, 1, 0, 0, 0],
86
+ [0, 0, 0, 1, 0, 0],
87
+ [0, 0, 0, 0, 0, 1],
88
+ ],
89
+ dtype=INT8,
90
+ ),
91
+ ]
92
+ ).transpose(1, 2)
93
+
94
+
95
+ def build_actions_tensor(size: int) -> torch.Tensor:
96
+ """
97
+ Built the 5D tensor carrying all rotations of a cube as matrix multiplication.
98
+ """
99
+ return torch.stack(
100
+ [
101
+ build_permunation_tensor(size=size, axis=axis, slice=slice, inverse=inverse)
102
+ for axis in range(3)
103
+ for slice in range(size)
104
+ for inverse in range(2)
105
+ ],
106
+ dim=0,
107
+ ).sum(dim=0, dtype=INT8)
108
+
109
+
110
+ def build_permunation_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor:
111
+ """
112
+ Compute the sparse permutation tensor whose effect on a position-frozen color vector
113
+ is the rotation along the specified axis, within the specified slice and the specified
114
+ orientation.
115
+ """
116
+ cube = Cube.create(["U", "L", "C", "R", "B", "D"], size=size)
117
+ length = 6 * (size**2)
118
+
119
+ # extract faces impacted by the move
120
+ coordinates: torch.Tensor = cube.coordinates # size = (length, 4)
121
+ transposed = coordinates.transpose(0, 1) # size = (4, length)
122
+ indices = (transposed[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
123
+ extract = transposed[:, indices] # size = (4, n)
124
+
125
+ # apply coordinate rotation
126
+ rotated = POS_ROTATIONS[axis] @ extract # size = (4, n)
127
+ offsets = POS_SHIFTS[axis].repeat(extract.shape[-1], 1).transpose(0, 1) # size = (4, n)
128
+ rotated = rotated + offsets # size = (4, n)
129
+
130
+ # apply face rotation
131
+ rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(INT8) @ FACE_PERMS[axis]).argmax(dim=-1)
132
+
133
+ # from this point on, convert rotation into a position-based permutation of colors
134
+ (inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
135
+ inputs = inputs.transpose(0, 1).tolist() # size = (n, 4)
136
+ outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
137
+
138
+ extract_to_coordinates = dict(enumerate(indices.tolist()))
139
+ coordinates_to_extract = {ind: i for i, ind in extract_to_coordinates.items()}
140
+
141
+ extract_perm = {i: inputs.index(outputs[i]) for i in range(len(inputs))}
142
+ global_perm = {
143
+ i: (i if i not in coordinates_to_extract else extract_to_coordinates[extract_perm[coordinates_to_extract[i]]])
144
+ for i in range(length)
145
+ }
146
+ perm_indices = torch.tensor(
147
+ [[axis] * length, [slice] * length, [inverse] * length, list(global_perm.keys()), list(global_perm.values())],
148
+ dtype=INT8,
149
+ )
150
+ perm_values = torch.tensor([1] * length)
151
+ perm_size = (3, size, 2, length, length)
152
+ return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=INT8)