Spaces:
Running
Running
from dataclasses import dataclass | |
import torch | |
class Cube: | |
""" | |
A 5D tensor filled with 0 or 1. Dimensions have the following interpretation: | |
- X coordinate (from 0 to self.size - 1, from Left to Right). | |
- Y coordinate (from 0 to self.size - 1, from Back to Front). | |
- Z coordinate (from 0 to self.size - 1, from Down to Up). | |
- Face (from 0 to 5, with 0 = "Up", 1 = "Left", 2 = "Front", 3 = "Right", 4 = "Back", 5 = "Down"). | |
- Color (from 0 to 6, 0 being the "dark" color, the rest according to order given in "colors" attribute). | |
""" | |
tensor: torch.Tensor | |
colors: list[str] | |
size: int | |
def from_default(cls, colors: list[str], size: int) -> "Cube": | |
""" | |
Create Cube from a given list of 6 colors and size. | |
Example: | |
cube = Cube.from_default(['U', 'L', 'C', 'R', 'B', 'D'], size = 3) | |
""" | |
assert (num := len(set(colors))) == 6, f"Expected 6 distinct colors, got {num}" | |
assert isinstance(size, int) and size > 1, ( | |
f"Expected non-zero integrer size, got {size}" | |
) | |
# build tensor filled with 0's, and fill the faces with 1's | |
n = size - 1 | |
tensor = torch.zeros([size, size, size, 6, 7], dtype=torch.int8) | |
tensor[:, :, n, 0, 1] = 1 # up | |
tensor[0, :, :, 1, 2] = 1 # left | |
tensor[:, 0, :, 2, 3] = 1 # front | |
tensor[n, :, :, 3, 4] = 1 # right | |
tensor[:, n, :, 4, 5] = 1 # back | |
tensor[:, :, 0, 5, 6] = 1 # down | |
return cls(tensor, colors, size) | |
def to_grid(self) -> list[list[list[str]]]: | |
""" | |
Convert Cube into a 3D grid representation. | |
""" | |
n = self.size - 1 | |
grid = [ | |
self.tensor[:, :, n, 0, :].argmax(dim=-1), # up | |
self.tensor[0, :, :, 1, :].argmax(dim=-1), # left | |
self.tensor[:, 0, :, 2, :].argmax(dim=-1), # front | |
self.tensor[n, :, :, 3, :].argmax(dim=-1), # right | |
self.tensor[:, n, :, 4, :].argmax(dim=-1), # back | |
self.tensor[:, :, 0, 5, :].argmax(dim=-1), # down | |
] | |
return [ | |
[[self.colors[i - 1] for i in row] for row in face.tolist()] | |
for face in grid | |
] | |
def __str__(self): | |
""" | |
Compute a string representation of a cube. | |
Example: | |
cube = Cube.from_default(['U', 'L', 'C', 'R', 'B', 'D'], size = 3) | |
print(cube) | |
# UUU | |
# UUU | |
# UUU | |
# LLL CCC RRR BBB | |
# LLL CCC RRR BBB | |
# LLL CCC RRR BBB | |
# DDD | |
# DDD | |
# DDD | |
""" | |
grid = self.to_grid() | |
void = " " * self.size | |
top = "\n".join(" ".join([void, "".join(row), void, void]) for row in grid[0]) | |
middle = "\n".join( | |
" ".join("".join(grid[face_i][row_i]) for face_i in range(1, 5)) | |
for row_i in range(self.size) | |
) | |
bottom = "\n".join( | |
" ".join((void, "".join(row), void, void)) for row in grid[-1] | |
) | |
return "\n".join([top, middle, bottom]) | |