JBAujogue commited on
Commit
7011a7d
·
1 Parent(s): 68f8c07

change base dtype from int8 to int16 to avoid overflow when cube size is >= 7

Browse files
src/rubik/action.py CHANGED
@@ -15,7 +15,7 @@ POS_ROTATIONS = torch.stack(
15
  [0, 0, 0, 1],
16
  [0, 0, -1, 0],
17
  ],
18
- dtype=torch.int8,
19
  ),
20
  # rot about Y: X -> Z
21
  torch.tensor(
@@ -25,7 +25,7 @@ POS_ROTATIONS = torch.stack(
25
  [0, 0, 1, 0],
26
  [0, 1, 0, 0],
27
  ],
28
- dtype=torch.int8,
29
  ),
30
  # rot about Z: Y -> X
31
  torch.tensor(
@@ -35,7 +35,7 @@ POS_ROTATIONS = torch.stack(
35
  [0, -1, 0, 0],
36
  [0, 0, 0, 1],
37
  ],
38
- dtype=torch.int8,
39
  ),
40
  ]
41
  )
@@ -46,7 +46,7 @@ POS_SHIFTS = torch.tensor(
46
  [0, 1, 0, 0],
47
  [0, 0, 1, 0],
48
  ],
49
- dtype=torch.int8,
50
  )
51
 
52
 
@@ -74,7 +74,7 @@ def build_actions_tensor(size: int) -> torch.Tensor:
74
  for inverse in range(2)
75
  ],
76
  dim=0,
77
- ).sum(dim=0, dtype=torch.int8)
78
 
79
 
80
  def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor:
@@ -87,7 +87,7 @@ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch
87
  length = 6 * (size**2)
88
 
89
  # extract faces impacted by the move
90
- indices = tensor.indices().to(dtype=torch.int8) # size = (4, length)
91
  changes = (indices[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
92
  extract = indices[:, changes] # size = (4, n)
93
 
@@ -97,7 +97,7 @@ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch
97
  rotated = rotated + offsets # size = (4, n)
98
 
99
  # apply face rotation
100
- rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(torch.int8) @ FACE_ROTATIONS[axis]).argmax(dim=-1)
101
 
102
  # from this point on, convert rotation into a position-based permutation of colors
103
  (inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
@@ -116,11 +116,11 @@ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch
116
  # convert permutation dict into sparse tensor
117
  perm_indices = torch.tensor(
118
  [[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
119
- dtype=torch.int8,
120
  )
121
- perm_values = torch.tensor([1] * length, dtype=torch.int8)
122
  perm_size = (3, size, 2, length, length)
123
- return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.int8)
124
 
125
 
126
  def parse_action_str(move: str) -> tuple[int, ...]:
 
15
  [0, 0, 0, 1],
16
  [0, 0, -1, 0],
17
  ],
18
+ dtype=torch.int16,
19
  ),
20
  # rot about Y: X -> Z
21
  torch.tensor(
 
25
  [0, 0, 1, 0],
26
  [0, 1, 0, 0],
27
  ],
28
+ dtype=torch.int16,
29
  ),
30
  # rot about Z: Y -> X
31
  torch.tensor(
 
35
  [0, -1, 0, 0],
36
  [0, 0, 0, 1],
37
  ],
38
+ dtype=torch.int16,
39
  ),
40
  ]
41
  )
 
46
  [0, 1, 0, 0],
47
  [0, 0, 1, 0],
48
  ],
49
+ dtype=torch.int16,
50
  )
51
 
52
 
 
74
  for inverse in range(2)
75
  ],
76
  dim=0,
77
+ ).sum(dim=0, dtype=torch.int16)
78
 
79
 
80
  def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch.Tensor:
 
87
  length = 6 * (size**2)
88
 
89
  # extract faces impacted by the move
90
+ indices = tensor.indices().to(dtype=torch.int16) # size = (4, length)
91
  changes = (indices[axis + 1] == slice).nonzero().reshape(-1) # size = (n,), n < length
92
  extract = indices[:, changes] # size = (4, n)
93
 
 
97
  rotated = rotated + offsets # size = (4, n)
98
 
99
  # apply face rotation
100
+ rotated[0] = (F.one_hot(rotated[0].long(), num_classes=6).to(torch.int16) @ FACE_ROTATIONS[axis]).argmax(dim=-1)
101
 
102
  # from this point on, convert rotation into a position-based permutation of colors
103
  (inputs, outputs) = (rotated, extract) if bool(inverse) else (extract, rotated)
 
116
  # convert permutation dict into sparse tensor
