File size: 1,295 Bytes
6ae852e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.nn import global_max_pool as gmp


class GAT(nn.Module):
    r"""
    From `GraphDTA <https://doi.org/10.1093/bioinformatics/btaa921>`_ (Nguyen et al., 2020),
    based on `Graph Attention Network <https://arxiv.org/abs/1710.10903>`_ (Veličković et al., 2018).
    """
    def __init__(
        self,
        num_features: int,
        out_channels: int,
        dropout: float
    ):
        super().__init__()

        self.dropout = dropout
        self.gcn1 = GATConv(num_features, num_features, heads=10, dropout=dropout)
        self.gcn2 = GATConv(num_features * 10, out_channels, dropout=dropout)
        self.fc_g1 = nn.Linear(out_channels, out_channels)
        self.relu = nn.ReLU()

    def forward(self, data):
        # graph input feed-forward
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.elu(self.gcn1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.gcn2(x, edge_index)
        x = self.relu(x)
        x = gmp(x, batch)  # global max pooling
        x = self.fc_g1(x)
        x = self.relu(x)

        return x