jetclustering / src /layers /mlp_readout_layer.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
812 Bytes
import torch
import torch.nn as nn
import torch.nn.functional as F
"""
MLP Layer used after graph vector representation
"""
class MLPReadout(nn.Module):
def __init__(self, input_dim, output_dim, L=2): #L=nb_hidden_layers
super().__init__()
list_FC_layers = [nn.Linear(input_dim, 20)]
list_FC_layers += [ nn.Linear( 20, 20 , bias=True ) for l in range(L - 1) ]
list_FC_layers.append(nn.Linear( 20 , output_dim , bias=True ))
self.FC_layers = nn.ModuleList(list_FC_layers)
self.drop_out = nn.Dropout(0.1)
self.L = L
def forward(self, x):
y = x
for l in range(self.L):
y = self.FC_layers[l](y)
y = F.relu(y)
y = self.drop_out(y)
y = self.FC_layers[self.L](y)
return y