Spaces:
Sleeping
Sleeping
File size: 5,351 Bytes
3e056bc 580f78c 569ff0b 580f78c 569ff0b 7011a7d 569ff0b 7011a7d 569ff0b 7011a7d 569ff0b 7583934 569ff0b 7011a7d 569ff0b 580f78c 569ff0b 580f78c 569ff0b 580f78c 569ff0b 580f78c 569ff0b 7011a7d 569ff0b 580f78c 569ff0b 580f78c 569ff0b 7011a7d 580f78c 569ff0b 7583934 569ff0b 7011a7d 569ff0b 5e5c08a 580f78c d756230 569ff0b d756230 569ff0b 5e5c08a 569ff0b d756230 7011a7d 569ff0b 7011a7d 569ff0b 7011a7d 580f78c 3e056bc 580f78c 3e056bc b24b865 3e056bc b24b865 580f78c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import re
import numpy as np
import torch
import torch.nn.functional as F
from rubik.tensor_utils import build_permutation_matrix, build_cube_tensor
POS_ROTATIONS = torch.stack(
[
# rot about X: Z -> Y
torch.tensor(
[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
[0, 0, -1, 0],
],
dtype=torch.int16,
),
# rot about Y: X -> Z
torch.tensor(
[
[1, 0, 0, 0],
[0, 0, 0, -1],
[0, 0, 1, 0],
[0, 1, 0, 0],
],
dtype=torch.int16,
),
# rot about Z: Y -> X
torch.tensor(
[
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, -1, 0, 0],
[0, 0, 0, 1],
],
dtype=torch.int16,
),
]
)
POS_SHIFTS = torch.tensor(
[
[0, 0, 0, 1],
[0, 1, 0, 0],
[0, 0, 1, 0],
],
dtype=torch.int16,
)
# rotation about X axis: 0 (Up) -> 2 (Front) -> 5 (Down) -> 4 (Back) -> 0 (Up)
# rotation about Y axis: 0 (Up) -> 1 (Left) -> 5 (Down) -> 3 (Right) -> 0 (Up)
# rotation about Z axis: 1 (Left) -> 2 (Front) -> 3 (Right) -> 4 (Back) -> 1 (Left)
FACE_ROTATIONS = torch.stack(
[
build_permutation_matrix(size=6, perm="0254"),
build_permutation_matrix(size=6, perm="0153"),
build_permutation_matrix(size=6, perm="1234"),
]
)
def build_actions_tensor(size: int) -> torch.Tensor:
"""
Built the 5D tensor carrying all rotations of a cube as matrix multiplication.
"""
return torch.stack(
[
build_action_tensor(size=size, axis=axis, slice=slice, inverse=inverse)
for axis in range(3)
for slice in range(size)
for inverse in range(2)
],
dim=0,
).sum(dim=0, dtype=torch.int16)
def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor:
"""
Compute the sparse permutation tensor whose effect on a position-frozen color vector
is the rotation along the specified axis, within the specified slice and the specified
orientation.
"""
tensor = build_cube_tensor(colors=list("ULCRBD"), size=size)
length = 6 * (size**2)
# extract faces impacted by the move
indices = tensor.indices().to(dtype=torch.int16) # size = (4, length)
changes = (indices[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
extract = indices[:, changes] # size = (4, n)
# apply coordinate rotation
rotated = POS_ROTATIONS[axis] @ extract # size = (4, n)
offsets = (POS_SHIFTS[axis] * (size - 1)).repeat(extract.shape[-1], 1).transpose(0, 1) # size = (4, n)
rotated = rotated + offsets # size = (4, n)
# apply face rotation
rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(torch.int16) @ FACE_ROTATIONS[axis]).argmax(dim=-1)
# from this point on, convert rotation into a position-based permutation of colors
(inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
inputs = inputs.transpose(0, 1).tolist() # size = (n, 4)
outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
# compute position-based permutation of colors equivalent to rotation converting inputs into outputs
local_to_total = dict(enumerate(changes.tolist()))
total_to_local = {ind: i for i, ind in local_to_total.items()}
local_perm = {i: inputs.index(outputs[i]) for i in range(len(inputs))}
total_perm = {
i: (i if i not in total_to_local else local_to_total[local_perm[total_to_local[i]]]) for i in range(length)
}
# convert permutation dict into sparse tensor
perm_indices = torch.tensor(
[[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
dtype=torch.int16,
)
perm_values = torch.tensor([1] * length, dtype=torch.int16)
perm_size = (3, size, 2, length, length)
return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.int16)
def parse_action_str(move: str) -> tuple[int, int, int]:
"""
Convert the name of an action into a triple (axis, slice, inverse).
Examples:
'X1' -> (0, 1, 0)
'X2i' -> (0, 2, 1)
"""
axis = "XYZ".index(move[0])
slice = int(re.findall(r"^\d+", move[1:])[0])
inverse = int(len(move) > (1 + len(str(slice))))
return (axis, slice, inverse)
def parse_actions_str(moves: str) -> list[tuple[int, int, int]]:
"""
Convert a sequence of actions in a string into a list of triples (axis, slice, inverse).
Examples:
'X1 X2i' -> [(0, 1, 0), (0, 2, 1)]
"""
return [parse_action_str(move) for move in moves.strip().split()]
def sample_actions_str(num_moves: int, size: int, seed: int = 0) -> str:
"""
Generate a string containing moves that are randomly sampled.
"""
rng = np.random.default_rng(seed=seed)
axes = rng.choice(["X", "Y", "Z"], size=num_moves)
slices = rng.choice([str(i) for i in range(size)], size=num_moves)
orients = rng.choice(["", "i"], size=num_moves)
return " ".join("".join(move) for move in zip(axes, slices, orients))
|