JBAujogue commited on
Commit
d756230
·
1 Parent(s): 7583934

rename internal variables during construction of action tensor

Browse files
Files changed (1) hide show
  1. src/rubik/moves.py +7 -8
src/rubik/moves.py CHANGED
@@ -135,18 +135,17 @@ def build_permunation_tensor(size: int, axis: int, slice: int, inverse: int) ->
135
  inputs = inputs.transpose(0, 1).tolist() # size = (n, 4)
136
  outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
137
 
138
- extract_to_coordinates = dict(enumerate(indices.tolist()))
139
- coordinates_to_extract = {ind: i for i, ind in extract_to_coordinates.items()}
140
 
141
- extract_perm = {i: inputs.index(outputs[i]) for i in range(len(inputs))}
142
- global_perm = {
143
- i: (i if i not in coordinates_to_extract else extract_to_coordinates[extract_perm[coordinates_to_extract[i]]])
144
- for i in range(length)
145
  }
146
  perm_indices = torch.tensor(
147
- [[axis] * length, [slice] * length, [inverse] * length, list(global_perm.keys()), list(global_perm.values())],
148
  dtype=INT8,
149
  )
150
- perm_values = torch.tensor([1] * length)
151
  perm_size = (3, size, 2, length, length)
152
  return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=INT8)
 
135
  inputs = inputs.transpose(0, 1).tolist() # size = (n, 4)
136
  outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
137
 
138
+ local_to_total = dict(enumerate(indices.tolist()))
139
+ total_to_local = {ind: i for i, ind in local_to_total.items()}
140
 
141
+ local_perm = {i: inputs.index(outputs[i]) for i in range(len(inputs))}
142
+ total_perm = {
143
+ i: (i if i not in total_to_local else local_to_total[local_perm[total_to_local[i]]]) for i in range(length)
 
144
  }
145
  perm_indices = torch.tensor(
146
+ [[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
147
  dtype=INT8,
148
  )
149
+ perm_values = torch.tensor([1] * length, dtype=INT8)
150
  perm_size = (3, size, 2, length, length)
151
  return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=INT8)