Spaces:
Sleeping
Sleeping
File size: 4,428 Bytes
da3e0de 438ca79 da3e0de f12b6ac da3e0de f12b6ac 27be29d f12b6ac da3e0de 438ca79 ad6d9bc f12b6ac 438ca79 ad6d9bc f12b6ac 438ca79 ad6d9bc f12b6ac 438ca79 f12b6ac 438ca79 ad6d9bc f12b6ac 438ca79 ad6d9bc f12b6ac 438ca79 f12b6ac 438ca79 27be29d ad6d9bc 27be29d ad6d9bc f12b6ac 438ca79 27be29d 438ca79 27be29d 438ca79 27be29d 438ca79 f12b6ac ad6d9bc f12b6ac 438ca79 f12b6ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import pytest
import torch
from rubik.cube import Cube
class TestCube:
"""
A testing class for the Cube class.
"""
@pytest.mark.parametrize("size", [3, 5, 10, 25])
def test__init__(self, size: int):
"""
Test that the __init__ method produce expected attributes.
"""
cube = Cube(size)
assert cube.state.shape == (6 * (size**2),), f"'state' has incorrect shape {cube.state.shape}"
assert cube.actions.shape == (3, size, 2, cube.state.shape[0]), (
f"'actions' has incorrect shape {cube.actions.shape}"
)
assert len(cube.history) == 0, "'history' field should be empty"
@pytest.mark.parametrize("device", ["cpu"])
def test_to(self, device: str | torch.device):
"""
Test that the .to method behaves as expected.
"""
cube = Cube(3)
cube_2 = cube.to(device)
assert torch.equal(cube.state, cube_2.state), "cube has different state after calling 'to' method"
def test_reset_history(self):
"""
Test that the .reset_history method behaves as expected.
"""
cube = Cube(3)
cube.rotate("X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i")
cube.reset_history()
assert cube.history == [], "method 'reset_history' does not flush content"
@pytest.mark.parametrize("num_moves, seed", [[50, 42]])
def test_shuffle(self, num_moves: int, seed: int):
"""
Test that the .shuffle method behaves as expected.
"""
cube = Cube(3)
cube_state = cube.state.clone()
cube.scramble(num_moves, seed)
assert cube.history == [], "method 'shuffle' does not flush content"
assert not torch.equal(cube_state, cube.state), "method 'shuffle' does not change state"
@pytest.mark.parametrize(
"moves",
[
"X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i",
"X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i" * 2,
],
)
def test_rotate(self, moves: str):
"""
Test that the .rotate method behaves as expected.
"""
cube = Cube(3)
cube_state = cube.state.clone()
cube.rotate(moves)
assert cube.history != [], "method 'rotate' does not update history"
assert not torch.equal(cube_state, cube.state), "method 'rotate' does not change state"
@pytest.mark.parametrize(
"axis, slice, inverse",
[
[0, 2, 0],
[1, 1, 1],
[2, 0, 0],
],
)
def test_rotate_once(self, axis: int, slice: int, inverse: int):
"""
Test that the .rotate_once method behaves as expected.
"""
cube = Cube(3)
cube_state = cube.state.clone()
cube.rotate_once(axis, slice, inverse)
assert cube.history == [(axis, slice, inverse)], "method 'rotate_once' does not update history"
assert not torch.equal(cube_state, cube.state), "method 'rotate_once' does not change state"
@pytest.mark.parametrize(
"moves",
[
"X2 X1i Y1i",
"X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i " * 2,
],
)
def test_compose_moves(self, moves: str):
"""
Test that the .compose_moves method behaves as expected.
"""
cube = Cube(3)
# apply changes induced by moves using the permutation dict returned by 'compute_changes'
changes = cube.compose_moves(moves)
expected = torch.gather(cube.state.clone(), 0, changes)
# apply changes induced by moves using the optimized 'rotate' method
cube.rotate(moves)
observed = cube.state
# assert the tow are identical
assert torch.equal(expected, observed), "method 'compute_changes' does not behave correctly: "
def test__str__len(self):
"""
Test that the __str__ method behaves as expected.
"""
cube = Cube(3)
repr = str(cube)
assert len(repr), "__str__ method returns an empty representation"
@pytest.mark.parametrize("size", [3, 5, 8, 10])
def test__str__content(self, size: int):
"""
Test that stringify behaves as expected.
"""
cube = Cube(size=size)
repr = str(cube)
lens = {len(line) for line in repr.split("\n")}
assert len(lens) == 1, f"'stringify' lines have variable length: {lens}"
|