Spaces:
Sleeping
Sleeping
File size: 2,021 Bytes
ff5d990 8f96832 ff5d990 4145e1a ff5d990 4145e1a ff5d990 cb37bd4 ff5d990 8f96832 ff5d990 8f96832 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import pytest
from typing import Iterable
import torch
from rubik.action import (
POS_ROTATIONS,
POS_SHIFTS,
)
def test_position_rotation_shape():
"""
Test that POS_ROTATIONS has expected shape.
"""
expected = (3, 4, 4)
observed = POS_ROTATIONS.shape
assert expected == observed, f"Position rotation tensor expected shape '{expected}', got '{observed}' instead"
@pytest.mark.parametrize(
"axis, input, expected",
[
(0, (1, 1, 0, 0), (1, 1, 0, 0)), # X -> X
(0, (1, 0, 1, 0), (1, 0, 0, -1)), # Y -> -Z
(0, (1, 0, 0, 1), (1, 0, 1, 0)), # Z -> Y
(1, (1, 1, 0, 0), (1, 0, 0, 1)), # X -> Z
(1, (1, 0, 1, 0), (1, 0, 1, 0)), # Y -> Y
(1, (1, 0, 0, 1), (1, -1, 0, 0)), # Z -> -X
(2, (1, 1, 0, 0), (1, 0, -1, 0)), # X -> -Y
(2, (1, 0, 1, 0), (1, 1, 0, 0)), # Y -> X
(2, (1, 0, 0, 1), (1, 0, 0, 1)), # Z -> Z
],
)
def test_position_rotation(axis: int, input: Iterable[int], expected: Iterable[int]):
"""
Test that POS_ROTATIONS behaves as expected.
"""
out = POS_ROTATIONS[axis] @ torch.tensor(input, dtype=POS_ROTATIONS.dtype)
exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype)
assert torch.equal(out, exp), f"Position rotation tensor is incorrect along axis {axis}: {out} != {exp}"
@pytest.mark.parametrize(
"axis, size, input, expected",
[
(0, 3, (1, 1, 1, 1), (1, 1, 1, 0)),
(1, 3, (1, 1, 1, 1), (1, 0, 1, 1)),
(2, 3, (1, 1, 1, 1), (1, 1, 0, 1)),
],
)
def test_position_shift(axis: int, size: int, input: Iterable[int], expected: Iterable[int]):
"""
Test that POS_SHIFTS behaves as expected.
"""
rot = POS_ROTATIONS[axis] @ (torch.tensor(input, dtype=POS_ROTATIONS.dtype) * (size - 1))
out = rot + (POS_SHIFTS[axis] * (size - 1))
exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype) * (size - 1)
assert torch.equal(out, exp), f"Position shift tensor is incorrect along axis {axis}: {out} != {exp}"
|