Spaces:
Sleeping
Sleeping
File size: 5,904 Bytes
ff5d990 8f96832 68f8c07 27be29d 53d3965 ff5d990 4145e1a ff5d990 4145e1a ff5d990 cb37bd4 ff5d990 8f96832 ff5d990 8f96832 68f8c07 27be29d 68f8c07 53d3965 27be29d 53d3965 27be29d 53d3965 27be29d 53d3965 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import pytest
from typing import Iterable
import torch
from rubik.action import (
POS_ROTATIONS,
POS_SHIFTS,
FACE_ROTATIONS,
build_actions_tensor,
build_action_permutation,
parse_action_str,
parse_actions_str,
sample_actions_str,
)
def test_position_rotation_shape():
"""
Test that POS_ROTATIONS has expected shape.
"""
expected = (3, 4, 4)
observed = POS_ROTATIONS.shape
assert expected == observed, f"Position rotation tensor expected shape '{expected}', got '{observed}' instead"
@pytest.mark.parametrize(
"axis, input, expected",
[
(0, (1, 1, 0, 0), (1, 1, 0, 0)), # X -> X
(0, (1, 0, 1, 0), (1, 0, 0, -1)), # Y -> -Z
(0, (1, 0, 0, 1), (1, 0, 1, 0)), # Z -> Y
(1, (1, 1, 0, 0), (1, 0, 0, 1)), # X -> Z
(1, (1, 0, 1, 0), (1, 0, 1, 0)), # Y -> Y
(1, (1, 0, 0, 1), (1, -1, 0, 0)), # Z -> -X
(2, (1, 1, 0, 0), (1, 0, -1, 0)), # X -> -Y
(2, (1, 0, 1, 0), (1, 1, 0, 0)), # Y -> X
(2, (1, 0, 0, 1), (1, 0, 0, 1)), # Z -> Z
],
)
def test_position_rotation(axis: int, input: Iterable[int], expected: Iterable[int]):
"""
Test that POS_ROTATIONS behaves as expected.
"""
out = POS_ROTATIONS[axis] @ torch.tensor(input, dtype=POS_ROTATIONS.dtype)
exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype)
assert torch.equal(out, exp), f"Position rotation tensor is incorrect along axis {axis}: {out} != {exp}"
@pytest.mark.parametrize(
"axis, size, input, expected",
[
(0, 3, (1, 1, 1, 1), (1, 1, 1, 0)),
(1, 3, (1, 1, 1, 1), (1, 0, 1, 1)),
(2, 3, (1, 1, 1, 1), (1, 1, 0, 1)),
],
)
def test_position_shift(axis: int, size: int, input: Iterable[int], expected: Iterable[int]):
"""
Test that POS_SHIFTS behaves as expected.
"""
rot = POS_ROTATIONS[axis] @ (torch.tensor(input, dtype=POS_ROTATIONS.dtype) * (size - 1))
out = rot + (POS_SHIFTS[axis] * (size - 1))
exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype) * (size - 1)
assert torch.equal(out, exp), f"Position shift tensor is incorrect along axis {axis}: {out} != {exp}"
def test_face_rotation_shape():
"""
Test that FACE_ROTATIONS has expected shape.
"""
expected = (3, 6, 6)
observed = FACE_ROTATIONS.shape
assert expected == observed, f"Face rotation tensor expected shape '{expected}', got '{observed}' instead"
@pytest.mark.parametrize(
"axis, input, expected",
[
(0, (1, 0, 0, 0, 0, 0), (0, 0, 1, 0, 0, 0)), # rotation about X axis: 0 (Up) -> 2 (Front)
(1, (1, 0, 0, 0, 0, 0), (0, 1, 0, 0, 0, 0)), # rotation about Y axis: 0 (Up) -> 1 (Left)
(2, (0, 1, 0, 0, 0, 0), (0, 0, 1, 0, 0, 0)), # rotation about Z axis: 1 (Left) -> 2 (Front)
],
)
def test_face_rotation(axis: int, input: Iterable[int], expected: Iterable[int]):
"""
Test that POS_ROTATIONS behaves as expected.
"""
out = torch.tensor(input, dtype=FACE_ROTATIONS.dtype) @ FACE_ROTATIONS[axis]
exp = torch.tensor(expected, dtype=FACE_ROTATIONS.dtype)
assert torch.equal(out, exp), f"Face rotation tensor is incorrect along axis {axis}: {out} != {exp}"
@pytest.mark.parametrize("size", [2, 3, 5, 20])
def test_build_actions_tensor_shape(size: int):
"""
Test that "build_actions_tensor" output has expected shape.
"""
expected = (3, size, 2, 6 * (size**2))
observed = build_actions_tensor(size).shape
assert expected == observed, (
f"'build_actions_tensor' output has incorrect shape: expected shape '{expected}', got '{observed}' instead"
)
@pytest.mark.parametrize(
"size, axis, slice, inverse",
[
(2, 2, 1, 0),
(3, 0, 1, 1),
(5, 1, 4, 0),
],
)
def test_build_action_permutation(size: int, axis: int, slice: int, inverse: int):
"""
Test that "build_actions_tensor" output has expected shape.
"""
expected = 6 * (size**2)
observed = len(build_action_permutation(size, axis, slice, inverse))
assert expected == observed, (
f"'build_action_tensor' output has incorrect length: expected length '{expected}', got '{observed}'"
)
@pytest.mark.parametrize(
"move, expected",
[
["X1", (0, 1, 0)],
["X25i", (0, 25, 1)],
["Y0", (1, 0, 0)],
["Y5i", (1, 5, 1)],
["Z30", (2, 30, 0)],
["Z512ijk", (2, 512, 1)],
],
)
def test_parse_action_str(move: str, expected: tuple[int, int, int]):
"""
Test that "parse_action_str" behaves as expected.
"""
observed = parse_action_str(move)
assert expected == observed, (
f"'parse_action_str' output is incorrect: expected '{expected}', got '{observed}' instead"
)
@pytest.mark.parametrize(
"moves, expected",
[
[" X1 Y0 X25i Z512ijk Z30 Y5i ", [(0, 1, 0), (1, 0, 0), (0, 25, 1), (2, 512, 1), (2, 30, 0), (1, 5, 1)]],
],
)
def test_parse_actions_str(moves: str, expected: tuple[int, int, int]):
"""
Test that "parse_action_str" behaves as expected.
"""
observed = parse_actions_str(moves)
assert expected == observed, (
f"'parse_actions_str' output is incorrect: expected '{expected}', got '{observed}' instead"
)
@pytest.mark.parametrize(
"num_moves, size, seed",
[
[1, 3, 0],
[1, 20, 42],
[256, 5, 21],
],
)
def test_sample_actions_str(num_moves: int, size: int, seed: int):
"""
Test that "sample_actions_str" is deterministic and outputs parsable content.
"""
moves_1 = sample_actions_str(num_moves, size, seed)
moves_2 = sample_actions_str(num_moves, size, seed)
assert moves_1 == moves_2, f"'sample_actions_str' is non-deterministic: {moves_1} != {moves_2}"
parsed = parse_actions_str(moves_1)
assert len(parsed) == len(moves_1.split()), "'sample_actions_str' output cannot be parsed correctly"
|