JBAujogue commited on
Commit
5e5c08a
·
1 Parent(s): d756230

add comments

Browse files
Files changed (1) hide show
  1. src/rubik/moves.py +3 -0
src/rubik/moves.py CHANGED
@@ -135,6 +135,7 @@ def build_permunation_tensor(size: int, axis: int, slice: int, inverse: int) ->
135
  inputs = inputs.transpose(0, 1).tolist() # size = (n, 4)
136
  outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
137
 
 
138
  local_to_total = dict(enumerate(indices.tolist()))
139
  total_to_local = {ind: i for i, ind in local_to_total.items()}
140
 
@@ -142,6 +143,8 @@ def build_permunation_tensor(size: int, axis: int, slice: int, inverse: int) ->
142
  total_perm = {
143
  i: (i if i not in total_to_local else local_to_total[local_perm[total_to_local[i]]]) for i in range(length)
144
  }
 
 
145
  perm_indices = torch.tensor(
146
  [[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
147
  dtype=INT8,
 
135
  inputs = inputs.transpose(0, 1).tolist() # size = (n, 4)
136
  outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
137
 
138
+ # compute position-based permutation of colors equivalent to rotation converting inputs into outputs
139
  local_to_total = dict(enumerate(indices.tolist()))
140
  total_to_local = {ind: i for i, ind in local_to_total.items()}
141
 
 
143
  total_perm = {
144
  i: (i if i not in total_to_local else local_to_total[local_perm[total_to_local[i]]]) for i in range(length)
145
  }
146
+
147
+ # convert permutation dict into sparse tensor
148
  perm_indices = torch.tensor(
149
  [[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
150
  dtype=INT8,