JBAujogue commited on
Commit
958b135
·
1 Parent(s): 280d9eb

change display function params

Browse files
Files changed (1) hide show
  1. src/rubik/display.py +8 -5
src/rubik/display.py CHANGED
@@ -1,13 +1,16 @@
1
- def stringify(cube) -> str:
 
 
 
2
  """
3
  Compute a string representation of a cube.
4
  """
5
- colors = pad_colors(cube.colors)
6
- faces = cube.state.reshape(6, cube.size, cube.size).transpose(1, 2)
7
  faces = [[[colors[i - 1] for i in row] for row in face.tolist()] for face in faces]
8
- space = " " * max(len(c) for c in cube.colors) * cube.size
9
  l1 = "\n".join(" ".join([space, "".join(row), space, space]) for row in faces[0])
10
- l2 = "\n".join(" ".join("".join(face[i]) for face in faces[1:5]) for i in range(cube.size))
11
  l3 = "\n".join(" ".join((space, "".join(row), space, space)) for row in faces[-1])
12
  return "\n".join([l1, l2, l3])
13
 
 
1
+ import torch
2
+
3
+
4
+ def stringify(state: torch.Tensor, colors: list[str], size: int) -> str:
5
  """
6
  Compute a string representation of a cube.
7
  """
8
+ colors = pad_colors(colors)
9
+ faces = state.reshape(6, size, size).transpose(1, 2)
10
  faces = [[[colors[i - 1] for i in row] for row in face.tolist()] for face in faces]
11
+ space = " " * max(len(c) for c in colors) * size
12
  l1 = "\n".join(" ".join([space, "".join(row), space, space]) for row in faces[0])
13
+ l2 = "\n".join(" ".join("".join(face[i]) for face in faces[1:5]) for i in range(size))
14
  l3 = "\n".join(" ".join((space, "".join(row), space, space)) for row in faces[-1])
15
  return "\n".join([l1, l2, l3])
16