mac9087 commited on
Commit
04bf072
·
verified ·
1 Parent(s): 95289e5

Create network_utils.py

Browse files
Files changed (1) hide show
  1. tsr/models/network_utils.py +124 -0
tsr/models/network_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+
8
+ from ..utils import BaseModule
9
+
10
+
11
+ class TriplaneUpsampleNetwork(BaseModule):
12
+ @dataclass
13
+ class Config(BaseModule.Config):
14
+ in_channels: int
15
+ out_channels: int
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.upsample = nn.ConvTranspose2d(
21
+ self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
22
+ )
23
+
24
+ def forward(self, triplanes: torch.Tensor) -> torch.Tensor:
25
+ triplanes_up = rearrange(
26
+ self.upsample(
27
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
28
+ ),
29
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
30
+ Np=3,
31
+ )
32
+ return triplanes_up
33
+
34
+
35
+ class NeRFMLP(BaseModule):
36
+ @dataclass
37
+ class Config(BaseModule.Config):
38
+ in_channels: int
39
+ n_neurons: int
40
+ n_hidden_layers: int
41
+ activation: str = "relu"
42
+ bias: bool = True
43
+ weight_init: Optional[str] = "kaiming_uniform"
44
+ bias_init: Optional[str] = None
45
+
46
+ cfg: Config
47
+
48
+ def configure(self) -> None:
49
+ layers = [
50
+ self.make_linear(
51
+ self.cfg.in_channels,
52
+ self.cfg.n_neurons,
53
+ bias=self.cfg.bias,
54
+ weight_init=self.cfg.weight_init,
55
+ bias_init=self.cfg.bias_init,
56
+ ),
57
+ self.make_activation(self.cfg.activation),
58
+ ]
59
+ for i in range(self.cfg.n_hidden_layers - 1):
60
+ layers += [
61
+ self.make_linear(
62
+ self.cfg.n_neurons,
63
+ self.cfg.n_neurons,
64
+ bias=self.cfg.bias,
65
+ weight_init=self.cfg.weight_init,
66
+ bias_init=self.cfg.bias_init,
67
+ ),
68
+ self.make_activation(self.cfg.activation),
69
+ ]
70
+ layers += [
71
+ self.make_linear(
72
+ self.cfg.n_neurons,
73
+ 4, # density 1 + features 3
74
+ bias=self.cfg.bias,
75
+ weight_init=self.cfg.weight_init,
76
+ bias_init=self.cfg.bias_init,
77
+ )
78
+ ]
79
+ self.layers = nn.Sequential(*layers)
80
+
81
+ def make_linear(
82
+ self,
83
+ dim_in,
84
+ dim_out,
85
+ bias=True,
86
+ weight_init=None,
87
+ bias_init=None,
88
+ ):
89
+ layer = nn.Linear(dim_in, dim_out, bias=bias)
90
+
91
+ if weight_init is None:
92
+ pass
93
+ elif weight_init == "kaiming_uniform":
94
+ torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ if bias:
99
+ if bias_init is None:
100
+ pass
101
+ elif bias_init == "zero":
102
+ torch.nn.init.zeros_(layer.bias)
103
+ else:
104
+ raise NotImplementedError
105
+
106
+ return layer
107
+
108
+ def make_activation(self, activation):
109
+ if activation == "relu":
110
+ return nn.ReLU(inplace=True)
111
+ elif activation == "silu":
112
+ return nn.SiLU(inplace=True)
113
+ else:
114
+ raise NotImplementedError
115
+
116
+ def forward(self, x):
117
+ inp_shape = x.shape[:-1]
118
+ x = x.reshape(-1, x.shape[-1])
119
+
120
+ features = self.layers(x)
121
+ features = features.reshape(*inp_shape, -1)
122
+ out = {"density": features[..., 0:1], "features": features[..., 1:4]}
123
+
124
+ return out