JBAujogue commited on
Commit
7583934
·
1 Parent(s): 569ff0b

fix translation vector which was specific to cubes of size 3

Browse files
Files changed (1) hide show
  1. src/rubik/moves.py +4 -4
src/rubik/moves.py CHANGED
@@ -43,9 +43,9 @@ POS_ROTATIONS = torch.stack(
43
 
44
  POS_SHIFTS = torch.tensor(
45
  [
46
- [0, 0, 0, 2],
47
- [0, 2, 0, 0],
48
- [0, 0, 2, 0],
49
  ],
50
  dtype=INT8,
51
  )
@@ -124,7 +124,7 @@ def build_permunation_tensor(size: int, axis: int, slice: int, inverse: int) ->
124
 
125
  # apply coordinate rotation
126
  rotated = POS_ROTATIONS[axis] @ extract # size = (4, n)
127
- offsets = POS_SHIFTS[axis].repeat(extract.shape[-1], 1).transpose(0, 1) # size = (4, n)
128
  rotated = rotated + offsets # size = (4, n)
129
 
130
  # apply face rotation
 
43
 
44
  POS_SHIFTS = torch.tensor(
45
  [
46
+ [0, 0, 0, 1],
47
+ [0, 1, 0, 0],
48
+ [0, 0, 1, 0],
49
  ],
50
  dtype=INT8,
51
  )
 
124
 
125
  # apply coordinate rotation
126
  rotated = POS_ROTATIONS[axis] @ extract # size = (4, n)
127
+ offsets = (POS_SHIFTS[axis] * (size - 1)).repeat(extract.shape[-1], 1).transpose(0, 1) # size = (4, n)
128
  rotated = rotated + offsets # size = (4, n)
129
 
130
  # apply face rotation