File size: 5,904 Bytes
ff5d990
 
 
 
 
 
 
8f96832
68f8c07
 
27be29d
53d3965
 
 
ff5d990
 
 
 
4145e1a
 
 
ff5d990
 
4145e1a
ff5d990
 
 
 
 
cb37bd4
 
 
 
 
 
 
 
 
ff5d990
 
 
8f96832
 
 
ff5d990
 
 
8f96832
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68f8c07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27be29d
68f8c07
 
 
 
53d3965
 
 
 
 
 
 
 
 
 
27be29d
53d3965
 
 
27be29d
 
53d3965
27be29d
53d3965
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import pytest
from typing import Iterable

import torch

from rubik.action import (
    POS_ROTATIONS,
    POS_SHIFTS,
    FACE_ROTATIONS,
    build_actions_tensor,
    build_action_permutation,
    parse_action_str,
    parse_actions_str,
    sample_actions_str,
)


def test_position_rotation_shape():
    """
    Test that POS_ROTATIONS has expected shape.
    """
    expected = (3, 4, 4)
    observed = POS_ROTATIONS.shape
    assert expected == observed, f"Position rotation tensor expected shape '{expected}', got '{observed}' instead"


@pytest.mark.parametrize(
    "axis, input, expected",
    [
        (0, (1, 1, 0, 0), (1, 1, 0, 0)),  # X -> X
        (0, (1, 0, 1, 0), (1, 0, 0, -1)),  # Y -> -Z
        (0, (1, 0, 0, 1), (1, 0, 1, 0)),  # Z -> Y
        (1, (1, 1, 0, 0), (1, 0, 0, 1)),  # X -> Z
        (1, (1, 0, 1, 0), (1, 0, 1, 0)),  # Y -> Y
        (1, (1, 0, 0, 1), (1, -1, 0, 0)),  # Z -> -X
        (2, (1, 1, 0, 0), (1, 0, -1, 0)),  # X -> -Y
        (2, (1, 0, 1, 0), (1, 1, 0, 0)),  # Y -> X
        (2, (1, 0, 0, 1), (1, 0, 0, 1)),  # Z -> Z
    ],
)
def test_position_rotation(axis: int, input: Iterable[int], expected: Iterable[int]):
    """
    Test that POS_ROTATIONS behaves as expected.
    """
    out = POS_ROTATIONS[axis] @ torch.tensor(input, dtype=POS_ROTATIONS.dtype)
    exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype)
    assert torch.equal(out, exp), f"Position rotation tensor is incorrect along axis {axis}: {out} != {exp}"


@pytest.mark.parametrize(
    "axis, size, input, expected",
    [
        (0, 3, (1, 1, 1, 1), (1, 1, 1, 0)),
        (1, 3, (1, 1, 1, 1), (1, 0, 1, 1)),
        (2, 3, (1, 1, 1, 1), (1, 1, 0, 1)),
    ],
)
def test_position_shift(axis: int, size: int, input: Iterable[int], expected: Iterable[int]):
    """
    Test that POS_SHIFTS behaves as expected.
    """
    rot = POS_ROTATIONS[axis] @ (torch.tensor(input, dtype=POS_ROTATIONS.dtype) * (size - 1))
    out = rot + (POS_SHIFTS[axis] * (size - 1))
    exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype) * (size - 1)
    assert torch.equal(out, exp), f"Position shift tensor is incorrect along axis {axis}: {out} != {exp}"


def test_face_rotation_shape():
    """
    Test that FACE_ROTATIONS has expected shape.
    """
    expected = (3, 6, 6)
    observed = FACE_ROTATIONS.shape
    assert expected == observed, f"Face rotation tensor expected shape '{expected}', got '{observed}' instead"