117
  perm_indices = torch.tensor(
118
  [[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
119
+ dtype=torch.int16,
120
  )
121
+ perm_values = torch.tensor([1] * length, dtype=torch.int16)
122
  perm_size = (3, size, 2, length, length)
123
+ return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.int16)
124
 
125
 
126
  def parse_action_str(move: str) -> tuple[int, ...]:
src/rubik/cube.py CHANGED
@@ -28,8 +28,8 @@ class Cube:
28
  cube = Cube(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
29
  """
30
  tensor = build_cube_tensor(colors, size)
31
- self.coordinates = tensor.indices().transpose(0, 1).to(torch.int8)
32
- self.state = F.one_hot(tensor.values().long()).to(torch.int8)
33
  self.actions = build_actions_tensor(size)
34
  self.history: list[list[int]] = []
35
  self.colors = colors
@@ -37,7 +37,7 @@ class Cube:
37
 
38
  def to(self, device: str | torch.device) -> "Cube":
39
  device = torch.device(device)
40
- dtype = torch.int8 if device == torch.device("cpu") else torch.float32
41
  self.coordinates = self.coordinates.to(device=device, dtype=dtype)
42
  self.state = self.state.to(device=device, dtype=dtype)
43
  self.actions = self.actions.to(device=device, dtype=dtype)
@@ -84,7 +84,7 @@ class Cube:
84
  """
85
  actions = parse_actions_str(moves)
86
  tensors = [self.actions[*action].to(torch.float32) for action in actions]
87
- result = reduce(lambda A, B: A @ B, tensors).to(torch.int8)
88
  return dict(result.indices().transpose(0, 1).tolist())
89
 
90
  def solve(self, policy: str) -> None:
@@ -97,5 +97,5 @@ class Cube:
97
  """
98
  Compute a string representation of a cube.
99
  """
100
- state = self.state.argmax(dim=-1).to(device="cpu", dtype=torch.int8)
101
  return stringify(state, self.colors, self.size)
 
28
  cube = Cube(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
29
  """
30
  tensor = build_cube_tensor(colors, size)
31
+ self.coordinates = tensor.indices().transpose(0, 1).to(torch.int16)
32
+ self.state = F.one_hot(tensor.values().long()).to(torch.int16)
33
  self.actions = build_actions_tensor(size)
34
  self.history: list[list[int]] = []
35
  self.colors = colors
 
37
 
38
  def to(self, device: str | torch.device) -> "Cube":
39
  device = torch.device(device)
40
+ dtype = torch.int16 if device == torch.device("cpu") else torch.float32
41
  self.coordinates = self.coordinates.to(device=device, dtype=dtype)
42
  self.state = self.state.to(device=device, dtype=dtype)
43
  self.actions = self.actions.to(device=device, dtype=dtype)
 
84
  """
85
  actions = parse_actions_str(moves)
86
  tensors = [self.actions[*action].to(torch.float32) for action in actions]
87
+ result = reduce(lambda A, B: A @ B, tensors).to(torch.int16)
88
  return dict(result.indices().transpose(0, 1).tolist())
89
 
90
  def solve(self, policy: str) -> None:
 
97
  """
98
  Compute a string representation of a cube.
99
  """
100
+ state = self.state.argmax(dim=-1).to(device="cpu", dtype=torch.int16)
101
  return stringify(state, self.colors, self.size)
src/rubik/tensor_utils.py CHANGED
@@ -10,7 +10,7 @@ def build_cube_tensor(colors: list[str], size: int) -> torch.Tensor:
10
 
11
  # build dense tensor filled with colors
12
  n = size - 1
13
- tensor = torch.zeros([6, size, size, size], dtype=torch.int8)
14
  tensor[0, :, :, n] = 1 # up
15
  tensor[1, 0, :, :] = 2 # left
16
  tensor[2, :, n, :] = 3 # front
@@ -26,6 +26,6 @@ def build_permutation_matrix(size: int, perm: str) -> torch.Tensor:
26
  """
27
  perm_list = [int(p) for p in (perm + perm[0])]
28
  perm_dict = {perm_list[i]: perm_list[i + 1] for i in range(len(perm))}
29
- indices = torch.tensor([list(range(size)), [(perm_dict.get(i, i)) for i in range(size)]], dtype=torch.int8)
30
- values = torch.tensor([1] * size, dtype=torch.int8)
31
- return torch.sparse_coo_tensor(indices=indices, values=values, size=(size, size), dtype=torch.int8)
 
10
 
11
  # build dense tensor filled with colors
12
  n = size - 1
13
+ tensor = torch.zeros([6, size, size, size], dtype=torch.int16)
14
  tensor[0, :, :, n] = 1 # up
15
  tensor[1, 0, :, :] = 2 # left
16
  tensor[2, :, n, :] = 3 # front
 
26
  """
27
  perm_list = [int(p) for p in (perm + perm[0])]
28
  perm_dict = {perm_list[i]: perm_list[i + 1] for i in range(len(perm))}
29
+ indices = torch.tensor([list(range(size)), [(perm_dict.get(i, i)) for i in range(size)]], dtype=torch.int16)
30
+ values = torch.tensor([1] * size, dtype=torch.int16)
31
+ return torch.sparse_coo_tensor(indices=indices, values=values, size=(size, size), dtype=torch.int16)