JBAujogue commited on
Commit
438ca79
·
1 Parent(s): da3e0de

add unit tests, achieving 100% coverage

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. src/rubik/cube.py +8 -8
  3. tests/unit/test_cube.py +77 -0
README.md CHANGED
@@ -40,7 +40,7 @@ print(cube.history)
40
  ### Perform basic moves
41
 
42
  ```python
43
- # shuffle the cube using 1000 random moves
44
  cube.shuffle(num_moves=1000, seed=0)
45
 
46
  # rotate it in some way
 
40
  ### Perform basic moves
41
 
42
  ```python
43
+ # shuffle the cube using 1000 random moves (random shuffling resets the history)
44
  cube.shuffle(num_moves=1000, seed=0)
45
 
46
  # rotate it in some way
src/rubik/cube.py CHANGED
@@ -37,7 +37,13 @@ class Cube:
37
 
38
  def to(self, device: str | torch.device) -> "Cube":
39
  device = torch.device(device)
40
- dtype = torch.int16 if device == torch.device("cpu") else torch.float32
 
 
 
 
 
 
41
  self.coordinates = self.coordinates.to(device=device, dtype=dtype)
42
  self.state = self.state.to(device=device, dtype=dtype)
43
  self.actions = self.actions.to(device=device, dtype=dtype)
@@ -84,15 +90,9 @@ class Cube:
84
  """
85
  actions = parse_actions_str(moves)
86
  tensors = [self.actions[*action].to(torch.float32) for action in actions]
87
- result = reduce(lambda A, B: A @ B, tensors).to(torch.int16)
88
  return dict(result.indices().transpose(0, 1).tolist())
89
 
90
- def solve(self, policy: str) -> None:
91
- """
92
- Apply the specified solving policy to the cube.
93
- """
94
- raise NotImplementedError
95
-
96
  def __str__(self):
97
  """
98
  Compute a string representation of a cube.
 
37
 
38
  def to(self, device: str | torch.device) -> "Cube":
39
  device = torch.device(device)
40
+ dtype = (
41
+ self.state.dtype
42
+ if self.state.device == device
43
+ else torch.int16
44
+ if device == torch.device("cpu")
45
+ else torch.float32
46
+ )
47
  self.coordinates = self.coordinates.to(device=device, dtype=dtype)
48
  self.state = self.state.to(device=device, dtype=dtype)
49
  self.actions = self.actions.to(device=device, dtype=dtype)
 
90
  """
91
  actions = parse_actions_str(moves)
92
  tensors = [self.actions[*action].to(torch.float32) for action in actions]
93
+ result = reduce(lambda A, B: B @ A, tensors).to(torch.int16).coalesce()
94
  return dict(result.indices().transpose(0, 1).tolist())
95
 
 
 
 
 
 
 
96
  def __str__(self):
97
  """
98
  Compute a string representation of a cube.
tests/unit/test_cube.py CHANGED
@@ -1,5 +1,6 @@
1
  import pytest
2
 
 
3
 
4
  from rubik.cube import Cube
5
 
@@ -27,3 +28,79 @@ class TestCube:
27
  )
28
  assert cube.state.shape == (6 * (size**2), 7), f"'state' has incorrect shape {cube.state.shape}"
29
  assert len(cube.history) == 0, "'history' field should be empty"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pytest
2
 
3
+ import torch
4
 
5
  from rubik.cube import Cube
6
 
 
28
  )
29
  assert cube.state.shape == (6 * (size**2), 7), f"'state' has incorrect shape {cube.state.shape}"
30
  assert len(cube.history) == 0, "'history' field should be empty"
31
+
32
+ @pytest.mark.parametrize("device", ["cpu"])
33
+ def test_to(self, device: str | torch.device):
34
+ cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
35
+ cube_2 = cube.to(device)
36
+ assert torch.equal(cube.state, cube_2.state), "cube has different state after calling 'to' method"
37
+
38
+ def test_reset_history(self):
39
+ cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
40
+ cube.rotate("X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i")
41
+ cube.reset_history()
42
+ assert cube.history == [], "method 'reset_history' does not flush content"
43
+
44
+ @pytest.mark.parametrize("num_moves, seed", [[50, 42]])
45
+ def test_shuffle(self, num_moves: int, seed: int):
46
+ cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
47
+ cube_state = cube.state.clone()
48
+ cube.shuffle(num_moves, seed)
49
+ assert cube.history == [], "method 'shuffle' does not flush content"
50
+ assert not torch.equal(cube_state, cube.state), "method 'shuffle' does not change state"
51
+
52
+ @pytest.mark.parametrize(
53
+ "moves",
54
+ [
55
+ "X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i",
56
+ "X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i" * 2,
57
+ ],
58
+ )
59
+ def test_rotate(self, moves: str):
60
+ cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
61
+ cube_state = cube.state.clone()
62
+ cube.rotate(moves)
63
+ assert cube.history != [], "method 'rotate' does not update history"
64
+ assert not torch.equal(cube_state, cube.state), "method 'rotate' does not change state"
65
+
66
+ @pytest.mark.parametrize(
67
+ "axis, slice, inverse",
68
+ [
69
+ [0, 2, 0],
70
+ [1, 1, 1],
71
+ [2, 0, 0],
72
+ ],
73
+ )
74
+ def test_rotate_once(self, axis: int, slice: int, inverse: int):
75
+ cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
76
+ cube_state = cube.state.clone()
77
+ cube.rotate_once(axis, slice, inverse)
78
+ assert cube.history == [[axis, slice, inverse]], "method 'rotate_once' does not update history"
79
+ assert not torch.equal(cube_state, cube.state), "method 'rotate_once' does not change state"
80
+
81
+ @pytest.mark.parametrize(
82
+ "moves",
83
+ [
84
+ "X2 X1i Y1i",
85
+ "X2 X1i Y1i Z1i Y0 Z0i X2 X1i Y1i Z1i Y0 Z0i " * 2,
86
+ ],
87
+ )
88
+ def test_compute_changes(self, moves: str):
89
+ cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
90
+ facets = cube.state.argmax(dim=-1).to(torch.int16).tolist()
91
+ changes = cube.compute_changes(moves)
92
+
93
+ # apply changes induced by moves using the permutation dict returned by 'compute_changes'
94
+ expected = [facets[changes.get(i, i)] for i in range(len(facets))]
95
+
96
+ # apply changes induced by moves using the optimized 'rotate' method
97
+ cube.rotate(moves)
98
+ observed = cube.state.argmax(dim=-1).to(torch.int16).tolist()
99
+
100
+ # assert the tow are identical
101
+ assert expected == observed, "method 'compute_changes' does not behave correctly: "
102
+
103
+ def test__str__(self):
104
+ cube = Cube(colors=["U", "L", "C", "R", "B", "D"], size=3)
105
+ repr = str(cube)
106
+ assert len(repr), "__str__ method returns an empty representation"