JBAujogue commited on
Commit
8f96832
·
1 Parent(s): cb37bd4

add unit test

Browse files
Files changed (1) hide show
  1. tests/unit/test_action.py +22 -0
tests/unit/test_action.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
 
6
  from rubik.action import (
7
  POS_ROTATIONS,
 
8
  )
9
 
10
 
@@ -29,6 +30,27 @@ def test_position_rotation_shape():
29
  ],
30
  )
31
  def test_position_rotation(axis: int, input: Iterable[int], expected: Iterable[int]):
 
 
 
32
  out = POS_ROTATIONS[axis] @ torch.tensor(input, dtype=POS_ROTATIONS.dtype)
33
  exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype)
34
  assert torch.equal(out, exp), f"Position rotation tensor is incorrect along axis {axis}: {out} != {exp}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  from rubik.action import (
7
  POS_ROTATIONS,
8
+ POS_SHIFTS,
9
  )
10
 
11
 
 
30
  ],
31
  )
32
  def test_position_rotation(axis: int, input: Iterable[int], expected: Iterable[int]):
33
+ """
34
+ Test that POS_ROTATIONS behaves as expected.
35
+ """
36
  out = POS_ROTATIONS[axis] @ torch.tensor(input, dtype=POS_ROTATIONS.dtype)
37
  exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype)
38
  assert torch.equal(out, exp), f"Position rotation tensor is incorrect along axis {axis}: {out} != {exp}"
39
+
40
+
41
+ @pytest.mark.parametrize(
42
+ "axis, size, input, expected",
43
+ [
44
+ (0, 3, (1, 1, 1, 1), (1, 1, 1, 0)),
45
+ (1, 3, (1, 1, 1, 1), (1, 0, 1, 1)),
46
+ (2, 3, (1, 1, 1, 1), (1, 1, 0, 1)),
47
+ ],
48
+ )
49
+ def test_position_shift(axis: int, size: int, input: Iterable[int], expected: Iterable[int]):
50
+ """
51
+ Test that POS_SHIFTS behaves as expected.
52
+ """
53
+ rot = POS_ROTATIONS[axis] @ (torch.tensor(input, dtype=POS_ROTATIONS.dtype) * (size - 1))
54
+ out = rot + (POS_SHIFTS[axis] * (size - 1))
55
+ exp = torch.tensor(expected, dtype=POS_ROTATIONS.dtype) * (size - 1)
56
+ assert torch.equal(out, exp), f"Position shift tensor is incorrect along axis {axis}: {out} != {exp}"