Spaces:
Sleeping
Sleeping
change base dtype from int8 to int16 to avoid overflow when cube size is >= 7
Browse files- src/rubik/action.py +10 -10
- src/rubik/cube.py +5 -5
- src/rubik/tensor_utils.py +4 -4
src/rubik/action.py
CHANGED
@@ -15,7 +15,7 @@ POS_ROTATIONS = torch.stack(
|
|
15 |
[0, 0, 0, 1],
|
16 |
[0, 0, -1, 0],
|
17 |
],
|
18 |
-
dtype=torch.
|
19 |
),
|
20 |
# rot about Y: X -> Z
|
21 |
torch.tensor(
|
@@ -25,7 +25,7 @@ POS_ROTATIONS = torch.stack(
|
|
25 |
[0, 0, 1, 0],
|
26 |
[0, 1, 0, 0],
|
27 |
],
|
28 |
-
dtype=torch.
|
29 |
),
|
30 |
# rot about Z: Y -> X
|
31 |
torch.tensor(
|
@@ -35,7 +35,7 @@ POS_ROTATIONS = torch.stack(
|
|
35 |
[0, -1, 0, 0],
|
36 |
[0, 0, 0, 1],
|
37 |
],
|
38 |
-
dtype=torch.
|
39 |
),
|
40 |
]
|
41 |
)
|
@@ -46,7 +46,7 @@ POS_SHIFTS = torch.tensor(
|
|
46 |
[0, 1, 0, 0],
|
47 |
[0, 0, 1, 0],
|
48 |
],
|
49 |
-
dtype=torch.
|
50 |
)
|
51 |
|
52 |
|
@@ -74,7 +74,7 @@ def build_actions_tensor(size: int) -> torch.Tensor:
|
|
74 |
for inverse in range(2)
|
75 |
],
|
76 |
dim=0,
|
77 |
-
).sum(dim=0, dtype=torch.
|
78 |
|
79 |
|
80 |
def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor:
|
@@ -87,7 +87,7 @@ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch
|
|
87 |
length = 6 * (size**2)
|
88 |
|
89 |
# extract faces impacted by the move
|
90 |
-
indices = tensor.indices().to(dtype=torch.
|
91 |
changes = (indices[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
|
92 |
extract = indices[:, changes] # size = (4, n)
|
93 |
|
@@ -97,7 +97,7 @@ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch
|
|
97 |
rotated = rotated + offsets # size = (4, n)
|
98 |
|
99 |
# apply face rotation
|
100 |
-
rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(torch.
|
101 |
|
102 |
# from this point on, convert rotation into a position-based permutation of colors
|
103 |
(inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
|
@@ -116,11 +116,11 @@ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch
|
|
116 |
# convert permutation dict into sparse tensor
|
117 |
perm_indices = torch.tensor(
|
118 |
[[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
|
119 |
-
dtype=torch.
|
120 |
)
|
121 |
-
perm_values = torch.tensor([1] * length, dtype=torch.
|
122 |
perm_size = (3, size, 2, length, length)
|
123 |
-
return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.
|
124 |
|
125 |
|
126 |
def parse_action_str(move: str) -> tuple[int, ...]:
|
|
|
15 |
[0, 0, 0, 1],
|
16 |
[0, 0, -1, 0],
|
17 |
],
|
18 |
+
dtype=torch.int16,
|
19 |
),
|
20 |
# rot about Y: X -> Z
|
21 |
torch.tensor(
|
|
|
25 |
[0, 0, 1, 0],
|
26 |
[0, 1, 0, 0],
|
27 |
],
|
28 |
+
dtype=torch.int16,
|
29 |
),
|
30 |
# rot about Z: Y -> X
|
31 |
torch.tensor(
|
|
|
35 |
[0, -1, 0, 0],
|
36 |
[0, 0, 0, 1],
|
37 |
],
|
38 |
+
dtype=torch.int16,
|
39 |
),
|
40 |
]
|
41 |
)
|
|
|
46 |
[0, 1, 0, 0],
|
47 |
[0, 0, 1, 0],
|
48 |
],
|
49 |
+
dtype=torch.int16,
|
50 |
)
|
51 |
|
52 |
|
|
|
74 |
for inverse in range(2)
|
75 |
],
|
76 |
dim=0,
|
77 |
+
).sum(dim=0, dtype=torch.int16)
|
78 |
|
79 |
|
80 |
def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor:
|
|
|
87 |
length = 6 * (size**2)
|
88 |
|
89 |
# extract faces impacted by the move
|
90 |
+
indices = tensor.indices().to(dtype=torch.int16) # size = (4, length)
|
91 |
changes = (indices[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
|
92 |
extract = indices[:, changes] # size = (4, n)
|
93 |
|
|
|
97 |
rotated = rotated + offsets # size = (4, n)
|
98 |
|
99 |
# apply face rotation
|
100 |
+
rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(torch.int16) @ FACE_ROTATIONS[axis]).argmax(dim=-1)
|
101 |
|
102 |
# from this point on, convert rotation into a position-based permutation of colors
|
103 |
(inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
|
|
|
116 |
# convert permutation dict into sparse tensor
|
117 |
perm_indices = torch.tensor(
|
118 |
[[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
|
119 |
+
dtype=torch.int16,
|
120 |
)
|
121 |
+
perm_values = torch.tensor([1] * length, dtype=torch.int16)
|
122 |
perm_size = (3, size, 2, length, length)
|
123 |
+
return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.int16)
|
124 |
|
125 |
|
126 |
def parse_action_str(move: str) -> tuple[int, ...]:
|
src/rubik/cube.py
CHANGED
@@ -28,8 +28,8 @@ class Cube:
|
|
28 |
cube = Cube(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
|
29 |
"""
|
30 |
tensor = build_cube_tensor(colors, size)
|
31 |
-
self.coordinates = tensor.indices().transpose(0, 1).to(torch.
|
32 |
-
self.state = F.one_hot(tensor.values().long()).to(torch.
|
33 |
self.actions = build_actions_tensor(size)
|
34 |
self.history: list[list[int]] = []
|
35 |
self.colors = colors
|
@@ -37,7 +37,7 @@ class Cube:
|
|
37 |
|
38 |
def to(self, device: str | torch.device) -> "Cube":
|
39 |
device = torch.device(device)
|
40 |
-
dtype = torch.
|
41 |
self.coordinates = self.coordinates.to(device=device, dtype=dtype)
|
42 |
self.state = self.state.to(device=device, dtype=dtype)
|
43 |
self.actions = self.actions.to(device=device, dtype=dtype)
|
@@ -84,7 +84,7 @@ class Cube:
|
|
84 |
"""
|
85 |
actions = parse_actions_str(moves)
|
86 |
tensors = [self.actions[*action].to(torch.float32) for action in actions]
|
87 |
-
result = reduce(lambda A, B: A @ B, tensors).to(torch.
|
88 |
return dict(result.indices().transpose(0, 1).tolist())
|
89 |
|
90 |
def solve(self, policy: str) -> None:
|
@@ -97,5 +97,5 @@ class Cube:
|
|
97 |
"""
|
98 |
Compute a string representation of a cube.
|
99 |
"""
|
100 |
-
state = self.state.argmax(dim=-1).to(device="cpu", dtype=torch.
|
101 |
return stringify(state, self.colors, self.size)
|
|
|
28 |
cube = Cube(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
|
29 |
"""
|
30 |
tensor = build_cube_tensor(colors, size)
|
31 |
+
self.coordinates = tensor.indices().transpose(0, 1).to(torch.int16)
|
32 |
+
self.state = F.one_hot(tensor.values().long()).to(torch.int16)
|
33 |
self.actions = build_actions_tensor(size)
|
34 |
self.history: list[list[int]] = []
|
35 |
self.colors = colors
|
|
|
37 |
|
38 |
def to(self, device: str | torch.device) -> "Cube":
|
39 |
device = torch.device(device)
|
40 |
+
dtype = torch.int16 if device == torch.device("cpu") else torch.float32
|
41 |
self.coordinates = self.coordinates.to(device=device, dtype=dtype)
|
42 |
self.state = self.state.to(device=device, dtype=dtype)
|
43 |
self.actions = self.actions.to(device=device, dtype=dtype)
|
|
|
84 |
"""
|
85 |
actions = parse_actions_str(moves)
|
86 |
tensors = [self.actions[*action].to(torch.float32) for action in actions]
|
87 |
+
result = reduce(lambda A, B: A @ B, tensors).to(torch.int16)
|
88 |
return dict(result.indices().transpose(0, 1).tolist())
|
89 |
|
90 |
def solve(self, policy: str) -> None:
|
|
|
97 |
"""
|
98 |
Compute a string representation of a cube.
|
99 |
"""
|
100 |
+
state = self.state.argmax(dim=-1).to(device="cpu", dtype=torch.int16)
|
101 |
return stringify(state, self.colors, self.size)
|
src/rubik/tensor_utils.py
CHANGED
@@ -10,7 +10,7 @@ def build_cube_tensor(colors: list[str], size: int) -> torch.Tensor:
|
|
10 |
|
11 |
# build dense tensor filled with colors
|
12 |
n = size - 1
|
13 |
-
tensor = torch.zeros([6, size, size, size], dtype=torch.
|
14 |
tensor[0, :, :, n] = 1 # up
|
15 |
tensor[1, 0, :, :] = 2 # left
|
16 |
tensor[2, :, n, :] = 3 # front
|
@@ -26,6 +26,6 @@ def build_permutation_matrix(size: int, perm: str) -> torch.Tensor:
|
|
26 |
"""
|
27 |
perm_list = [int(p) for p in (perm + perm[0])]
|
28 |
perm_dict = {perm_list[i]: perm_list[i + 1] for i in range(len(perm))}
|
29 |
-
indices = torch.tensor([list(range(size)), [(perm_dict.get(i, i)) for i in range(size)]], dtype=torch.
|
30 |
-
values = torch.tensor([1] * size, dtype=torch.
|
31 |
-
return torch.sparse_coo_tensor(indices=indices, values=values, size=(size, size), dtype=torch.
|
|
|
10 |
|
11 |
# build dense tensor filled with colors
|
12 |
n = size - 1
|
13 |
+
tensor = torch.zeros([6, size, size, size], dtype=torch.int16)
|
14 |
tensor[0, :, :, n] = 1 # up
|
15 |
tensor[1, 0, :, :] = 2 # left
|
16 |
tensor[2, :, n, :] = 3 # front
|
|
|
26 |
"""
|
27 |
perm_list = [int(p) for p in (perm + perm[0])]
|
28 |
perm_dict = {perm_list[i]: perm_list[i + 1] for i in range(len(perm))}
|
29 |
+
indices = torch.tensor([list(range(size)), [(perm_dict.get(i, i)) for i in range(size)]], dtype=torch.int16)
|
30 |
+
values = torch.tensor([1] * size, dtype=torch.int16)
|
31 |
+
return torch.sparse_coo_tensor(indices=indices, values=values, size=(size, size), dtype=torch.int16)
|