File size: 812 Bytes
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import torch.nn as nn
import torch.nn.functional as F

"""
    MLP Layer used after graph vector representation
"""

class MLPReadout(nn.Module):

    def __init__(self, input_dim, output_dim, L=2): #L=nb_hidden_layers
        super().__init__()
        list_FC_layers = [nn.Linear(input_dim, 20)]
        list_FC_layers += [ nn.Linear( 20, 20 , bias=True ) for l in range(L - 1) ]
        list_FC_layers.append(nn.Linear( 20 , output_dim , bias=True ))
        self.FC_layers = nn.ModuleList(list_FC_layers)
        self.drop_out = nn.Dropout(0.1)
        self.L = L
        
    def forward(self, x):
        y = x
        for l in range(self.L):
            y = self.FC_layers[l](y)
            y = F.relu(y)
            y = self.drop_out(y)
        y = self.FC_layers[self.L](y)
        return y