Spaces:
Sleeping
Sleeping
interface for shuffling and rotating a cube
Browse files- README.md +17 -1
- src/rubik/{moves.py → action.py} +48 -58
- src/rubik/cube.py +42 -47
- src/rubik/display.py +20 -0
- src/rubik/tensor_utils.py +31 -0
README.md
CHANGED
@@ -12,7 +12,9 @@ uv sync
|
|
12 |
pre-commit install
|
13 |
```
|
14 |
|
15 |
-
##
|
|
|
|
|
16 |
|
17 |
```python
|
18 |
from rubik.cube import Cube
|
@@ -30,6 +32,20 @@ print(cube)
|
|
30 |
# DDD
|
31 |
```
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
## Roadmap
|
34 |
|
35 |
#### Fully tensorized Rubik Cube model
|
|
|
12 |
pre-commit install
|
13 |
```
|
14 |
|
15 |
+
## Usage
|
16 |
+
|
17 |
+
### Create a cube
|
18 |
|
19 |
```python
|
20 |
from rubik.cube import Cube
|
|
|
32 |
# DDD
|
33 |
```
|
34 |
|
35 |
+
### Perform basic moves
|
36 |
+
|
37 |
+
```python
|
38 |
+
# shuffle the cube using 1000 random moves
|
39 |
+
cube.shuffle(num_moves=1000, seed=0)
|
40 |
+
print(cube)
|
41 |
+
print(cube.history)
|
42 |
+
|
43 |
+
# rotate it in some way
|
44 |
+
cube.rotate('X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i')
|
45 |
+
print(cube)
|
46 |
+
print(cube.history)
|
47 |
+
```
|
48 |
+
|
49 |
## Roadmap
|
50 |
|
51 |
#### Fully tensorized Rubik Cube model
|
src/rubik/{moves.py → action.py}
RENAMED
@@ -1,11 +1,10 @@
|
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
|
4 |
-
from rubik.
|
5 |
|
6 |
|
7 |
-
INT8 = torch.int8
|
8 |
-
|
9 |
POS_ROTATIONS = torch.stack(
|
10 |
[
|
11 |
# rot about X: Z -> Y
|
@@ -16,7 +15,7 @@ POS_ROTATIONS = torch.stack(
|
|
16 |
[0, 0, 0, 1],
|
17 |
[0, 0, -1, 0],
|
18 |
],
|
19 |
-
dtype=
|
20 |
),
|
21 |
# rot about Y: X -> Z
|
22 |
torch.tensor(
|
@@ -26,7 +25,7 @@ POS_ROTATIONS = torch.stack(
|
|
26 |
[0, 0, 1, 0],
|
27 |
[0, 1, 0, 0],
|
28 |
],
|
29 |
-
dtype=
|
30 |
),
|
31 |
# rot about Z: Y -> X
|
32 |
torch.tensor(
|
@@ -36,7 +35,7 @@ POS_ROTATIONS = torch.stack(
|
|
36 |
[0, -1, 0, 0],
|
37 |
[0, 0, 0, 1],
|
38 |
],
|
39 |
-
dtype=
|
40 |
),
|
41 |
]
|
42 |
)
|
@@ -47,49 +46,20 @@ POS_SHIFTS = torch.tensor(
|
|
47 |
[0, 1, 0, 0],
|
48 |
[0, 0, 1, 0],
|
49 |
],
|
50 |
-
dtype=
|
51 |
)
|
52 |
|
53 |
-
|
|
|
|
|
|
|
|
|
54 |
[
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
[0, 0, 0, 0, 1, 0],
|
59 |
-
[0, 1, 0, 0, 0, 0],
|
60 |
-
[1, 0, 0, 0, 0, 0],
|
61 |
-
[0, 0, 0, 1, 0, 0],
|
62 |
-
[0, 0, 0, 0, 0, 1],
|
63 |
-
[0, 0, 1, 0, 0, 0],
|
64 |
-
],
|
65 |
-
dtype=INT8,
|
66 |
-
),
|
67 |
-
# rotation about Y axis: Up -> Left -> Down -> Right -> Up
|
68 |
-
torch.tensor(
|
69 |
-
[
|
70 |
-
[0, 0, 0, 1, 0, 0],
|
71 |
-
[1, 0, 0, 0, 0, 0],
|
72 |
-
[0, 0, 1, 0, 0, 0],
|
73 |
-
[0, 0, 0, 0, 0, 1],
|
74 |
-
[0, 0, 0, 0, 1, 0],
|
75 |
-
[0, 1, 0, 0, 0, 0],
|
76 |
-
],
|
77 |
-
dtype=INT8,
|
78 |
-
),
|
79 |
-
# rotation about Z axis: Left -> Front -> Right -> Back -> Left
|
80 |
-
torch.tensor(
|
81 |
-
[
|
82 |
-
[1, 0, 0, 0, 0, 0],
|
83 |
-
[0, 0, 0, 0, 1, 0],
|
84 |
-
[0, 1, 0, 0, 0, 0],
|
85 |
-
[0, 0, 1, 0, 0, 0],
|
86 |
-
[0, 0, 0, 1, 0, 0],
|
87 |
-
[0, 0, 0, 0, 0, 1],
|
88 |
-
],
|
89 |
-
dtype=INT8,
|
90 |
-
),
|
91 |
]
|
92 |
-
)
|
93 |
|
94 |
|
95 |
def build_actions_tensor(size: int) -> torch.Tensor:
|
@@ -98,29 +68,28 @@ def build_actions_tensor(size: int) -> torch.Tensor:
|
|
98 |
"""
|
99 |
return torch.stack(
|
100 |
[
|
101 |
-
|
102 |
for axis in range(3)
|
103 |
for slice in range(size)
|
104 |
for inverse in range(2)
|
105 |
],
|
106 |
dim=0,
|
107 |
-
).sum(dim=0, dtype=
|
108 |
|
109 |
|
110 |
-
def
|
111 |
"""
|
112 |
Compute the sparse permutation tensor whose effect on a position-frozen color vector
|
113 |
is the rotation along the specified axis, within the specified slice and the specified
|
114 |
orientation.
|
115 |
"""
|
116 |
-
|
117 |
length = 6 * (size**2)
|
118 |
|
119 |
# extract faces impacted by the move
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
extract = transposed[:, indices] # size = (4, n)
|
124 |
|
125 |
# apply coordinate rotation
|
126 |
rotated = POS_ROTATIONS[axis] @ extract # size = (4, n)
|
@@ -128,7 +97,7 @@ def build_permunation_tensor(size: int, axis: int, slice: int, inverse: int) ->
|
|
128 |
rotated = rotated + offsets # size = (4, n)
|
129 |
|
130 |
# apply face rotation
|
131 |
-
rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(
|
132 |
|
133 |
# from this point on, convert rotation into a position-based permutation of colors
|
134 |
(inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
|
@@ -136,7 +105,7 @@ def build_permunation_tensor(size: int, axis: int, slice: int, inverse: int) ->
|
|
136 |
outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
|
137 |
|
138 |
# compute position-based permutation of colors equivalent to rotation converting inputs into outputs
|
139 |
-
local_to_total = dict(enumerate(
|
140 |
total_to_local = {ind: i for i, ind in local_to_total.items()}
|
141 |
|
142 |
local_perm = {i: inputs.index(outputs[i]) for i in range(len(inputs))}
|
@@ -147,8 +116,29 @@ def build_permunation_tensor(size: int, axis: int, slice: int, inverse: int) ->
|
|
147 |
# convert permutation dict into sparse tensor
|
148 |
perm_indices = torch.tensor(
|
149 |
[[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
|
150 |
-
dtype=
|
151 |
)
|
152 |
-
perm_values = torch.tensor([1] * length, dtype=
|
153 |
perm_size = (3, size, 2, length, length)
|
154 |
-
return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
+
from rubik.tensor_utils import build_permutation_matrix, build_cube_tensor
|
6 |
|
7 |
|
|
|
|
|
8 |
POS_ROTATIONS = torch.stack(
|
9 |
[
|
10 |
# rot about X: Z -> Y
|
|
|
15 |
[0, 0, 0, 1],
|
16 |
[0, 0, -1, 0],
|
17 |
],
|
18 |
+
dtype=torch.int8,
|
19 |
),
|
20 |
# rot about Y: X -> Z
|
21 |
torch.tensor(
|
|
|
25 |
[0, 0, 1, 0],
|
26 |
[0, 1, 0, 0],
|
27 |
],
|
28 |
+
dtype=torch.int8,
|
29 |
),
|
30 |
# rot about Z: Y -> X
|
31 |
torch.tensor(
|
|
|
35 |
[0, -1, 0, 0],
|
36 |
[0, 0, 0, 1],
|
37 |
],
|
38 |
+
dtype=torch.int8,
|
39 |
),
|
40 |
]
|
41 |
)
|
|
|
46 |
[0, 1, 0, 0],
|
47 |
[0, 0, 1, 0],
|
48 |
],
|
49 |
+
dtype=torch.int8,
|
50 |
)
|
51 |
|
52 |
+
|
53 |
+
# rotation about X axis: 0 (Up) -> 2 (Front) -> 5 (Down) -> 4 (Back) -> 0 (Up)
|
54 |
+
# rotation about Y axis: 0 (Up) -> 1 (Left) -> 5 (Down) -> 3 (Right) -> 0 (Up)
|
55 |
+
# rotation about Z axis: 1 (Left) -> 2 (Front) -> 3 (Right) -> 4 (Back) -> 1 (Left)
|
56 |
+
FACE_ROTATIONS = torch.stack(
|
57 |
[
|
58 |
+
build_permutation_matrix(size=6, perm="0254"),
|
59 |
+
build_permutation_matrix(size=6, perm="0153"),
|
60 |
+
build_permutation_matrix(size=6, perm="1234"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
]
|
62 |
+
)
|
63 |
|
64 |
|
65 |
def build_actions_tensor(size: int) -> torch.Tensor:
|
|
|
68 |
"""
|
69 |
return torch.stack(
|
70 |
[
|
71 |
+
build_action_tensor(size=size, axis=axis, slice=slice, inverse=inverse)
|
72 |
for axis in range(3)
|
73 |
for slice in range(size)
|
74 |
for inverse in range(2)
|
75 |
],
|
76 |
dim=0,
|
77 |
+
).sum(dim=0, dtype=torch.int8)
|
78 |
|
79 |
|
80 |
+
def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor:
|
81 |
"""
|
82 |
Compute the sparse permutation tensor whose effect on a position-frozen color vector
|
83 |
is the rotation along the specified axis, within the specified slice and the specified
|
84 |
orientation.
|
85 |
"""
|
86 |
+
tensor = build_cube_tensor(colors=list("ULCRBD"), size=size)
|
87 |
length = 6 * (size**2)
|
88 |
|
89 |
# extract faces impacted by the move
|
90 |
+
indices = tensor.indices().to(dtype=torch.int8) # size = (4, length)
|
91 |
+
changes = (indices[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
|
92 |
+
extract = indices[:, changes] # size = (4, n)
|
|
|
93 |
|
94 |
# apply coordinate rotation
|
95 |
rotated = POS_ROTATIONS[axis] @ extract # size = (4, n)
|
|
|
97 |
rotated = rotated + offsets # size = (4, n)
|
98 |
|
99 |
# apply face rotation
|
100 |
+
rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(torch.int8) @ FACE_ROTATIONS[axis]).argmax(dim=-1)
|
101 |
|
102 |
# from this point on, convert rotation into a position-based permutation of colors
|
103 |
(inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
|
|
|
105 |
outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
|
106 |
|
107 |
# compute position-based permutation of colors equivalent to rotation converting inputs into outputs
|
108 |
+
local_to_total = dict(enumerate(changes.tolist()))
|
109 |
total_to_local = {ind: i for i, ind in local_to_total.items()}
|
110 |
|
111 |
local_perm = {i: inputs.index(outputs[i]) for i in range(len(inputs))}
|
|
|
116 |
# convert permutation dict into sparse tensor
|
117 |
perm_indices = torch.tensor(
|
118 |
[[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
|
119 |
+
dtype=torch.int8,
|
120 |
)
|
121 |
+
perm_values = torch.tensor([1] * length, dtype=torch.int8)
|
122 |
perm_size = (3, size, 2, length, length)
|
123 |
+
return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.int8)
|
124 |
+
|
125 |
+
|
126 |
+
def parse_action_str(name: str) -> tuple[int, ...]:
|
127 |
+
"""
|
128 |
+
Convert the name of an action into a triple (axis, slice, inverse).
|
129 |
+
Examples:
|
130 |
+
'X1' -> (0, 1, 0)
|
131 |
+
'X2i' -> (0, 2, 1)
|
132 |
+
"""
|
133 |
+
return ("XYZ".index(name[0]), int(name[1]), int(len(name) >= 3))
|
134 |
+
|
135 |
+
|
136 |
+
def sample_actions_str(num_moves: int, size: int, seed: int = 0) -> str:
|
137 |
+
"""
|
138 |
+
Generate a string containing moves that are randomly sampled.
|
139 |
+
"""
|
140 |
+
rng = np.random.default_rng(seed=seed)
|
141 |
+
axes = rng.choice(["X", "Y", "Z"], size=num_moves)
|
142 |
+
slices = rng.choice([str(i) for i in range(size)], size=num_moves)
|
143 |
+
orients = rng.choice(["", "i"], size=num_moves)
|
144 |
+
return " ".join("".join(move) for move in zip(axes, slices, orients))
|
src/rubik/cube.py
CHANGED
@@ -2,6 +2,10 @@ from dataclasses import dataclass
|
|
2 |
|
3 |
import torch
|
4 |
|
|
|
|
|
|
|
|
|
5 |
|
6 |
@dataclass
|
7 |
class Cube:
|
@@ -18,6 +22,8 @@ class Cube:
|
|
18 |
|
19 |
coordinates: torch.Tensor
|
20 |
state: torch.Tensor
|
|
|
|
|
21 |
colors: list[str]
|
22 |
size: int
|
23 |
|
@@ -28,66 +34,55 @@ class Cube:
|
|
28 |
Example:
|
29 |
cube = Cube.create(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
|
30 |
"""
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
tensor[0, :, :, n] = 1 # up
|
38 |
-
tensor[1, 0, :, :] = 2 # left
|
39 |
-
tensor[2, :, n, :] = 3 # front
|
40 |
-
tensor[3, n, :, :] = 4 # right
|
41 |
-
tensor[4, :, 0, :] = 5 # back
|
42 |
-
tensor[5, :, :, 0] = 6 # down
|
43 |
-
return cls.from_sparse(tensor.to_sparse(), colors, size)
|
44 |
-
|
45 |
-
def shuffle(self, num_moves: int):
|
46 |
-
raise NotImplementedError
|
47 |
|
48 |
-
def
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
def
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
|
55 |
-
def pad_colors(colors: list[str]) -> list[str]:
|
56 |
"""
|
57 |
-
|
58 |
"""
|
59 |
-
|
60 |
-
|
|
|
|
|
61 |
|
62 |
-
|
63 |
-
def from_sparse(cls, tensor: torch.Tensor, colors: list[str], size: int) -> "Cube":
|
64 |
"""
|
65 |
-
|
66 |
"""
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
70 |
|
71 |
-
def
|
72 |
"""
|
73 |
-
|
74 |
"""
|
75 |
-
|
76 |
-
indices=self.coordinates.transpose(0, 1),
|
77 |
-
values=self.state,
|
78 |
-
size=(6, self.size, self.size, self.size),
|
79 |
-
dtype=torch.int8,
|
80 |
-
)
|
81 |
|
82 |
def __str__(self):
|
83 |
"""
|
84 |
Compute a string representation of a cube.
|
85 |
"""
|
86 |
-
|
87 |
-
faces = self.state.reshape(6, self.size, self.size).transpose(1, 2)
|
88 |
-
faces = [[[colors[i - 1] for i in row] for row in face.tolist()] for face in faces]
|
89 |
-
void = " " * max(len(c) for c in self.colors) * self.size
|
90 |
-
l1 = "\n".join(" ".join([void, "".join(row), void, void]) for row in faces[0])
|
91 |
-
l2 = "\n".join(" ".join("".join(face[i]) for face in faces[1:5]) for i in range(self.size))
|
92 |
-
l3 = "\n".join(" ".join((void, "".join(row), void, void)) for row in faces[-1])
|
93 |
-
return "\n".join([l1, l2, l3])
|
|
|
2 |
|
3 |
import torch
|
4 |
|
5 |
+
from rubik.action import build_actions_tensor, parse_action_str, sample_actions_str
|
6 |
+
from rubik.display import stringify
|
7 |
+
from rubik.tensor_utils import build_cube_tensor
|
8 |
+
|
9 |
|
10 |
@dataclass
|
11 |
class Cube:
|
|
|
22 |
|
23 |
coordinates: torch.Tensor
|
24 |
state: torch.Tensor
|
25 |
+
actions: torch.Tensor
|
26 |
+
history: list[list[int]]
|
27 |
colors: list[str]
|
28 |
size: int
|
29 |
|
|
|
34 |
Example:
|
35 |
cube = Cube.create(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
|
36 |
"""
|
37 |
+
tensor = build_cube_tensor(colors, size)
|
38 |
+
coordinates = tensor.indices().transpose(0, 1).to(torch.int8)
|
39 |
+
values = tensor.values()
|
40 |
+
actions = build_actions_tensor(size)
|
41 |
+
history: list[list[int]] = []
|
42 |
+
return cls(coordinates, values, actions, history, colors, size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
+
def reset_history(self) -> None:
|
45 |
+
"""
|
46 |
+
Reset internal history of moves.
|
47 |
+
"""
|
48 |
+
self.history = []
|
49 |
+
return
|
50 |
|
51 |
+
def shuffle(self, num_moves: int, seed: int = 0) -> None:
|
52 |
+
"""
|
53 |
+
Randomly shuffle the cube by the supplied number of steps, and reset history of moves.
|
54 |
+
"""
|
55 |
+
moves = sample_actions_str(num_moves, self.size, seed=seed)
|
56 |
+
self.rotate(moves)
|
57 |
+
self.reset_history()
|
58 |
+
return
|
59 |
|
60 |
+
def rotate(self, moves: str) -> None:
|
|
|
61 |
"""
|
62 |
+
Apply a sequence of moves (defined as plain string) to the cube.
|
63 |
"""
|
64 |
+
actions = [parse_action_str(move) for move in moves.strip().split()]
|
65 |
+
for action in actions:
|
66 |
+
self.rotate_once(*action)
|
67 |
+
return
|
68 |
|
69 |
+
def rotate_once(self, axis: int, slice: int, inverse: int) -> None:
|
|
|
70 |
"""
|
71 |
+
Apply a move (defined as 3 coordinates) to the cube.
|
72 |
"""
|
73 |
+
action = self.actions[axis, slice, inverse]
|
74 |
+
self.state = action @ self.state
|
75 |
+
self.history.append([axis, slice, inverse])
|
76 |
+
return
|
77 |
|
78 |
+
def solve(self, policy: str) -> None:
|
79 |
"""
|
80 |
+
Apply the specified solving policy to the cube.
|
81 |
"""
|
82 |
+
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
def __str__(self):
|
85 |
"""
|
86 |
Compute a string representation of a cube.
|
87 |
"""
|
88 |
+
return stringify(self)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/rubik/display.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def stringify(cube) -> str:
|
2 |
+
"""
|
3 |
+
Compute a string representation of a cube.
|
4 |
+
"""
|
5 |
+
colors = pad_colors(cube.colors)
|
6 |
+
faces = cube.state.reshape(6, cube.size, cube.size).transpose(1, 2)
|
7 |
+
faces = [[[colors[i - 1] for i in row] for row in face.tolist()] for face in faces]
|
8 |
+
space = " " * max(len(c) for c in cube.colors) * cube.size
|
9 |
+
l1 = "\n".join(" ".join([space, "".join(row), space, space]) for row in faces[0])
|
10 |
+
l2 = "\n".join(" ".join("".join(face[i]) for face in faces[1:5]) for i in range(cube.size))
|
11 |
+
l3 = "\n".join(" ".join((space, "".join(row), space, space)) for row in faces[-1])
|
12 |
+
return "\n".join([l1, l2, l3])
|
13 |
+
|
14 |
+
|
15 |
+
def pad_colors(colors: list[str]) -> list[str]:
|
16 |
+
"""
|
17 |
+
Pad color names to strings of equal length.
|
18 |
+
"""
|
19 |
+
max_len = max(len(c) for c in colors)
|
20 |
+
return [c + " " * (max_len - len(c)) for c in colors]
|
src/rubik/tensor_utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def build_cube_tensor(colors: list[str], size: int) -> torch.Tensor:
|
5 |
+
"""
|
6 |
+
Convert a list of 6 colors and size into a sparse 4D tensor representing a cube.
|
7 |
+
"""
|
8 |
+
assert (num := len(set(colors))) == 6, f"Expected 6 distinct colors, got {num}"
|
9 |
+
assert isinstance(size, int) and size > 1, f"Expected non-zero integrer size, got {size}"
|
10 |
+
|
11 |
+
# build dense tensor filled with colors
|
12 |
+
n = size - 1
|
13 |
+
tensor = torch.zeros([6, size, size, size], dtype=torch.int8)
|
14 |
+
tensor[0, :, :, n] = 1 # up
|
15 |
+
tensor[1, 0, :, :] = 2 # left
|
16 |
+
tensor[2, :, n, :] = 3 # front
|
17 |
+
tensor[3, n, :, :] = 4 # right
|
18 |
+
tensor[4, :, 0, :] = 5 # back
|
19 |
+
tensor[5, :, :, 0] = 6 # down
|
20 |
+
return tensor.to_sparse()
|
21 |
+
|
22 |
+
|
23 |
+
def build_permutation_matrix(size: int, perm: str) -> torch.Tensor:
|
24 |
+
"""
|
25 |
+
Convert a permutation sting into a sparse 2D matrix.
|
26 |
+
"""
|
27 |
+
perm_list = [int(p) for p in (perm + perm[0])]
|
28 |
+
perm_dict = {perm_list[i]: perm_list[i + 1] for i in range(len(perm))}
|
29 |
+
indices = torch.tensor([list(range(size)), [(perm_dict.get(i, i)) for i in range(size)]], dtype=torch.int8)
|
30 |
+
values = torch.tensor([1] * size, dtype=torch.int8)
|
31 |
+
return torch.sparse_coo_tensor(indices=indices, values=values, size=(size, size), dtype=torch.int8)
|