import re import numpy as np import torch import torch.nn.functional as F from rubik.tensor_utils import build_permutation_matrix, build_cube_tensor POS_ROTATIONS = torch.stack( [ # rot about X: Z -> Y torch.tensor( [ [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0], ], dtype=torch.int16, ), # rot about Y: X -> Z torch.tensor( [ [1, 0, 0, 0], [0, 0, 0, -1], [0, 0, 1, 0], [0, 1, 0, 0], ], dtype=torch.int16, ), # rot about Z: Y -> X torch.tensor( [ [1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1], ], dtype=torch.int16, ), ] ) POS_SHIFTS = torch.tensor( [ [0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 1, 0], ], dtype=torch.int16, ) # rotation about X axis: 0 (Up) -> 2 (Front) -> 5 (Down) -> 4 (Back) -> 0 (Up) # rotation about Y axis: 0 (Up) -> 1 (Left) -> 5 (Down) -> 3 (Right) -> 0 (Up) # rotation about Z axis: 1 (Left) -> 2 (Front) -> 3 (Right) -> 4 (Back) -> 1 (Left) FACE_ROTATIONS = torch.stack( [ build_permutation_matrix(size=6, perm="0254"), build_permutation_matrix(size=6, perm="0153"), build_permutation_matrix(size=6, perm="1234"), ] ) def build_actions_tensor(size: int) -> torch.Tensor: """ Built the 5D tensor carrying all rotations of a cube as matrix multiplication. """ return torch.stack( [ build_action_tensor(size=size, axis=axis, slice=slice, inverse=inverse) for axis in range(3) for slice in range(size) for inverse in range(2) ], dim=0, ).sum(dim=0, dtype=torch.int16) def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor: """ Compute the sparse permutation tensor whose effect on a position-frozen color vector is the rotation along the specified axis, within the specified slice and the specified orientation. """ tensor = build_cube_tensor(colors=list("ULCRBD"), size=size) length = 6 * (size**2) # extract faces impacted by the move indices = tensor.indices().to(dtype=torch.int16) # size = (4, length) changes = (indices[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length extract = indices[:, changes] # size = (4, n) # apply coordinate rotation rotated = POS_ROTATIONS[axis] @ extract # size = (4, n) offsets = (POS_SHIFTS[axis] * (size - 1)).repeat(extract.shape[-1], 1).transpose(0, 1) # size = (4, n) rotated = rotated + offsets # size = (4, n) # apply face rotation rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(torch.int16) @ FACE_ROTATIONS[axis]).argmax(dim=-1) # from this point on, convert rotation into a position-based permutation of colors (inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated) inputs = inputs.transpose(0, 1).tolist() # size = (n, 4) outputs = outputs.transpose(0, 1).tolist() # size = (n, 4) # compute position-based permutation of colors equivalent to rotation converting inputs into outputs local_to_total = dict(enumerate(changes.tolist())) total_to_local = {ind: i for i, ind in local_to_total.items()} local_perm = {i: inputs.index(outputs[i]) for i in range(len(inputs))} total_perm = { i: (i if i not in total_to_local else local_to_total[local_perm[total_to_local[i]]]) for i in range(length) } # convert permutation dict into sparse tensor perm_indices = torch.tensor( [[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())], dtype=torch.int16, ) perm_values = torch.tensor([1] * length, dtype=torch.int16) perm_size = (3, size, 2, length, length) return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.int16) def parse_action_str(move: str) -> tuple[int, int, int]: """ Convert the name of an action into a triple (axis, slice, inverse). Examples: 'X1' -> (0, 1, 0) 'X2i' -> (0, 2, 1) """ axis = "XYZ".index(move[0]) slice = int(re.findall(r"^\d+", move[1:])[0]) inverse = int(len(move) > (1 + len(str(slice)))) return (axis, slice, inverse) def parse_actions_str(moves: str) -> list[tuple[int, int, int]]: """ Convert a sequence of actions in a string into a list of triples (axis, slice, inverse). Examples: 'X1 X2i' -> [(0, 1, 0), (0, 2, 1)] """ return [parse_action_str(move) for move in moves.strip().split()] def sample_actions_str(num_moves: int, size: int, seed: int = 0) -> str: """ Generate a string containing moves that are randomly sampled. """ rng = np.random.default_rng(seed=seed) axes = rng.choice(["X", "Y", "Z"], size=num_moves) slices = rng.choice([str(i) for i in range(size)], size=num_moves) orients = rng.choice(["", "i"], size=num_moves) return " ".join("".join(move) for move in zip(axes, slices, orients))