Spaces:
Sleeping
Sleeping
File size: 4,552 Bytes
e75a247 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.nn.pytorch import GraphConv
"""
GCN: Graph Convolutional Networks
Thomas N. Kipf, Max Welling, Semi-Supervised Classification with Graph Convolutional Networks (ICLR 2017)
http://arxiv.org/abs/1609.02907
"""
# Sends a message of node feature h
# Equivalent to => return {'m': edges.src['h']}
msg = fn.copy_u("h", "m")
reduce = fn.mean("m", "h")
class NodeApplyModule(nn.Module):
# Update node feature h_v with (Wh_v+b)
def __init__(self, in_dim, out_dim):
super().__init__()
self.linear = nn.Linear(in_dim, out_dim)
def forward(self, node):
h = self.linear(node.data["h"])
return {"h": h}
class GCNLayer(nn.Module):
"""
Param: [in_dim, out_dim]
"""
def __init__(
self,
in_dim,
out_dim,
activation,
dropout,
batch_norm,
residual=False,
dgl_builtin=True,
):
super().__init__()
self.in_channels = in_dim
self.out_channels = out_dim
self.batch_norm = batch_norm
self.residual = residual
self.dgl_builtin = dgl_builtin
if in_dim != out_dim:
self.residual = False
self.batchnorm_h = nn.BatchNorm1d(out_dim)
self.activation = activation
self.dropout = nn.Dropout(dropout)
if self.dgl_builtin == False:
self.apply_mod = NodeApplyModule(in_dim, out_dim)
elif dgl.__version__ < "0.5":
self.conv = GraphConv(in_dim, out_dim, bias=False)
else:
self.conv = GraphConv(
in_dim, out_dim, allow_zero_in_degree=True, bias=False
)
self.sc_act = nn.ReLU()
def forward(self, g, feature):
h_in = feature # to be used for residual connection
if self.dgl_builtin == False:
g.ndata["h"] = feature
g.update_all(msg, reduce)
g.apply_nodes(func=self.apply_mod)
h = g.ndata["h"] # result of graph convolution
else:
h = self.conv(g, feature)
if self.batch_norm:
h = self.batchnorm_h(h) # batch normalization
if self.activation:
h = self.activation(h)
if self.residual:
h = h_in + h # residual connection
h = self.sc_act(h)
h = self.dropout(h)
return h
def __repr__(self):
return "{}(in_channels={}, out_channels={}, residual={})".format(
self.__class__.__name__, self.in_channels, self.out_channels, self.residual
)
class GCNLayer(nn.Module):
"""
Param: [in_dim, out_dim]
"""
def __init__(
self,
in_dim,
out_dim,
activation,
dropout,
batch_norm,
residual=False,
dgl_builtin=True,
):
super().__init__()
self.in_channels = in_dim
self.out_channels = out_dim
self.batch_norm = batch_norm
self.residual = residual
self.dgl_builtin = dgl_builtin
if in_dim != out_dim:
self.residual = False
self.batchnorm_h = nn.BatchNorm1d(out_dim)
self.activation = activation
self.dropout = nn.Dropout(dropout)
if self.dgl_builtin == False:
self.apply_mod = NodeApplyModule(in_dim, out_dim)
elif dgl.__version__ < "0.5":
self.conv = GraphConv(in_dim, out_dim, bias=False)
else:
self.conv = GraphConv(
in_dim, out_dim, allow_zero_in_degree=True, bias=False
)
self.sc_act = nn.ReLU()
def forward(self, g, feature):
h_in = feature # to be used for residual connection
if self.dgl_builtin == False:
g.ndata["h"] = feature
g.update_all(msg, reduce)
g.apply_nodes(func=self.apply_mod)
h = g.ndata["h"] # result of graph convolution
else:
h = self.conv(g, feature)
if self.batch_norm:
h = self.batchnorm_h(h) # batch normalization
if self.activation:
h = self.activation(h)
if self.residual:
h = h_in + h # residual connection
h = self.sc_act(h)
h = self.dropout(h)
return h
def __repr__(self):
return "{}(in_channels={}, out_channels={}, residual={})".format(
self.__class__.__name__, self.in_channels, self.out_channels, self.residual
)
|