Spaces:
Sleeping
Sleeping
rename internal variables during construction of action tensor
Browse files- src/rubik/moves.py +7 -8
src/rubik/moves.py
CHANGED
@@ -135,18 +135,17 @@ def build_permunation_tensor(size: int, axis: int, slice: int, inverse: int) ->
|
|
135 |
inputs = inputs.transpose(0, 1).tolist() # size = (n, 4)
|
136 |
outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
|
137 |
|
138 |
-
|
139 |
-
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
i: (i if i not in
|
144 |
-
for i in range(length)
|
145 |
}
|
146 |
perm_indices = torch.tensor(
|
147 |
-
[[axis] * length, [slice] * length, [inverse] * length, list(
|
148 |
dtype=INT8,
|
149 |
)
|
150 |
-
perm_values = torch.tensor([1] * length)
|
151 |
perm_size = (3, size, 2, length, length)
|
152 |
return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=INT8)
|
|
|
135 |
inputs = inputs.transpose(0, 1).tolist() # size = (n, 4)
|
136 |
outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
|
137 |
|
138 |
+
local_to_total = dict(enumerate(indices.tolist()))
|
139 |
+
total_to_local = {ind: i for i, ind in local_to_total.items()}
|
140 |
|
141 |
+
local_perm = {i: inputs.index(outputs[i]) for i in range(len(inputs))}
|
142 |
+
total_perm = {
|
143 |
+
i: (i if i not in total_to_local else local_to_total[local_perm[total_to_local[i]]]) for i in range(length)
|
|
|
144 |
}
|
145 |
perm_indices = torch.tensor(
|
146 |
+
[[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
|
147 |
dtype=INT8,
|
148 |
)
|
149 |
+
perm_values = torch.tensor([1] * length, dtype=INT8)
|
150 |
perm_size = (3, size, 2, length, length)
|
151 |
return torch.sparse_coo_tensor(indices=perm_indices, values=perm_values, size=perm_size, dtype=INT8)
|