Spaces:
Sleeping
Sleeping
import pytest | |
from typing import Iterable | |
import torch | |
from rubik.action import ( | |
POS_ROTATIONS, | |
POS_SHIFTS, | |
) | |
def test_position_rotation_shape(): | |
""" | |
Test that POS_ROTATIONS has expected shape. | |
""" | |
expected = (3, 4, 4) | |
observed = POS_ROTATIONS.shape | |
assert expected == observed, f"Position rotation tensor expected shape '{expected}', got '{observed}' instead" | |
def test_position_rotation(axis: int, input: Iterable[int], expected: Iterable[int]): | |
""" | |
Test that POS_ROTATIONS behaves as expected. | |
""" | |
out = POS_ROTATIONS[axis] @ torch.tensor(input, dtype=POS_ROTATIONS.dtype) | |
exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype) | |
assert torch.equal(out, exp), f"Position rotation tensor is incorrect along axis {axis}: {out} != {exp}" | |
def test_position_shift(axis: int, size: int, input: Iterable[int], expected: Iterable[int]): | |
""" | |
Test that POS_SHIFTS behaves as expected. | |
""" | |
rot = POS_ROTATIONS[axis] @ (torch.tensor(input, dtype=POS_ROTATIONS.dtype) * (size - 1)) | |
out = rot + (POS_SHIFTS[axis] * (size - 1)) | |
exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype) * (size - 1) | |
assert torch.equal(out, exp), f"Position shift tensor is incorrect along axis {axis}: {out} != {exp}" | |