JBAujogue commited on
Commit
68f8c07
·
1 Parent(s): b24b865

add unit tests

Browse files
Files changed (1) hide show
  1. tests/unit/test_action.py +40 -0
tests/unit/test_action.py CHANGED
@@ -6,6 +6,8 @@ import torch
6
  from rubik.action import (
7
  POS_ROTATIONS,
8
  POS_SHIFTS,
 
 
9
  )
10
 
11
 
@@ -57,3 +59,41 @@ def test_position_shift(axis: int, size: int, input: Iterable[int], expected: It
57
  out = rot + (POS_SHIFTS[axis] * (size - 1))
58
  exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype) * (size - 1)
59
  assert torch.equal(out, exp), f"Position shift tensor is incorrect along axis {axis}: {out} != {exp}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from rubik.action import (
7
  POS_ROTATIONS,
8
  POS_SHIFTS,
9
+ FACE_ROTATIONS,
10
+ build_actions_tensor,
11
  )
12
 
13
 
 
59
  out = rot + (POS_SHIFTS[axis] * (size - 1))
60
  exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype) * (size - 1)
61
  assert torch.equal(out, exp), f"Position shift tensor is incorrect along axis {axis}: {out} != {exp}"
62
+
63
+
64
+ def test_face_rotation_shape():
65
+ """
66
+ Test that FACE_ROTATIONS has expected shape.
67
+ """
68
+ expected = (3, 6, 6)
69
+ observed = FACE_ROTATIONS.shape
70
+ assert expected == observed, f"Face rotation tensor expected shape '{expected}', got '{observed}' instead"
71
+
72
+
73
+ @pytest.mark.parametrize(
74
+ "axis, input, expected",
75
+ [
76
+ (0, (1, 0, 0, 0, 0, 0), (0, 0, 1, 0, 0, 0)), # rotation about X axis: 0 (Up) -> 2 (Front)
77
+ (1, (1, 0, 0, 0, 0, 0), (0, 1, 0, 0, 0, 0)), # rotation about Y axis: 0 (Up) -> 1 (Left)
78
+ (2, (0, 1, 0, 0, 0, 0), (0, 0, 1, 0, 0, 0)), # rotation about Z axis: 1 (Left) -> 2 (Front)
79
+ ],
80
+ )
81
+ def test_face_rotation(axis: int, input: Iterable[int], expected: Iterable[int]):
82
+ """
83
+ Test that POS_ROTATIONS behaves as expected.
84
+ """
85
+ out = torch.tensor(input, dtype=FACE_ROTATIONS.dtype) @ FACE_ROTATIONS[axis]
86
+ exp = torch.tensor(expected, dtype=FACE_ROTATIONS.dtype)
87
+ assert torch.equal(out, exp), f"Face rotation tensor is incorrect along axis {axis}: {out} != {exp}"
88
+
89
+
90
+ @pytest.mark.parametrize("size", [2, 3, 5, 20])
91
+ def test_build_actions_tensor_shape(size: int):
92
+ """
93
+ Test that "build_actions_tensor" output has expected shape.
94
+ """
95
+ expected = (3, size, 2, 6 * (size**2), 6 * (size**2))
96
+ observed = build_actions_tensor(size).shape
97
+ assert expected == observed, (
98
+ f"'build_actions_tensor' output has incorrect shape: expected shape '{expected}', got '{observed}' instead"
99
+ )