@pytest.mark.parametrize(
    "axis, input, expected",
    [
        (0, (1, 0, 0, 0, 0, 0), (0, 0, 1, 0, 0, 0)),  # rotation about X axis: 0 (Up) -> 2 (Front)
        (1, (1, 0, 0, 0, 0, 0), (0, 1, 0, 0, 0, 0)),  # rotation about Y axis: 0 (Up) -> 1 (Left)
        (2, (0, 1, 0, 0, 0, 0), (0, 0, 1, 0, 0, 0)),  # rotation about Z axis: 1 (Left) -> 2 (Front)
    ],
)
def test_face_rotation(axis: int, input: Iterable[int], expected: Iterable[int]):
    """
    Test that POS_ROTATIONS behaves as expected.
    """
    out = torch.tensor(input, dtype=FACE_ROTATIONS.dtype) @ FACE_ROTATIONS[axis]
    exp = torch.tensor(expected, dtype=FACE_ROTATIONS.dtype)
    assert torch.equal(out, exp), f"Face rotation tensor is incorrect along axis {axis}: {out} != {exp}"


@pytest.mark.parametrize("size", [2, 3, 5, 20])
def test_build_actions_tensor_shape(size: int):
    """
    Test that "build_actions_tensor" output has expected shape.
    """
    expected = (3, size, 2, 6 * (size**2))
    observed = build_actions_tensor(size).shape
    assert expected == observed, (
        f"'build_actions_tensor' output has incorrect shape: expected shape '{expected}', got '{observed}' instead"
    )


@pytest.mark.parametrize(
    "size, axis, slice, inverse",
    [
        (2, 2, 1, 0),
        (3, 0, 1, 1),
        (5, 1, 4, 0),
    ],
)
def test_build_action_permutation(size: int, axis: int, slice: int, inverse: int):
    """
    Test that "build_actions_tensor" output has expected shape.
    """
    expected = 6 * (size**2)
    observed = len(build_action_permutation(size, axis, slice, inverse))
    assert expected == observed, (
        f"'build_action_tensor' output has incorrect length: expected length '{expected}', got '{observed}'"
    )


@pytest.mark.parametrize(
    "move, expected",
    [
        ["X1", (0, 1, 0)],
        ["X25i", (0, 25, 1)],
        ["Y0", (1, 0, 0)],
        ["Y5i", (1, 5, 1)],
        ["Z30", (2, 30, 0)],
        ["Z512ijk", (2, 512, 1)],
    ],
)
def test_parse_action_str(move: str, expected: tuple[int, int, int]):
    """
    Test that "parse_action_str" behaves as expected.
    """
    observed = parse_action_str(move)
    assert expected == observed, (
        f"'parse_action_str' output is incorrect: expected '{expected}', got '{observed}' instead"
    )


@pytest.mark.parametrize(
    "moves, expected",
    [
        ["  X1 Y0 X25i Z512ijk Z30 Y5i ", [(0, 1, 0), (1, 0, 0), (0, 25, 1), (2, 512, 1), (2, 30, 0), (1, 5, 1)]],
    ],
)
def test_parse_actions_str(moves: str, expected: tuple[int, int, int]):
    """
    Test that "parse_action_str" behaves as expected.
    """
    observed = parse_actions_str(moves)
    assert expected == observed, (
        f"'parse_actions_str' output is incorrect: expected '{expected}', got '{observed}' instead"
    )


@pytest.mark.parametrize(
    "num_moves, size, seed",
    [
        [1, 3, 0],
        [1, 20, 42],
        [256, 5, 21],
    ],
)
def test_sample_actions_str(num_moves: int, size: int, seed: int):
    """
    Test that "sample_actions_str" is deterministic and outputs parsable content.
    """
    moves_1 = sample_actions_str(num_moves, size, seed)
    moves_2 = sample_actions_str(num_moves, size, seed)
    assert moves_1 == moves_2, f"'sample_actions_str' is non-deterministic: {moves_1} != {moves_2}"

    parsed = parse_actions_str(moves_1)
    assert len(parsed) == len(moves_1.split()), "'sample_actions_str' output cannot be parsed correctly"