JBAujogue commited on
Commit
7a1d759
·
1 Parent(s): 905aff6

add unit tests

Browse files
Files changed (1) hide show
  1. tests/unit/test_tensor_utils.py +38 -0
tests/unit/test_tensor_utils.py CHANGED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ import torch
4
+
5
+ from rubik.tensor_utils import build_cube_tensor, build_permutation_matrix
6
+
7
+
8
+ @pytest.mark.parametrize("size", [2, 3, 5, 20])
9
+ def test_build_cube_tensor(size: int):
10
+ """
11
+ Test that build_cube_tensor behaves as expected.
12
+ """
13
+ tensor = build_cube_tensor(colors=["U", "L", "C", "R", "B", "D"], size=size)
14
+ facets = tensor.to_dense().to(dtype=torch.int8) != 0
15
+ x_sums = facets.sum(dim=(0, 2, 3)).tolist()
16
+ y_sums = facets.sum(dim=(0, 1, 3)).tolist()
17
+ z_sums = facets.sum(dim=(0, 1, 2)).tolist()
18
+ expected = [(size**2) + (4 * size)] + [4 * size] * (size - 2) + [(size**2) + (4 * size)]
19
+ assert x_sums == expected, (
20
+ f"'build_cube_tensor' has incorrect sum along X axis: expected '{expected}', got '{x_sums}'"
21
+ )
22
+ assert y_sums == expected, (
23
+ f"'build_cube_tensor' has incorrect sum along Y axis: expected '{expected}', got '{y_sums}'"
24
+ )
25
+ assert z_sums == expected, (
26
+ f"'build_cube_tensor' has incorrect sum along Z axis: expected '{expected}', got '{z_sums}'"
27
+ )
28
+
29
+
30
+ @pytest.mark.parametrize("size, perm", [[2, "01"], [3, "210"], [6, "2345"]])
31
+ def test_build_permutation_matrix(size: int, perm: str):
32
+ """
33
+ Test that build_permutation_matrix behaves as expected.
34
+ """
35
+ matrix = build_permutation_matrix(size, perm)
36
+ mapping = dict(matrix.indices().transpose(0, 1).tolist())
37
+ for i, j in zip(perm, perm[1:] + perm[0]):
38
+ assert mapping[int(i)] == int(j), f"'build_permutation_matrix' outputs has wrong behavior: {perm}, {mapping}"