Spaces:
Sleeping
Sleeping
import re | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from rubik.tensor_utils import build_permutation_matrix, build_cube_tensor | |
POS_ROTATIONS = torch.stack( | |
[ | |
# rot about X: Z -> Y | |
torch.tensor( | |
[ | |
[1, 0, 0, 0], | |
[0, 1, 0, 0], | |
[0, 0, 0, 1], | |
[0, 0, -1, 0], | |
], | |
dtype=torch.int16, | |
), | |
# rot about Y: X -> Z | |
torch.tensor( | |
[ | |
[1, 0, 0, 0], | |
[0, 0, 0, -1], | |
[0, 0, 1, 0], | |
[0, 1, 0, 0], | |
], | |
dtype=torch.int16, | |
), | |
# rot about Z: Y -> X | |
torch.tensor( | |
[ | |
[1, 0, 0, 0], | |
[0, 0, 1, 0], | |
[0, -1, 0, 0], | |
[0, 0, 0, 1], | |
], | |
dtype=torch.int16, | |
), | |
] | |
) | |
POS_SHIFTS = torch.tensor( | |
[ | |
[0, 0, 0, 1], | |
[0, 1, 0, 0], | |
[0, 0, 1, 0], | |
], | |
dtype=torch.int16, | |
) | |
# rotation about X axis: 0 (Up) -> 2 (Front) -> 5 (Down) -> 4 (Back) -> 0 (Up) | |
# rotation about Y axis: 0 (Up) -> 1 (Left) -> 5 (Down) -> 3 (Right) -> 0 (Up) | |
# rotation about Z axis: 1 (Left) -> 2 (Front) -> 3 (Right) -> 4 (Back) -> 1 (Left) | |
FACE_ROTATIONS = torch.stack( | |
[ | |
build_permutation_matrix(size=6, perm="0254"), | |
build_permutation_matrix(size=6, perm="0153"), | |
build_permutation_matrix(size=6, perm="1234"), | |
] | |
) | |
def build_actions_tensor(size: int) -> torch.Tensor: | |
""" | |
Built the 5D tensor carrying all rotations of a cube as matrix multiplication. | |
""" | |
return torch.stack( | |
[ | |
build_action_tensor(size=size, axis=axis, slice=slice, inverse=inverse) | |
for axis in range(3) | |
for slice in range(size) | |
for inverse in range(2) | |
], | |
dim=0, | |
).sum(dim=0, dtype=torch.int16) | |
def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor: | |
""" | |
Compute the sparse permutation tensor whose effect on a position-frozen color vector | |
is the rotation along the specified axis, within the specified slice and the specified | |
orientation. | |
""" | |
tensor = build_cube_tensor(colors=list("ULCRBD"), size=size) | |
length = 6 * (size**2) | |
# extract faces impacted by the move | |
indices = tensor.indices().to(dtype=torch.int16) # size = (4, length) | |
changes = (indices[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length | |
extract = indices[:, changes] # size = (4, n) | |
# apply coordinate rotation | |
rotated = POS_ROTATIONS[axis] @ extract # size = (4, n) | |
offsets = (POS_SHIFTS[axis] * (size - 1)).repeat(extract.shape[-1], 1).transpose(0, 1) # size = (4, n) | |
rotated = rotated + offsets # size = (4, n) | |
# apply face rotation | |
rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(torch.int16) @ FACE_ROTATIONS[axis]).argmax(dim=-1) | |
# from this point on, convert rotation into a position-based permutation of colors | |
(inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated) | |
inputs = inputs.transpose(0, 1).tolist() # size = (n, 4) | |
outputs = outputs.transpose(0, 1).tolist() # size = (n, 4) | |
# compute position-based permutation of colors equivalent to rotation converting inputs into outputs | |
local_to_total = dict(enumerate(changes.tolist())) | |
total_to_local = {ind: i for i, ind in local_to_total.items()} | |
local_perm = {i: inputs.index(outputs[i]) for i in range(len(inputs))} | |
total_perm = { | |
i: (i if i not in total_to_local else local_to_total[local_perm[total_to_local[i]]]) for i in range(length) | |
} | |
# convert permutation dict into sparse tensor | |
perm_indices = torch.tensor( | |
[[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())], | |
dtype=torch.int16, | |
) | |
perm_values = torch.tensor([1] * length, dtype=torch.int16) | |
perm_size = (3, size, 2, length, length) | |
return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.int16) | |
def parse_action_str(move: str) -> tuple[int, int, int]: | |
""" | |
Convert the name of an action into a triple (axis, slice, inverse). | |
Examples: | |
'X1' -> (0, 1, 0) | |
'X2i' -> (0, 2, 1) | |
""" | |
axis = "XYZ".index(move[0]) | |
slice = int(re.findall(r"^\d+", move[1:])[0]) | |
inverse = int(len(move) > (1 + len(str(slice)))) | |
return (axis, slice, inverse) | |
def parse_actions_str(moves: str) -> list[tuple[int, int, int]]: | |
""" | |
Convert a sequence of actions in a string into a list of triples (axis, slice, inverse). | |
Examples: | |
'X1 X2i' -> [(0, 1, 0), (0, 2, 1)] | |
""" | |
return [parse_action_str(move) for move in moves.strip().split()] | |
def sample_actions_str(num_moves: int, size: int, seed: int = 0) -> str: | |
""" | |
Generate a string containing moves that are randomly sampled. | |
""" | |
rng = np.random.default_rng(seed=seed) | |
axes = rng.choice(["X", "Y", "Z"], size=num_moves) | |
slices = rng.choice([str(i) for i in range(size)], size=num_moves) | |
orients = rng.choice(["", "i"], size=num_moves) | |
return " ".join("".join(move) for move in zip(axes, slices, orients)) | |