|
import torch |
|
import torch.nn as nn |
|
|
|
class VanillaMLP(nn.Module): |
|
def __init__(self, input_dim, output_dim, out_activation, n_hidden_layers=4, n_neurons=64, activation="ReLU"): |
|
super().__init__() |
|
self.n_neurons = n_neurons |
|
self.n_hidden_layers = n_hidden_layers |
|
self.activation = activation |
|
self.out_activation = out_activation |
|
layers = [ |
|
self.make_linear(input_dim, self.n_neurons, is_first=True, is_last=False), |
|
self.make_activation(), |
|
] |
|
for i in range(self.n_hidden_layers - 1): |
|
layers += [ |
|
self.make_linear( |
|
self.n_neurons, self.n_neurons, is_first=False, is_last=False |
|
), |
|
self.make_activation(), |
|
] |
|
layers += [ |
|
self.make_linear(self.n_neurons, output_dim, is_first=False, is_last=True) |
|
] |
|
if self.out_activation == "sigmoid": |
|
layers += [nn.Sigmoid()] |
|
elif self.out_activation == "tanh": |
|
layers += [nn.Tanh()] |
|
elif self.out_activation == "hardtanh": |
|
layers += [nn.Hardtanh()] |
|
elif self.out_activation == "GELU": |
|
layers += [nn.GELU()] |
|
elif self.out_activation == "RELU": |
|
layers += [nn.ReLU()] |
|
else: |
|
raise NotImplementedError |
|
self.layers = nn.Sequential(*layers) |
|
|
|
def forward(self, x, split_size=100000): |
|
with torch.cuda.amp.autocast(enabled=False): |
|
out = self.layers(x) |
|
return out |
|
|
|
def make_linear(self, dim_in, dim_out, is_first, is_last): |
|
layer = nn.Linear(dim_in, dim_out, bias=False) |
|
return layer |
|
|
|
def make_activation(self): |
|
if self.activation == "ReLU": |
|
return nn.ReLU(inplace=True) |
|
elif self.activation == "GELU": |
|
return nn.GELU() |
|
else: |
|
raise NotImplementedError |