JBAujogue commited on
Commit
905aff6
·
1 Parent(s): 5e3396a

coalesce permutation tensor

Browse files
Files changed (1) hide show
  1. src/rubik/tensor_utils.py +1 -1
src/rubik/tensor_utils.py CHANGED
@@ -28,4 +28,4 @@ def build_permutation_matrix(size: int, perm: str) -> torch.Tensor:
28
  perm_dict = {perm_list[i]: perm_list[i + 1] for i in range(len(perm))}
29
  indices = torch.tensor([list(range(size)), [(perm_dict.get(i, i)) for i in range(size)]], dtype=torch.int16)
30
  values = torch.tensor([1] * size, dtype=torch.int16)
31
- return torch.sparse_coo_tensor(indices=indices, values=values, size=(size, size), dtype=torch.int16)
 
28
  perm_dict = {perm_list[i]: perm_list[i + 1] for i in range(len(perm))}
29
  indices = torch.tensor([list(range(size)), [(perm_dict.get(i, i)) for i in range(size)]], dtype=torch.int16)
30
  values = torch.tensor([1] * size, dtype=torch.int16)
31
+ return torch.sparse_coo_tensor(indices=indices, values=values, size=(size, size), dtype=torch.int16).coalesce()