File size: 1,495 Bytes
7a1d759
 
 
 
f12b6ac
7a1d759
 
 
 
 
 
 
f12b6ac
7a1d759
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import pytest

import torch

from rubik.state import build_cube_tensor, build_permutation_matrix


@pytest.mark.parametrize("size", [2, 3, 5, 20])
def test_build_cube_tensor(size: int):
    """
    Test that build_cube_tensor behaves as expected.
    """
    tensor = build_cube_tensor(size)
    facets = tensor.to_dense().to(dtype=torch.int8) != 0
    x_sums = facets.sum(dim=(0, 2, 3)).tolist()
    y_sums = facets.sum(dim=(0, 1, 3)).tolist()
    z_sums = facets.sum(dim=(0, 1, 2)).tolist()
    expected = [(size**2) + (4 * size)] + [4 * size] * (size - 2) + [(size**2) + (4 * size)]
    assert x_sums == expected, (
        f"'build_cube_tensor' has incorrect sum along X axis: expected '{expected}', got '{x_sums}'"
    )
    assert y_sums == expected, (
        f"'build_cube_tensor' has incorrect sum along Y axis: expected '{expected}', got '{y_sums}'"
    )
    assert z_sums == expected, (
        f"'build_cube_tensor' has incorrect sum along Z axis: expected '{expected}', got '{z_sums}'"
    )


@pytest.mark.parametrize("size, perm", [[2, "01"], [3, "210"], [6, "2345"]])
def test_build_permutation_matrix(size: int, perm: str):
    """
    Test that build_permutation_matrix behaves as expected.
    """
    matrix = build_permutation_matrix(size, perm)
    mapping = dict(matrix.indices().transpose(0, 1).tolist())
    for i, j in zip(perm, perm[1:] + perm[0]):
        assert mapping[int(i)] == int(j), f"'build_permutation_matrix' outputs has wrong behavior: {perm}, {mapping}"