Spaces:
Sleeping
Sleeping
| import pytest | |
| from typing import Iterable | |
| import torch | |
| from rubik.action import ( | |
| POS_ROTATIONS, | |
| POS_SHIFTS, | |
| FACE_ROTATIONS, | |
| build_actions_tensor, | |
| build_action_tensor, | |
| parse_action_str, | |
| parse_actions_str, | |
| sample_actions_str, | |
| ) | |
| def test_position_rotation_shape(): | |
| """ | |
| Test that POS_ROTATIONS has expected shape. | |
| """ | |
| expected = (3, 4, 4) | |
| observed = POS_ROTATIONS.shape | |
| assert expected == observed, f"Position rotation tensor expected shape '{expected}', got '{observed}' instead" | |
| def test_position_rotation(axis: int, input: Iterable[int], expected: Iterable[int]): | |
| """ | |
| Test that POS_ROTATIONS behaves as expected. | |
| """ | |
| out = POS_ROTATIONS[axis] @ torch.tensor(input, dtype=POS_ROTATIONS.dtype) | |
| exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype) | |
| assert torch.equal(out, exp), f"Position rotation tensor is incorrect along axis {axis}: {out} != {exp}" | |
| def test_position_shift(axis: int, size: int, input: Iterable[int], expected: Iterable[int]): | |
| """ | |
| Test that POS_SHIFTS behaves as expected. | |
| """ | |
| rot = POS_ROTATIONS[axis] @ (torch.tensor(input, dtype=POS_ROTATIONS.dtype) * (size - 1)) | |
| out = rot + (POS_SHIFTS[axis] * (size - 1)) | |
| exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype) * (size - 1) | |
| assert torch.equal(out, exp), f"Position shift tensor is incorrect along axis {axis}: {out} != {exp}" | |
| def test_face_rotation_shape(): | |
| """ | |
| Test that FACE_ROTATIONS has expected shape. | |
| """ | |
| expected = (3, 6, 6) | |
| observed = FACE_ROTATIONS.shape | |
| assert expected == observed, f"Face rotation tensor expected shape '{expected}', got '{observed}' instead" | |
| def test_face_rotation(axis: int, input: Iterable[int], expected: Iterable[int]): | |
| """ | |
| Test that POS_ROTATIONS behaves as expected. | |
| """ | |
| out = torch.tensor(input, dtype=FACE_ROTATIONS.dtype) @ FACE_ROTATIONS[axis] | |
| exp = torch.tensor(expected, dtype=FACE_ROTATIONS.dtype) | |
| assert torch.equal(out, exp), f"Face rotation tensor is incorrect along axis {axis}: {out} != {exp}" | |
| def test_build_actions_tensor_shape(size: int): | |
| """ | |
| Test that "build_actions_tensor" output has expected shape. | |
| """ | |
| expected = (3, size, 2, 6 * (size**2), 6 * (size**2)) | |
| observed = build_actions_tensor(size).shape | |
| assert expected == observed, ( | |
| f"'build_actions_tensor' output has incorrect shape: expected shape '{expected}', got '{observed}' instead" | |
| ) | |
| def test_build_action_tensor_shape(size: int, axis: int, slice: int, inverse: int): | |
| """ | |
| Test that "build_actions_tensor" output has expected shape. | |
| """ | |
| expected = (3, size, 2, 6 * (size**2), 6 * (size**2)) | |
| observed = build_action_tensor(size, axis, slice, inverse).shape | |
| assert expected == observed, ( | |
| f"'build_action_tensor' output has incorrect shape: expected shape '{expected}', got '{observed}' instead" | |
| ) | |
| def test_parse_action_str(move: str, expected: tuple[int, int, int]): | |
| """ | |
| Test that "parse_action_str" behaves as expected. | |
| """ | |
| observed = parse_action_str(move) | |
| assert expected == observed, ( | |
| f"'parse_action_str' output is incorrect: expected '{expected}', got '{observed}' instead" | |
| ) | |
| def test_parse_actions_str(moves: str, expected: tuple[int, int, int]): | |
| """ | |
| Test that "parse_action_str" behaves as expected. | |
| """ | |
| observed = parse_actions_str(moves) | |
| assert expected == observed, ( | |
| f"'parse_actions_str' output is incorrect: expected '{expected}', got '{observed}' instead" | |
| ) | |
| def test_sample_actions_str(num_moves: int, size: int, seed: int): | |
| """ | |
| Test that "sample_actions_str" is deterministic and outputs parsable content. | |
| """ | |
| moves_1 = sample_actions_str(num_moves, size, seed) | |
| moves_2 = sample_actions_str(num_moves, size, seed) | |
| assert moves_1 == moves_2, f"'sample_actions_str' is non-deterministic: {moves_1} != {moves_2}" | |
| parsed = parse_actions_str(moves_1) | |
| assert len(parsed) == len(moves_1.split()), "'sample_actions_str' output cannot be parsed correctly" | |