JBAujogue commited on
Commit
3e056bc
·
1 Parent(s): 7011a7d

correct move parsing function for slices greater than 9

Browse files
Files changed (1) hide show
  1. src/rubik/action.py +8 -3
src/rubik/action.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import numpy as np
2
  import torch
3
  import torch.nn.functional as F
@@ -123,17 +125,20 @@ def build_action_tensor(size: int, axis: int, slice: int, inverse: int) -> torch
123
  return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.int16)
124
 
125
 
126
- def parse_action_str(move: str) -> tuple[int, ...]:
127
  """
128
  Convert the name of an action into a triple (axis, slice, inverse).
129
  Examples:
130
  'X1' -> (0, 1, 0)
131
  'X2i' -> (0, 2, 1)
132
  """
133
- return ("XYZ".index(move[0]), int(move[1]), int(len(move) >= 3))
 
 
 
134
 
135
 
136
- def parse_actions_str(moves: str) -> list[tuple[int, ...]]:
137
  """
138
  Convert a sequence of actions in a string into a list of triples (axis, slice, inverse).
139
  Examples:
 
1
+ import re
2
+
3
  import numpy as np
4
  import torch
5
  import torch.nn.functional as F
 
125
  return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=torch.int16)
126
 
127
 
128
+ def parse_action_str(move: str) -> tuple[int, int, int]:
129
  """
130
  Convert the name of an action into a triple (axis, slice, inverse).
131
  Examples:
132
  'X1' -> (0, 1, 0)
133
  'X2i' -> (0, 2, 1)
134
  """
135
+ axis = "XYZ".index(move[0])
136
+ slice = int(re.findall(r"^\d+", move[1:])[0])
137
+ inverse = int(len(move) > (1 + len(str(slice))))
138
+ return (axis, slice, inverse)
139
 
140
 
141
+ def parse_actions_str(moves: str) -> list[tuple[int, int, int]]:
142
  """
143
  Convert a sequence of actions in a string into a list of triples (axis, slice, inverse).
144
  Examples: