import torch import torch.nn as nn import torch.nn.functional as F import dgl import dgl.function as fn from dgl.nn.pytorch import GraphConv """ GCN: Graph Convolutional Networks Thomas N. Kipf, Max Welling, Semi-Supervised Classification with Graph Convolutional Networks (ICLR 2017) http://arxiv.org/abs/1609.02907 """ # Sends a message of node feature h # Equivalent to => return {'m': edges.src['h']} msg = fn.copy_u("h", "m") reduce = fn.mean("m", "h") class NodeApplyModule(nn.Module): # Update node feature h_v with (Wh_v+b) def __init__(self, in_dim, out_dim): super().__init__() self.linear = nn.Linear(in_dim, out_dim) def forward(self, node): h = self.linear(node.data["h"]) return {"h": h} class GCNLayer(nn.Module): """ Param: [in_dim, out_dim] """ def __init__( self, in_dim, out_dim, activation, dropout, batch_norm, residual=False, dgl_builtin=True, ): super().__init__() self.in_channels = in_dim self.out_channels = out_dim self.batch_norm = batch_norm self.residual = residual self.dgl_builtin = dgl_builtin if in_dim != out_dim: self.residual = False self.batchnorm_h = nn.BatchNorm1d(out_dim) self.activation = activation self.dropout = nn.Dropout(dropout) if self.dgl_builtin == False: self.apply_mod = NodeApplyModule(in_dim, out_dim) elif dgl.__version__ < "0.5": self.conv = GraphConv(in_dim, out_dim, bias=False) else: self.conv = GraphConv( in_dim, out_dim, allow_zero_in_degree=True, bias=False ) self.sc_act = nn.ReLU() def forward(self, g, feature): h_in = feature # to be used for residual connection if self.dgl_builtin == False: g.ndata["h"] = feature g.update_all(msg, reduce) g.apply_nodes(func=self.apply_mod) h = g.ndata["h"] # result of graph convolution else: h = self.conv(g, feature) if self.batch_norm: h = self.batchnorm_h(h) # batch normalization if self.activation: h = self.activation(h) if self.residual: h = h_in + h # residual connection h = self.sc_act(h) h = self.dropout(h) return h def __repr__(self): return "{}(in_channels={}, out_channels={}, residual={})".format( self.__class__.__name__, self.in_channels, self.out_channels, self.residual ) class GCNLayer(nn.Module): """ Param: [in_dim, out_dim] """ def __init__( self, in_dim, out_dim, activation, dropout, batch_norm, residual=False, dgl_builtin=True, ): super().__init__() self.in_channels = in_dim self.out_channels = out_dim self.batch_norm = batch_norm self.residual = residual self.dgl_builtin = dgl_builtin if in_dim != out_dim: self.residual = False self.batchnorm_h = nn.BatchNorm1d(out_dim) self.activation = activation self.dropout = nn.Dropout(dropout) if self.dgl_builtin == False: self.apply_mod = NodeApplyModule(in_dim, out_dim) elif dgl.__version__ < "0.5": self.conv = GraphConv(in_dim, out_dim, bias=False) else: self.conv = GraphConv( in_dim, out_dim, allow_zero_in_degree=True, bias=False ) self.sc_act = nn.ReLU() def forward(self, g, feature): h_in = feature # to be used for residual connection if self.dgl_builtin == False: g.ndata["h"] = feature g.update_all(msg, reduce) g.apply_nodes(func=self.apply_mod) h = g.ndata["h"] # result of graph convolution else: h = self.conv(g, feature) if self.batch_norm: h = self.batchnorm_h(h) # batch normalization if self.activation: h = self.activation(h) if self.residual: h = h_in + h # residual connection h = self.sc_act(h) h = self.dropout(h) return h def __repr__(self): return "{}(in_channels={}, out_channels={}, residual={})".format( self.__class__.__name__, self.in_channels, self.out_channels, self.residual )