jetclustering / src /layers /graph_transformer_edge_layer.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
7.23 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
import numpy as np
"""
Graph Transformer Layer with edge features
"""
"""
Util functions
"""
def src_dot_dst(src_field, dst_field, out_field):
def func(edges):
return {out_field: (edges.src[src_field] * edges.dst[dst_field])}
return func
def scaling(field, scale_constant):
def func(edges):
return {field: ((edges.data[field]) / scale_constant)}
return func
# Improving implicit attention scores with explicit edge features, if available
def imp_exp_attn(implicit_attn, explicit_edge):
"""
implicit_attn: the output of K Q
explicit_edge: the explicit edge features
"""
def func(edges):
return {implicit_attn: (edges.data[implicit_attn] * edges.data[explicit_edge])}
return func
# To copy edge features to be passed to FFN_e
def out_edge_features(edge_feat):
def func(edges):
return {'e_out': edges.data[edge_feat]}
return func
def exp(field):
def func(edges):
# clamp for softmax numerical stability
return {field: torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))}
return func
"""
Single Attention Head
"""
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, in_dim, out_dim, num_heads, use_bias):
super().__init__()
self.out_dim = out_dim
self.num_heads = num_heads
if use_bias:
self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True)
self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True)
self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True)
self.proj_e = nn.Linear(in_dim, out_dim * num_heads, bias=True)
else:
self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=False)
self.K = nn.Linear(in_dim, out_dim * num_heads, bias=False)
self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False)
self.proj_e = nn.Linear(in_dim, out_dim * num_heads, bias=False)
def propagate_attention(self, g):
# Compute attention score
g.apply_edges(src_dot_dst('K_h', 'Q_h', 'score')) #, edges)
# scaling
g.apply_edges(scaling('score', np.sqrt(self.out_dim)))
# Use available edge features to modify the scores
g.apply_edges(imp_exp_attn('score', 'proj_e'))
# Copy edge features as e_out to be passed to FFN_e
g.apply_edges(out_edge_features('score'))
# softmax
g.apply_edges(exp('score'))
# Send weighted values to target nodes
eids = g.edges()
g.send_and_recv(eids, fn.src_mul_edge('V_h', 'score', 'V_h'), fn.sum('V_h', 'wV'))
g.send_and_recv(eids, fn.copy_edge('score', 'score'), fn.sum('score', 'z'))
def forward(self, g, h, e):
Q_h = self.Q(h)
K_h = self.K(h)
V_h = self.V(h)
proj_e = self.proj_e(e)
# Reshaping into [num_nodes, num_heads, feat_dim] to
# get projections for multi-head attention
g.ndata['Q_h'] = Q_h.view(-1, self.num_heads, self.out_dim)
g.ndata['K_h'] = K_h.view(-1, self.num_heads, self.out_dim)
g.ndata['V_h'] = V_h.view(-1, self.num_heads, self.out_dim)
g.edata['proj_e'] = proj_e.view(-1, self.num_heads, self.out_dim)
self.propagate_attention(g)
h_out = g.ndata['wV'] / (g.ndata['z'] + torch.full_like(g.ndata['z'], 1e-6)) # adding eps to all values here
e_out = g.edata['e_out']
return h_out, e_out
class GraphTransformerLayer(nn.Module):
"""
Param:
"""
def __init__(self, in_dim, out_dim, num_heads, dropout=0.0, layer_norm=False, batch_norm=True, residual=True, use_bias=False):
super().__init__()
self.in_channels = in_dim
self.out_channels = out_dim
self.num_heads = num_heads
self.dropout = dropout
self.residual = residual
self.layer_norm = layer_norm
self.batch_norm = batch_norm
self.attention = MultiHeadAttentionLayer(in_dim, out_dim//num_heads, num_heads, use_bias)
self.O_h = nn.Linear(out_dim, out_dim)
self.O_e = nn.Linear(out_dim, out_dim)
if self.layer_norm:
self.layer_norm1_h = nn.LayerNorm(out_dim)
self.layer_norm1_e = nn.LayerNorm(out_dim)
if self.batch_norm:
self.batch_norm1_h = nn.BatchNorm1d(out_dim)
self.batch_norm1_e = nn.BatchNorm1d(out_dim)
# FFN for h
self.FFN_h_layer1 = nn.Linear(out_dim, out_dim*2)
self.FFN_h_layer2 = nn.Linear(out_dim*2, out_dim)
# FFN for e
self.FFN_e_layer1 = nn.Linear(out_dim, out_dim*2)
self.FFN_e_layer2 = nn.Linear(out_dim*2, out_dim)
if self.layer_norm:
self.layer_norm2_h = nn.LayerNorm(out_dim)
self.layer_norm2_e = nn.LayerNorm(out_dim)
if self.batch_norm:
self.batch_norm2_h = nn.BatchNorm1d(out_dim)
self.batch_norm2_e = nn.BatchNorm1d(out_dim)
def forward(self, g, h, e):
h_in1 = h # for first residual connection
e_in1 = e # for first residual connection
# multi-head attention out
h_attn_out, e_attn_out = self.attention(g, h, e)
h = h_attn_out.view(-1, self.out_channels)
e = e_attn_out.view(-1, self.out_channels)
#h = F.dropout(h, self.dropout, training=self.training)
#e = F.dropout(e, self.dropout, training=self.training)
h = self.O_h(h)
e = self.O_e(e)
if self.residual:
h = h_in1 + h # residual connection
e = e_in1 + e # residual connection
if self.layer_norm:
h = self.layer_norm1_h(h)
e = self.layer_norm1_e(e)
if self.batch_norm:
h = self.batch_norm1_h(h)
e = self.batch_norm1_e(e)
h_in2 = h # for second residual connection
e_in2 = e # for second residual connection
# FFN for h
h = self.FFN_h_layer1(h)
h = F.relu(h)
h = F.dropout(h, self.dropout, training=self.training)
h = self.FFN_h_layer2(h)
# FFN for e
e = self.FFN_e_layer1(e)
e = F.relu(e)
e = F.dropout(e, self.dropout, training=self.training)
e = self.FFN_e_layer2(e)
if self.residual:
h = h_in2 + h # residual connection
e = e_in2 + e # residual connection
if self.layer_norm:
h = self.layer_norm2_h(h)
e = self.layer_norm2_e(e)
if self.batch_norm:
h = self.batch_norm2_h(h)
e = self.batch_norm2_e(e)
return h, e
def __repr__(self):
return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.num_heads, self.residual)