Spaces:
Sleeping
Sleeping
add unit test
Browse files- tests/unit/test_action.py +22 -0
tests/unit/test_action.py
CHANGED
@@ -5,6 +5,7 @@ import torch
|
|
5 |
|
6 |
from rubik.action import (
|
7 |
POS_ROTATIONS,
|
|
|
8 |
)
|
9 |
|
10 |
|
@@ -29,6 +30,27 @@ def test_position_rotation_shape():
|
|
29 |
],
|
30 |
)
|
31 |
def test_position_rotation(axis: int, input: Iterable[int], expected: Iterable[int]):
|
|
|
|
|
|
|
32 |
out = POS_ROTATIONS[axis] @ torch.tensor(input, dtype=POS_ROTATIONS.dtype)
|
33 |
exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype)
|
34 |
assert torch.equal(out, exp), f"Position rotation tensor is incorrect along axis {axis}: {out} != {exp}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
from rubik.action import (
|
7 |
POS_ROTATIONS,
|
8 |
+
POS_SHIFTS,
|
9 |
)
|
10 |
|
11 |
|
|
|
30 |
],
|
31 |
)
|
32 |
def test_position_rotation(axis: int, input: Iterable[int], expected: Iterable[int]):
|
33 |
+
"""
|
34 |
+
Test that POS_ROTATIONS behaves as expected.
|
35 |
+
"""
|
36 |
out = POS_ROTATIONS[axis] @ torch.tensor(input, dtype=POS_ROTATIONS.dtype)
|
37 |
exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype)
|
38 |
assert torch.equal(out, exp), f"Position rotation tensor is incorrect along axis {axis}: {out} != {exp}"
|
39 |
+
|
40 |
+
|
41 |
+
@pytest.mark.parametrize(
|
42 |
+
"axis, size, input, expected",
|
43 |
+
[
|
44 |
+
(0, 3, (1, 1, 1, 1), (1, 1, 1, 0)),
|
45 |
+
(1, 3, (1, 1, 1, 1), (1, 0, 1, 1)),
|
46 |
+
(2, 3, (1, 1, 1, 1), (1, 1, 0, 1)),
|
47 |
+
],
|
48 |
+
)
|
49 |
+
def test_position_shift(axis: int, size: int, input: Iterable[int], expected: Iterable[int]):
|
50 |
+
"""
|
51 |
+
Test that POS_SHIFTS behaves as expected.
|
52 |
+
"""
|
53 |
+
rot = POS_ROTATIONS[axis] @ (torch.tensor(input, dtype=POS_ROTATIONS.dtype) * (size - 1))
|
54 |
+
out = rot + (POS_SHIFTS[axis] * (size - 1))
|
55 |
+
exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype) * (size - 1)
|
56 |
+
assert torch.equal(out, exp), f"Position shift tensor is incorrect along axis {axis}: {out} != {exp}"
|