JBAujogue commited on
Commit
df1c297
·
1 Parent(s): 8f96832

use sparse state representation

Browse files
Files changed (1) hide show
  1. src/rubik/cube.py +20 -18
src/rubik/cube.py CHANGED
@@ -1,13 +1,13 @@
1
- from dataclasses import dataclass
2
 
3
  import torch
 
4
 
5
  from rubik.action import build_actions_tensor, parse_action_str, sample_actions_str
6
  from rubik.display import stringify
7
  from rubik.tensor_utils import build_cube_tensor
8
 
9
 
10
- @dataclass
11
  class Cube:
12
  """
13
  A 4D tensor filled with colors. Dimensions have the following interpretation:
@@ -20,26 +20,28 @@ class Cube:
20
  the rest according to order given in "colors" attribute.
21
  """
22
 
23
- coordinates: torch.Tensor
24
- state: torch.Tensor
25
- actions: torch.Tensor
26
- history: list[list[int]]
27
- colors: list[str]
28
- size: int
29
-
30
- @classmethod
31
- def create(cls, colors: list[str], size: int) -> "Cube":
32
  """
33
  Create Cube from a given list of 6 colors and size.
34
  Example:
35
- cube = Cube.create(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
36
  """
37
  tensor = build_cube_tensor(colors, size)
38
- coordinates = tensor.indices().transpose(0, 1).to(torch.int8)
39
- values = tensor.values()
40
- actions = build_actions_tensor(size)
41
- history: list[list[int]] = []
42
- return cls(coordinates, values, actions, history, colors, size)
 
 
 
 
 
 
 
 
 
 
43
 
44
  def reset_history(self) -> None:
45
  """
@@ -85,4 +87,4 @@ class Cube:
85
  """
86
  Compute a string representation of a cube.
87
  """
88
- return stringify(self)
 
1
+ from loguru import logger
2
 
3
  import torch
4
+ import torch.nn.functional as F
5
 
6
  from rubik.action import build_actions_tensor, parse_action_str, sample_actions_str
7
  from rubik.display import stringify
8
  from rubik.tensor_utils import build_cube_tensor
9
 
10
 
 
11
  class Cube:
12
  """
13
  A 4D tensor filled with colors. Dimensions have the following interpretation:
 
20
  the rest according to order given in "colors" attribute.
21
  """
22
 
23
+ def __init__(self, colors: list[str], size: int):
 
 
 
 
 
 
 
 
24
  """
25
  Create Cube from a given list of 6 colors and size.
26
  Example:
27
+ cube = Cube(['U', 'L', 'C', 'R', 'B', 'D'], size = 3)
28
  """
29
  tensor = build_cube_tensor(colors, size)
30
+ self.coordinates = tensor.indices().transpose(0, 1).to(torch.int8)
31
+ self.state = F.one_hot(tensor.values().long()).to(torch.int8)
32
+ self.actions = build_actions_tensor(size)
33
+ self.history: list[list[int]] = []
34
+ self.colors = colors
35
+ self.size = size
36
+
37
+ def to(self, device: str | torch.device) -> "Cube":
38
+ device = torch.device(device)
39
+ dtype = torch.int8 if device == torch.device("cpu") else torch.float32
40
+ self.coordinates = self.coordinates.to(device=device, dtype=dtype)
41
+ self.state = self.state.to(device=device, dtype=dtype)
42
+ self.actions = self.actions.to(device=device, dtype=dtype)
43
+ logger.info(f"Using device '{self.state.device}' and dtype '{dtype}'")
44
+ return self
45
 
46
  def reset_history(self) -> None:
47
  """
 
87
  """
88
  Compute a string representation of a cube.
89
  """
90
+ return stringify(self.state.argmax(dim=-1).to(device="cpu", dtype=torch.int8), self.colors, self.size)