jetclustering / src /layers /gcn_layer.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
4.55 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.nn.pytorch import GraphConv
"""
GCN: Graph Convolutional Networks
Thomas N. Kipf, Max Welling, Semi-Supervised Classification with Graph Convolutional Networks (ICLR 2017)
http://arxiv.org/abs/1609.02907
"""
# Sends a message of node feature h
# Equivalent to => return {'m': edges.src['h']}
msg = fn.copy_u("h", "m")
reduce = fn.mean("m", "h")
class NodeApplyModule(nn.Module):
# Update node feature h_v with (Wh_v+b)
def __init__(self, in_dim, out_dim):
super().__init__()
self.linear = nn.Linear(in_dim, out_dim)
def forward(self, node):
h = self.linear(node.data["h"])
return {"h": h}
class GCNLayer(nn.Module):
"""
Param: [in_dim, out_dim]
"""
def __init__(
self,
in_dim,
out_dim,
activation,
dropout,
batch_norm,
residual=False,
dgl_builtin=True,
):
super().__init__()
self.in_channels = in_dim
self.out_channels = out_dim
self.batch_norm = batch_norm
self.residual = residual
self.dgl_builtin = dgl_builtin
if in_dim != out_dim:
self.residual = False
self.batchnorm_h = nn.BatchNorm1d(out_dim)
self.activation = activation
self.dropout = nn.Dropout(dropout)
if self.dgl_builtin == False:
self.apply_mod = NodeApplyModule(in_dim, out_dim)
elif dgl.__version__ < "0.5":
self.conv = GraphConv(in_dim, out_dim, bias=False)
else:
self.conv = GraphConv(
in_dim, out_dim, allow_zero_in_degree=True, bias=False
)
self.sc_act = nn.ReLU()
def forward(self, g, feature):
h_in = feature # to be used for residual connection
if self.dgl_builtin == False:
g.ndata["h"] = feature
g.update_all(msg, reduce)
g.apply_nodes(func=self.apply_mod)
h = g.ndata["h"] # result of graph convolution
else:
h = self.conv(g, feature)
if self.batch_norm:
h = self.batchnorm_h(h) # batch normalization
if self.activation:
h = self.activation(h)
if self.residual:
h = h_in + h # residual connection
h = self.sc_act(h)
h = self.dropout(h)
return h
def __repr__(self):
return "{}(in_channels={}, out_channels={}, residual={})".format(
self.__class__.__name__, self.in_channels, self.out_channels, self.residual
)
class GCNLayer(nn.Module):
"""
Param: [in_dim, out_dim]
"""
def __init__(
self,
in_dim,
out_dim,
activation,
dropout,
batch_norm,
residual=False,
dgl_builtin=True,
):
super().__init__()
self.in_channels = in_dim
self.out_channels = out_dim
self.batch_norm = batch_norm
self.residual = residual
self.dgl_builtin = dgl_builtin
if in_dim != out_dim:
self.residual = False
self.batchnorm_h = nn.BatchNorm1d(out_dim)
self.activation = activation
self.dropout = nn.Dropout(dropout)
if self.dgl_builtin == False:
self.apply_mod = NodeApplyModule(in_dim, out_dim)
elif dgl.__version__ < "0.5":
self.conv = GraphConv(in_dim, out_dim, bias=False)
else:
self.conv = GraphConv(
in_dim, out_dim, allow_zero_in_degree=True, bias=False
)
self.sc_act = nn.ReLU()
def forward(self, g, feature):
h_in = feature # to be used for residual connection
if self.dgl_builtin == False:
g.ndata["h"] = feature
g.update_all(msg, reduce)
g.apply_nodes(func=self.apply_mod)
h = g.ndata["h"] # result of graph convolution
else:
h = self.conv(g, feature)
if self.batch_norm:
h = self.batchnorm_h(h) # batch normalization
if self.activation:
h = self.activation(h)
if self.residual:
h = h_in + h # residual connection
h = self.sc_act(h)
h = self.dropout(h)
return h
def __repr__(self):
return "{}(in_channels={}, out_channels={}, residual={})".format(
self.__class__.__name__, self.in_channels, self.out_channels, self.residual
)