JBAujogue commited on
Commit
f14bbd0
·
1 Parent(s): 5457ec8

improve cube printing

Browse files
Files changed (1) hide show
  1. src/rubik/cube.py +32 -25
src/rubik/cube.py CHANGED
@@ -14,7 +14,7 @@ class Cube:
14
  - Color (from 0 to 6, 0 being the "dark" color, the rest according to order given in "colors" attribute).
15
  """
16
 
17
- tensor: torch.Tensor
18
  colors: list[str]
19
  size: int
20
 
@@ -30,29 +30,38 @@ class Cube:
30
 
31
  # build tensor filled with 0's, and fill the faces with 1's
32
  n = size - 1
33
- tensor = torch.zeros([size, size, size, 6, 7], dtype=torch.int8)
34
- tensor[:, :, n, 0, 1] = 1 # up
35
- tensor[0, :, :, 1, 2] = 1 # left
36
- tensor[:, 0, :, 2, 3] = 1 # front
37
- tensor[n, :, :, 3, 4] = 1 # right
38
- tensor[:, n, :, 4, 5] = 1 # back
39
- tensor[:, :, 0, 5, 6] = 1 # down
40
- return cls(tensor, colors, size)
41
 
42
- def to_grid(self) -> list[list[list[str]]]:
 
 
 
 
 
 
 
 
43
  """
44
  Convert Cube into a 3D grid representation.
45
  """
46
  n = self.size - 1
 
47
  grid = [
48
- self.tensor[:, :, n, 0, :].argmax(dim=-1), # up
49
- self.tensor[0, :, :, 1, :].argmax(dim=-1), # left
50
- self.tensor[:, 0, :, 2, :].argmax(dim=-1), # front
51
- self.tensor[n, :, :, 3, :].argmax(dim=-1), # right
52
- self.tensor[:, n, :, 4, :].argmax(dim=-1), # back
53
- self.tensor[:, :, 0, 5, :].argmax(dim=-1), # down
54
  ]
55
- return [[[self.colors[i - 1] for i in row] for row in face.tolist()] for face in grid]
56
 
57
  def __str__(self):
58
  """
@@ -70,11 +79,9 @@ class Cube:
70
  # DDD
71
  # DDD
72
  """
73
- grid = self.to_grid()
74
- void = " " * self.size
75
- top = "\n".join(" ".join([void, "".join(row), void, void]) for row in grid[0])
76
- middle = "\n".join(
77
- " ".join("".join(grid[face_i][row_i]) for face_i in range(1, 5)) for row_i in range(self.size)
78
- )
79
- bottom = "\n".join(" ".join((void, "".join(row), void, void)) for row in grid[-1])
80
- return "\n".join([top, middle, bottom])
 
14
  - Color (from 0 to 6, 0 being the "dark" color, the rest according to order given in "colors" attribute).
15
  """
16
 
17
+ state: torch.Tensor
18
  colors: list[str]
19
  size: int
20
 
 
30
 
31
  # build tensor filled with 0's, and fill the faces with 1's
32
  n = size - 1
33
+ state = torch.zeros([size, size, size, 6, 7], dtype=torch.int8)
34
+ state[:, :, n, 0, 1] = 1 # up
35
+ state[0, :, :, 1, 2] = 1 # left
36
+ state[:, 0, :, 2, 3] = 1 # front
37
+ state[n, :, :, 3, 4] = 1 # right
38
+ state[:, n, :, 4, 5] = 1 # back
39
+ state[:, :, 0, 5, 6] = 1 # down
40
+ return cls(state, colors, size)
41
 
42
+ @staticmethod
43
+ def pad_colors(colors: list[str]) -> list[str]:
44
+ """
45
+ Pad color names to strings of equal length.
46
+ """
47
+ max_len = max(len(c) for c in colors)
48
+ return [c + " " * (max_len - len(c)) for c in colors]
49
+
50
+ def to_grid(self, pad_colors: bool = False) -> list[list[list[str]]]:
51
  """
52
  Convert Cube into a 3D grid representation.
53
  """
54
  n = self.size - 1
55
+ colors = self.pad_colors(self.colors) if pad_colors else self.colors
56
  grid = [
57
+ self.state[:, :, n, 0, :].argmax(dim=-1), # up
58
+ self.state[0, :, :, 1, :].argmax(dim=-1), # left
59
+ self.state[:, 0, :, 2, :].argmax(dim=-1), # front
60
+ self.state[n, :, :, 3, :].argmax(dim=-1), # right
61
+ self.state[:, n, :, 4, :].argmax(dim=-1), # back
62
+ self.state[:, :, 0, 5, :].argmax(dim=-1), # down
63
  ]
64
+ return [[[colors[i - 1] for i in row] for row in face.tolist()] for face in grid]
65
 
66
  def __str__(self):
67
  """
 
79
  # DDD
80
  # DDD
81
  """
82
+ grid = self.to_grid(pad_colors=True)
83
+ void = " " * max(len(c) for c in self.colors) * self.size
84
+ l1 = "\n".join(" ".join([void, "".join(row), void, void]) for row in grid[0])
85
+ l2 = "\n".join(" ".join("".join(grid[face_i][row_i]) for face_i in range(1, 5)) for row_i in range(self.size))
86
+ l3 = "\n".join(" ".join((void, "".join(row), void, void)) for row in grid[-1])
87
+ return "\n".join([l1, l2, l3])