Spaces:
Sleeping
Sleeping
add unit tests
Browse files- tests/unit/test_action.py +40 -0
tests/unit/test_action.py
CHANGED
@@ -6,6 +6,8 @@ import torch
|
|
6 |
from rubik.action import (
|
7 |
POS_ROTATIONS,
|
8 |
POS_SHIFTS,
|
|
|
|
|
9 |
)
|
10 |
|
11 |
|
@@ -57,3 +59,41 @@ def test_position_shift(axis: int, size: int, input: Iterable[int], expected: It
|
|
57 |
out = rot + (POS_SHIFTS[axis] * (size - 1))
|
58 |
exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype) * (size - 1)
|
59 |
assert torch.equal(out, exp), f"Position shift tensor is incorrect along axis {axis}: {out} != {exp}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from rubik.action import (
|
7 |
POS_ROTATIONS,
|
8 |
POS_SHIFTS,
|
9 |
+
FACE_ROTATIONS,
|
10 |
+
build_actions_tensor,
|
11 |
)
|
12 |
|
13 |
|
|
|
59 |
out = rot + (POS_SHIFTS[axis] * (size - 1))
|
60 |
exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype) * (size - 1)
|
61 |
assert torch.equal(out, exp), f"Position shift tensor is incorrect along axis {axis}: {out} != {exp}"
|
62 |
+
|
63 |
+
|
64 |
+
def test_face_rotation_shape():
|
65 |
+
"""
|
66 |
+
Test that FACE_ROTATIONS has expected shape.
|
67 |
+
"""
|
68 |
+
expected = (3, 6, 6)
|
69 |
+
observed = FACE_ROTATIONS.shape
|
70 |
+
assert expected == observed, f"Face rotation tensor expected shape '{expected}', got '{observed}' instead"
|
71 |
+
|
72 |
+
|
73 |
+
@pytest.mark.parametrize(
|
74 |
+
"axis, input, expected",
|
75 |
+
[
|
76 |
+
(0, (1, 0, 0, 0, 0, 0), (0, 0, 1, 0, 0, 0)), # rotation about X axis: 0 (Up) -> 2 (Front)
|
77 |
+
(1, (1, 0, 0, 0, 0, 0), (0, 1, 0, 0, 0, 0)), # rotation about Y axis: 0 (Up) -> 1 (Left)
|
78 |
+
(2, (0, 1, 0, 0, 0, 0), (0, 0, 1, 0, 0, 0)), # rotation about Z axis: 1 (Left) -> 2 (Front)
|
79 |
+
],
|
80 |
+
)
|
81 |
+
def test_face_rotation(axis: int, input: Iterable[int], expected: Iterable[int]):
|
82 |
+
"""
|
83 |
+
Test that POS_ROTATIONS behaves as expected.
|
84 |
+
"""
|
85 |
+
out = torch.tensor(input, dtype=FACE_ROTATIONS.dtype) @ FACE_ROTATIONS[axis]
|
86 |
+
exp = torch.tensor(expected, dtype=FACE_ROTATIONS.dtype)
|
87 |
+
assert torch.equal(out, exp), f"Face rotation tensor is incorrect along axis {axis}: {out} != {exp}"
|
88 |
+
|
89 |
+
|
90 |
+
@pytest.mark.parametrize("size", [2, 3, 5, 20])
|
91 |
+
def test_build_actions_tensor_shape(size: int):
|
92 |
+
"""
|
93 |
+
Test that "build_actions_tensor" output has expected shape.
|
94 |
+
"""
|
95 |
+
expected = (3, size, 2, 6 * (size**2), 6 * (size**2))
|
96 |
+
observed = build_actions_tensor(size).shape
|
97 |
+
assert expected == observed, (
|
98 |
+
f"'build_actions_tensor' output has incorrect shape: expected shape '{expected}', got '{observed}' instead"
|
99 |
+
)
|