Spaces:
Sleeping
Sleeping
use sparse state representation
Browse files- src/rubik/cube.py +20 -18
src/rubik/cube.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
from
|
2 |
|
3 |
import torch
|
|
|
4 |
|
5 |
from rubik.action import build_actions_tensor, parse_action_str, sample_actions_str
|
6 |
from rubik.display import stringify
|
7 |
from rubik.tensor_utils import build_cube_tensor
|
8 |
|
9 |
|
10 |
-
@dataclass
|
11 |
class Cube:
|
12 |
"""
|
13 |
A 4D tensor filled with colors. Dimensions have the following interpretation:
|
@@ -20,26 +20,28 @@ class Cube:
|
|
20 |
the rest according to order given in "colors" attribute.
|
21 |
"""
|
22 |
|
23 |
-
|
24 |
-
state: torch.Tensor
|
25 |
-
actions: torch.Tensor
|
26 |
-
history: list[list[int]]
|
27 |
-
colors: list[str]
|
28 |
-
size: int
|
29 |
-
|
30 |
-
@classmethod
|
31 |
-
def create(cls, colors: list[str], size: int) -> "Cube":
|
32 |
"""
|
33 |
Create Cube from a given list of 6 colors and size.
|
34 |
Example:
|
35 |
-
cube = Cube
|
36 |
"""
|
37 |
tensor = build_cube_tensor(colors, size)
|
38 |
-
coordinates = tensor.indices().transpose(0, 1).to(torch.int8)
|
39 |
-
|
40 |
-
actions = build_actions_tensor(size)
|
41 |
-
history: list[list[int]] = []
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
def reset_history(self) -> None:
|
45 |
"""
|
@@ -85,4 +87,4 @@ class Cube:
|
|
85 |
"""
|
86 |
Compute a string representation of a cube.
|
87 |
"""
|
88 |
-
return stringify(self)
|
|
|
1 |
+
from loguru import logger
|
2 |
|
3 |
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
|
6 |
from rubik.action import build_actions_tensor, parse_action_str, sample_actions_str
|
7 |
from rubik.display import stringify
|
8 |
from rubik.tensor_utils import build_cube_tensor
|
9 |
|
10 |
|
|
|
11 |
class Cube:
|
12 |
"""
|
13 |
A 4D tensor filled with colors. Dimensions have the following interpretation:
|
|
|
20 |
the rest according to order given in "colors" attribute.
|
21 |
"""
|
22 |
|
23 |
+
def __init__(self, colors: list[str], size: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
"""
|
25 |
Create Cube from a given list of 6 colors and size.
|
26 |
Example:
|
27 |
+
cube = Cube(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
|
28 |
"""
|
29 |
tensor = build_cube_tensor(colors, size)
|
30 |
+
self.coordinates = tensor.indices().transpose(0, 1).to(torch.int8)
|
31 |
+
self.state = F.one_hot(tensor.values().long()).to(torch.int8)
|
32 |
+
self.actions = build_actions_tensor(size)
|
33 |
+
self.history: list[list[int]] = []
|
34 |
+
self.colors = colors
|
35 |
+
self.size = size
|
36 |
+
|
37 |
+
def to(self, device: str | torch.device) -> "Cube":
|
38 |
+
device = torch.device(device)
|
39 |
+
dtype = torch.int8 if device == torch.device("cpu") else torch.float32
|
40 |
+
self.coordinates = self.coordinates.to(device=device, dtype=dtype)
|
41 |
+
self.state = self.state.to(device=device, dtype=dtype)
|
42 |
+
self.actions = self.actions.to(device=device, dtype=dtype)
|
43 |
+
logger.info(f"Using device '{self.state.device}' and dtype '{dtype}'")
|
44 |
+
return self
|
45 |
|
46 |
def reset_history(self) -> None:
|
47 |
"""
|
|
|
87 |
"""
|
88 |
Compute a string representation of a cube.
|
89 |
"""
|
90 |
+
return stringify(self.state.argmax(dim=-1).to(device="cpu", dtype=torch.int8), self.colors, self.size)
|