Spaces:
Sleeping
Sleeping
add builder for tensor of actions
Browse files- README.md +1 -1
- src/rubik/__main__.py +2 -4
- src/rubik/cube.py +49 -43
- src/rubik/hello.py +0 -2
- src/rubik/moves.py +152 -0
README.md
CHANGED
@@ -17,7 +17,7 @@ pre-commit install
|
|
17 |
```python
|
18 |
from rubik.cube import Cube
|
19 |
|
20 |
-
cube = Cube.
|
21 |
print(cube)
|
22 |
# UUU
|
23 |
# UUU
|
|
|
17 |
```python
|
18 |
from rubik.cube import Cube
|
19 |
|
20 |
+
cube = Cube.create(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
|
21 |
print(cube)
|
22 |
# UUU
|
23 |
# UUU
|
src/rubik/__main__.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1 |
-
from fire import Fire
|
2 |
|
3 |
-
from rubik.hello import hello_world
|
4 |
|
5 |
-
|
6 |
-
Fire({"hello": hello_world})
|
|
|
1 |
+
# from fire import Fire
|
2 |
|
|
|
3 |
|
4 |
+
# Fire({"hello": hello_world})
|
|
src/rubik/cube.py
CHANGED
@@ -6,38 +6,50 @@ import torch
|
|
6 |
@dataclass
|
7 |
class Cube:
|
8 |
"""
|
9 |
-
A
|
|
|
10 |
- X coordinate (from 0 to self.size - 1, from Left to Right).
|
11 |
- Y coordinate (from 0 to self.size - 1, from Back to Front).
|
12 |
- Z coordinate (from 0 to self.size - 1, from Down to Up).
|
13 |
-
|
14 |
-
|
|
|
15 |
"""
|
16 |
|
|
|
17 |
state: torch.Tensor
|
18 |
colors: list[str]
|
19 |
size: int
|
20 |
|
21 |
@classmethod
|
22 |
-
def
|
23 |
"""
|
24 |
Create Cube from a given list of 6 colors and size.
|
25 |
Example:
|
26 |
-
cube = Cube.
|
27 |
"""
|
28 |
assert (num := len(set(colors))) == 6, f"Expected 6 distinct colors, got {num}"
|
29 |
assert isinstance(size, int) and size > 1, f"Expected non-zero integrer size, got {size}"
|
30 |
|
31 |
-
# build tensor filled with
|
32 |
n = size - 1
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
return cls(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
@staticmethod
|
43 |
def pad_colors(colors: list[str]) -> list[str]:
|
@@ -47,41 +59,35 @@ class Cube:
|
|
47 |
max_len = max(len(c) for c in colors)
|
48 |
return [c + " " * (max_len - len(c)) for c in colors]
|
49 |
|
50 |
-
|
|
|
51 |
"""
|
52 |
-
|
53 |
"""
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
65 |
|
66 |
def __str__(self):
|
67 |
"""
|
68 |
Compute a string representation of a cube.
|
69 |
-
Example:
|
70 |
-
cube = Cube.from_default(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
|
71 |
-
print(cube)
|
72 |
-
# UUU
|
73 |
-
# UUU
|
74 |
-
# UUU
|
75 |
-
# LLL CCC RRR BBB
|
76 |
-
# LLL CCC RRR BBB
|
77 |
-
# LLL CCC RRR BBB
|
78 |
-
# DDD
|
79 |
-
# DDD
|
80 |
-
# DDD
|
81 |
"""
|
82 |
-
|
|
|
|
|
83 |
void = " " * max(len(c) for c in self.colors) * self.size
|
84 |
-
l1 = "\n".join(" ".join([void, "".join(row), void, void]) for row in
|
85 |
-
l2 = "\n".join(" ".join("".join(
|
86 |
-
l3 = "\n".join(" ".join((void, "".join(row), void, void)) for row in
|
87 |
return "\n".join([l1, l2, l3])
|
|
|
6 |
@dataclass
|
7 |
class Cube:
|
8 |
"""
|
9 |
+
A 4D tensor filled with colors. Dimensions have the following interpretation:
|
10 |
+
- Face (from 0 to 5, with 0 = "Up", 1 = "Left", 2 = "Front", 3 = "Right", 4 = "Back", 5 = "Down").
|
11 |
- X coordinate (from 0 to self.size - 1, from Left to Right).
|
12 |
- Y coordinate (from 0 to self.size - 1, from Back to Front).
|
13 |
- Z coordinate (from 0 to self.size - 1, from Down to Up).
|
14 |
+
|
15 |
+
Colors filling each tensor cell are from 0 to 6, 0 being the "dark" color,
|
16 |
+
the rest according to order given in "colors" attribute.
|
17 |
"""
|
18 |
|
19 |
+
coordinates: torch.Tensor
|
20 |
state: torch.Tensor
|
21 |
colors: list[str]
|
22 |
size: int
|
23 |
|
24 |
@classmethod
|
25 |
+
def create(cls, colors: list[str], size: int) -> "Cube":
|
26 |
"""
|
27 |
Create Cube from a given list of 6 colors and size.
|
28 |
Example:
|
29 |
+
cube = Cube.create(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
|
30 |
"""
|
31 |
assert (num := len(set(colors))) == 6, f"Expected 6 distinct colors, got {num}"
|
32 |
assert isinstance(size, int) and size > 1, f"Expected non-zero integrer size, got {size}"
|
33 |
|
34 |
+
# build dense tensor filled with colors
|
35 |
n = size - 1
|
36 |
+
tensor = torch.zeros([6, size, size, size], dtype=torch.int8)
|
37 |
+
tensor[0, :, :, n] = 1 # up
|
38 |
+
tensor[1, 0, :, :] = 2 # left
|
39 |
+
tensor[2, :, n, :] = 3 # front
|
40 |
+
tensor[3, n, :, :] = 4 # right
|
41 |
+
tensor[4, :, 0, :] = 5 # back
|
42 |
+
tensor[5, :, :, 0] = 6 # down
|
43 |
+
return cls.from_sparse(tensor.to_sparse(), colors, size)
|
44 |
+
|
45 |
+
def shuffle(self, num_moves: int):
|
46 |
+
raise NotImplementedError
|
47 |
+
|
48 |
+
def rotate(self, moves: str):
|
49 |
+
raise NotImplementedError
|
50 |
+
|
51 |
+
def solve(slef, policy: str):
|
52 |
+
raise NotImplementedError
|
53 |
|
54 |
@staticmethod
|
55 |
def pad_colors(colors: list[str]) -> list[str]:
|
|
|
59 |
max_len = max(len(c) for c in colors)
|
60 |
return [c + " " * (max_len - len(c)) for c in colors]
|
61 |
|
62 |
+
@classmethod
|
63 |
+
def from_sparse(cls, tensor: torch.Tensor, colors: list[str], size: int) -> "Cube":
|
64 |
"""
|
65 |
+
Gather cube attributes into a torch sparse tensor.
|
66 |
"""
|
67 |
+
coordinates = tensor.indices().transpose(0, 1).to(torch.int8)
|
68 |
+
values = tensor.values()
|
69 |
+
return cls(coordinates, values, colors, size)
|
70 |
+
|
71 |
+
def to_sparse(self) -> torch.Tensor:
|
72 |
+
"""
|
73 |
+
Gather cube attributes into a torch sparse tensor.
|
74 |
+
"""
|
75 |
+
return torch.sparse_coo_tensor(
|
76 |
+
indices=self.coordinates.transpose(0, 1),
|
77 |
+
values=self.state,
|
78 |
+
size=(6, self.size, self.size, self.size),
|
79 |
+
dtype=torch.int8,
|
80 |
+
)
|
81 |
|
82 |
def __str__(self):
|
83 |
"""
|
84 |
Compute a string representation of a cube.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
"""
|
86 |
+
colors = self.pad_colors(self.colors)
|
87 |
+
faces = self.state.reshape(6, self.size, self.size).transpose(1, 2)
|
88 |
+
faces = [[[colors[i - 1] for i in row] for row in face.tolist()] for face in faces]
|
89 |
void = " " * max(len(c) for c in self.colors) * self.size
|
90 |
+
l1 = "\n".join(" ".join([void, "".join(row), void, void]) for row in faces[0])
|
91 |
+
l2 = "\n".join(" ".join("".join(face[i]) for face in faces[1:5]) for i in range(self.size))
|
92 |
+
l3 = "\n".join(" ".join((void, "".join(row), void, void)) for row in faces[-1])
|
93 |
return "\n".join([l1, l2, l3])
|
src/rubik/hello.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
def hello_world(s: str) -> None:
|
2 |
-
print(f"Hello {s} !")
|
|
|
|
|
|
src/rubik/moves.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from rubik.cube import Cube
|
5 |
+
|
6 |
+
|
7 |
+
INT8 = torch.int8
|
8 |
+
|
9 |
+
POS_ROTATIONS = torch.stack(
|
10 |
+
[
|
11 |
+
# rot about X: Z -> Y
|
12 |
+
torch.tensor(
|
13 |
+
[
|
14 |
+
[1, 0, 0, 0],
|
15 |
+
[0, 1, 0, 0],
|
16 |
+
[0, 0, 0, 1],
|
17 |
+
[0, 0, -1, 0],
|
18 |
+
],
|
19 |
+
dtype=INT8,
|
20 |
+
),
|
21 |
+
# rot about Y: X -> Z
|
22 |
+
torch.tensor(
|
23 |
+
[
|
24 |
+
[1, 0, 0, 0],
|
25 |
+
[0, 0, 0, -1],
|
26 |
+
[0, 0, 1, 0],
|
27 |
+
[0, 1, 0, 0],
|
28 |
+
],
|
29 |
+
dtype=INT8,
|
30 |
+
),
|
31 |
+
# rot about Z: Y -> X
|
32 |
+
torch.tensor(
|
33 |
+
[
|
34 |
+
[1, 0, 0, 0],
|
35 |
+
[0, 0, 1, 0],
|
36 |
+
[0, -1, 0, 0],
|
37 |
+
[0, 0, 0, 1],
|
38 |
+
],
|
39 |
+
dtype=INT8,
|
40 |
+
),
|
41 |
+
]
|
42 |
+
)
|
43 |
+
|
44 |
+
POS_SHIFTS = torch.tensor(
|
45 |
+
[
|
46 |
+
[0, 0, 0, 2],
|
47 |
+
[0, 2, 0, 0],
|
48 |
+
[0, 0, 2, 0],
|
49 |
+
],
|
50 |
+
dtype=INT8,
|
51 |
+
)
|
52 |
+
|
53 |
+
FACE_PERMS = torch.stack(
|
54 |
+
[
|
55 |
+
# rotation about X axis: Up -> Front -> Down -> Back -> Up
|
56 |
+
torch.tensor(
|
57 |
+
[
|
58 |
+
[0, 0, 0, 0, 1, 0],
|
59 |
+
[0, 1, 0, 0, 0, 0],
|
60 |
+
[1, 0, 0, 0, 0, 0],
|
61 |
+
[0, 0, 0, 1, 0, 0],
|
62 |
+
[0, 0, 0, 0, 0, 1],
|
63 |
+
[0, 0, 1, 0, 0, 0],
|
64 |
+
],
|
65 |
+
dtype=INT8,
|
66 |
+
),
|
67 |
+
# rotation about Y axis: Up -> Left -> Down -> Right -> Up
|
68 |
+
torch.tensor(
|
69 |
+
[
|
70 |
+
[0, 0, 0, 1, 0, 0],
|
71 |
+
[1, 0, 0, 0, 0, 0],
|
72 |
+
[0, 0, 1, 0, 0, 0],
|
73 |
+
[0, 0, 0, 0, 0, 1],
|
74 |
+
[0, 0, 0, 0, 1, 0],
|
75 |
+
[0, 1, 0, 0, 0, 0],
|
76 |
+
],
|
77 |
+
dtype=INT8,
|
78 |
+
),
|
79 |
+
# rotation about Z axis: Left -> Front -> Right -> Back -> Left
|
80 |
+
torch.tensor(
|
81 |
+
[
|
82 |
+
[1, 0, 0, 0, 0, 0],
|
83 |
+
[0, 0, 0, 0, 1, 0],
|
84 |
+
[0, 1, 0, 0, 0, 0],
|
85 |
+
[0, 0, 1, 0, 0, 0],
|
86 |
+
[0, 0, 0, 1, 0, 0],
|
87 |
+
[0, 0, 0, 0, 0, 1],
|
88 |
+
],
|
89 |
+
dtype=INT8,
|
90 |
+
),
|
91 |
+
]
|
92 |
+
).transpose(1, 2)
|
93 |
+
|
94 |
+
|
95 |
+
def build_actions_tensor(size: int) -> torch.Tensor:
|
96 |
+
"""
|
97 |
+
Built the 5D tensor carrying all rotations of a cube as matrix multiplication.
|
98 |
+
"""
|
99 |
+
return torch.stack(
|
100 |
+
[
|
101 |
+
build_permunation_tensor(size=size, axis=axis, slice=slice, inverse=inverse)
|
102 |
+
for axis in range(3)
|
103 |
+
for slice in range(size)
|
104 |
+
for inverse in range(2)
|
105 |
+
],
|
106 |
+
dim=0,
|
107 |
+
).sum(dim=0, dtype=INT8)
|
108 |
+
|
109 |
+
|
110 |
+
def build_permunation_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor:
|
111 |
+
"""
|
112 |
+
Compute the sparse permutation tensor whose effect on a position-frozen color vector
|
113 |
+
is the rotation along the specified axis, within the specified slice and the specified
|
114 |
+
orientation.
|
115 |
+
"""
|
116 |
+
cube = Cube.create(["U", "L", "C", "R", "B", "D"], size=size)
|
117 |
+
length = 6 * (size**2)
|
118 |
+
|
119 |
+
# extract faces impacted by the move
|
120 |
+
coordinates: torch.Tensor = cube.coordinates # size = (length, 4)
|
121 |
+
transposed = coordinates.transpose(0, 1) # size = (4, length)
|
122 |
+
indices = (transposed[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
|
123 |
+
extract = transposed[:, indices] # size = (4, n)
|
124 |
+
|
125 |
+
# apply coordinate rotation
|
126 |
+
rotated = POS_ROTATIONS[axis] @ extract # size = (4, n)
|
127 |
+
offsets = POS_SHIFTS[axis].repeat(extract.shape[-1], 1).transpose(0, 1) # size = (4, n)
|
128 |
+
rotated = rotated + offsets # size = (4, n)
|
129 |
+
|
130 |
+
# apply face rotation
|
131 |
+
rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(INT8) @ FACE_PERMS[axis]).argmax(dim=-1)
|
132 |
+
|
133 |
+
# from this point on, convert rotation into a position-based permutation of colors
|
134 |
+
(inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
|
135 |
+
inputs = inputs.transpose(0, 1).tolist() # size = (n, 4)
|
136 |
+
outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
|
137 |
+
|
138 |
+
extract_to_coordinates = dict(enumerate(indices.tolist()))
|
139 |
+
coordinates_to_extract = {ind: i for i, ind in extract_to_coordinates.items()}
|
140 |
+
|
141 |
+
extract_perm = {i: inputs.index(outputs[i]) for i in range(len(inputs))}
|
142 |
+
global_perm = {
|
143 |
+
i: (i if i not in coordinates_to_extract else extract_to_coordinates[extract_perm[coordinates_to_extract[i]]])
|
144 |
+
for i in range(length)
|
145 |
+
}
|
146 |
+
perm_indices = torch.tensor(
|
147 |
+
[[axis] * length, [slice] * length, [inverse] * length, list(global_perm.keys()), list(global_perm.values())],
|
148 |
+
dtype=INT8,
|
149 |
+
)
|
150 |
+
perm_values = torch.tensor([1] * length)
|
151 |
+
perm_size = (3, size, 2, length, length)
|
152 |
+
return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=INT8)
|