File size: 2,021 Bytes
ff5d990
 
 
 
 
 
 
8f96832
ff5d990
 
 
 
4145e1a
 
 
ff5d990
 
4145e1a
ff5d990
 
 
 
 
cb37bd4
 
 
 
 
 
 
 
 
ff5d990
 
 
8f96832
 
 
ff5d990
 
 
8f96832
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import pytest
from typing import Iterable

import torch

from rubik.action import (
    POS_ROTATIONS,
    POS_SHIFTS,
)


def test_position_rotation_shape():
    """
    Test that POS_ROTATIONS has expected shape.
    """
    expected = (3, 4, 4)
    observed = POS_ROTATIONS.shape
    assert expected == observed, f"Position rotation tensor expected shape '{expected}', got '{observed}' instead"


@pytest.mark.parametrize(
    "axis, input, expected",
    [
        (0, (1, 1, 0, 0), (1, 1, 0, 0)),  # X -> X
        (0, (1, 0, 1, 0), (1, 0, 0, -1)),  # Y -> -Z
        (0, (1, 0, 0, 1), (1, 0, 1, 0)),  # Z -> Y
        (1, (1, 1, 0, 0), (1, 0, 0, 1)),  # X -> Z
        (1, (1, 0, 1, 0), (1, 0, 1, 0)),  # Y -> Y
        (1, (1, 0, 0, 1), (1, -1, 0, 0)),  # Z -> -X
        (2, (1, 1, 0, 0), (1, 0, -1, 0)),  # X -> -Y
        (2, (1, 0, 1, 0), (1, 1, 0, 0)),  # Y -> X
        (2, (1, 0, 0, 1), (1, 0, 0, 1)),  # Z -> Z
    ],
)
def test_position_rotation(axis: int, input: Iterable[int], expected: Iterable[int]):
    """
    Test that POS_ROTATIONS behaves as expected.
    """
    out = POS_ROTATIONS[axis] @ torch.tensor(input, dtype=POS_ROTATIONS.dtype)
    exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype)
    assert torch.equal(out, exp), f"Position rotation tensor is incorrect along axis {axis}: {out} != {exp}"


@pytest.mark.parametrize(
    "axis, size, input, expected",
    [
        (0, 3, (1, 1, 1, 1), (1, 1, 1, 0)),
        (1, 3, (1, 1, 1, 1), (1, 0, 1, 1)),
        (2, 3, (1, 1, 1, 1), (1, 1, 0, 1)),
    ],
)
def test_position_shift(axis: int, size: int, input: Iterable[int], expected: Iterable[int]):
    """
    Test that POS_SHIFTS behaves as expected.
    """
    rot = POS_ROTATIONS[axis] @ (torch.tensor(input, dtype=POS_ROTATIONS.dtype) * (size - 1))
    out = rot + (POS_SHIFTS[axis] * (size - 1))
    exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype) * (size - 1)
    assert torch.equal(out, exp), f"Position shift tensor is incorrect along axis {axis}: {out} != {exp}"