mac9087 commited on
Commit
b910820
·
verified ·
1 Parent(s): 9745daf

Create triplane.py

Browse files
Files changed (1) hide show
  1. tsr/models/tokenizers/triplane.py +45 -0
tsr/models/tokenizers/triplane.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange, repeat
7
+
8
+ from ...utils import BaseModule
9
+
10
+
11
+ class Triplane1DTokenizer(BaseModule):
12
+ @dataclass
13
+ class Config(BaseModule.Config):
14
+ plane_size: int
15
+ num_channels: int
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.embeddings = nn.Parameter(
21
+ torch.randn(
22
+ (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
23
+ dtype=torch.float32,
24
+ )
25
+ * 1
26
+ / math.sqrt(self.cfg.num_channels)
27
+ )
28
+
29
+ def forward(self, batch_size: int) -> torch.Tensor:
30
+ return rearrange(
31
+ repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
32
+ "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
33
+ )
34
+
35
+ def detokenize(self, tokens: torch.Tensor) -> torch.Tensor:
36
+ batch_size, Ct, Nt = tokens.shape
37
+ assert Nt == self.cfg.plane_size**2 * 3
38
+ assert Ct == self.cfg.num_channels
39
+ return rearrange(
40
+ tokens,
41
+ "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
42
+ Np=3,
43
+ Hp=self.cfg.plane_size,
44
+ Wp=self.cfg.plane_size,
45
+ )