Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import dgl | |
import dgl.function as fn | |
import numpy as np | |
""" | |
Graph Transformer Layer | |
""" | |
""" | |
Util functions | |
""" | |
def src_dot_dst(src_field, dst_field, out_field): | |
def func(edges): | |
return {out_field: (edges.src[src_field] - edges.dst[dst_field])} | |
return func | |
def scaled_exp(field, scale_constant): | |
def func(edges): | |
# clamp for softmax numerical stability | |
return {field: torch.exp((edges.data[field] / scale_constant).clamp(-5, 5))} | |
return func | |
""" | |
Single Attention Head | |
""" | |
class MultiHeadAttentionLayer(nn.Module): | |
def __init__(self, in_dim, out_dim, num_heads, use_bias): | |
super().__init__() | |
self.out_dim = out_dim | |
self.num_heads = num_heads | |
if use_bias: | |
self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True) | |
self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True) | |
self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True) | |
else: | |
self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=False) | |
self.K = nn.Linear(in_dim, out_dim * num_heads, bias=False) | |
self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False) | |
self.M1 = nn.Linear(out_dim, out_dim, bias=False) | |
self.relu = nn.ReLU() | |
self.M2 = nn.Linear(out_dim, out_dim, bias=False) | |
def propagate_attention(self, g): | |
# Compute attention score | |
g.apply_edges(src_dot_dst("K_h", "Q_h", "vector")) # , edges) | |
# if torch.sum(torch.isnan(g.edata["vector"])) > 0: | |
# print("VECTOR ALREADY NAN HERE") | |
# 0 / 0 | |
e_data_m1 = self.M1(g.edata["vector"]) | |
e_data_m1 = self.relu(e_data_m1) | |
e_data_m1 = self.M2(e_data_m1) | |
# print("e_data_m1", e_data_m1[0:2]) | |
g.edata["vector"] = e_data_m1 | |
g.apply_edges(scaled_exp("vector", np.sqrt(self.out_dim))) | |
# print("vector", g.edata["vector"][0:2]) | |
# if torch.sum(torch.isnan(g.edata["vector"])) > 0: | |
# print(g.edata["vector"]) | |
# Send weighted values to target nodes | |
eids = g.edges() | |
# vector attention to modulate individual channels | |
g.send_and_recv(eids, fn.u_mul_e("V_h", "vector", "V_h"), fn.sum("V_h", "wV")) | |
# print("wV", g.ndata["wV"][0:2]) | |
g.send_and_recv(eids, fn.copy_e("vector", "vector"), fn.sum("vector", "z")) | |
# print("z", g.ndata["z"][0:2]) | |
# if torch.sum(torch.isnan(g.ndata["z"])) > 0: | |
# 0 / 0 | |
def forward(self, g, h): | |
Q_h = self.Q(h) | |
K_h = self.K(h) | |
V_h = self.V(h) | |
# if torch.sum(torch.isnan(Q_h)) > 0: | |
# print("Q_h ALREADY NAN HERE") | |
# 0 / 0 | |
# if torch.sum(torch.isnan(V_h)) > 0: | |
# print("V_h ALREADY NAN HERE") | |
# 0 / 0 | |
# if torch.sum(torch.isnan(K_h)) > 0: | |
# print("K_h ALREADY NAN HERE") | |
# 0 / 0 | |
# Reshaping into [num_nodes, num_heads, feat_dim] to | |
# get projections for multi-head attention | |
g.ndata["Q_h"] = Q_h.view(-1, self.num_heads, self.out_dim) | |
g.ndata["K_h"] = K_h.view(-1, self.num_heads, self.out_dim) | |
g.ndata["V_h"] = V_h.view(-1, self.num_heads, self.out_dim) | |
# print("q_h", Q_h[0:2]) | |
# print("K_h", K_h[0:2]) | |
# print("V_h", V_h[0:2]) | |
self.propagate_attention(g) | |
# g.ndata["z"] = g.ndata["z"].tile((1, 1, self.out_dim)) | |
mask_empty = g.ndata["z"] > 0 | |
head_out = g.ndata["wV"] | |
head_out[mask_empty] = head_out[mask_empty] / (g.ndata["z"][mask_empty]) | |
# g.ndata["z"] = g.ndata["z"][:, :, 0].view( | |
# g.ndata["wV"].shape[0], self.num_heads, 1 | |
# ) | |
# print("head_out", head_out[0:2]) | |
# if torch.sum(torch.isnan(head_out)) > 0: | |
# print("head_out ALREADY NAN HERE") | |
# 0 / 0 | |
return head_out | |
class GraphTransformerLayer(nn.Module): | |
""" | |
Param: | |
""" | |
def __init__( | |
self, | |
in_dim, | |
out_dim, | |
num_heads, | |
dropout=0.0, | |
layer_norm=False, | |
batch_norm=True, | |
residual=False, | |
use_bias=False, | |
): | |
super().__init__() | |
self.in_channels = in_dim | |
self.out_channels = out_dim | |
self.num_heads = num_heads | |
self.dropout = dropout | |
self.residual = residual | |
self.layer_norm = layer_norm | |
self.batch_norm = batch_norm | |
self.attention = MultiHeadAttentionLayer( | |
in_dim, out_dim // num_heads, num_heads, use_bias | |
) | |
self.O = nn.Linear(out_dim, out_dim) | |
if self.layer_norm: | |
self.layer_norm1 = nn.LayerNorm(out_dim) | |
if self.batch_norm: | |
self.batch_norm1 = nn.BatchNorm1d(out_dim) | |
# FFN | |
self.FFN_layer1 = nn.Linear(out_dim, out_dim * 2) | |
self.FFN_layer2 = nn.Linear(out_dim * 2, out_dim) | |
if self.layer_norm: | |
self.layer_norm2 = nn.LayerNorm(out_dim) | |
if self.batch_norm: | |
self.batch_norm2 = nn.BatchNorm1d(out_dim) | |
def forward(self, g, h): | |
h_in1 = h # for first residual connection | |
# multi-head attention out | |
attn_out = self.attention(g, h) | |
h = attn_out.view(-1, self.out_channels) | |
# print("output of the attention ", h[0:2]) | |
# if torch.sum(torch.isnan(h)) > 0: | |
# print("output of the attention ALREADY NAN HERE") | |
# 0 / 0 | |
h = F.dropout(h, self.dropout, training=self.training) | |
h = self.O(h) | |
if self.residual: | |
h = h_in1 + h # residual connection | |
# print("output of residual ", h[0:2]) | |
# if torch.sum(torch.isnan(h)) > 0: | |
# print("output of the residual ALREADY NAN HERE") | |
# 0 / 0 | |
if self.layer_norm: | |
h = self.layer_norm1(h) | |
if self.batch_norm: | |
h = self.batch_norm1(h) | |
# # print("output of batchnorm ", h[0:2]) | |
# if torch.sum(torch.isnan(h)) > 0: | |
# print("output of the batchnorm ALREADY NAN HERE") | |
# 0 / 0 | |
h_in2 = h # for second residual connection | |
# FFN | |
h = self.FFN_layer1(h) | |
h = F.relu(h) | |
h = F.dropout(h, self.dropout, training=self.training) | |
h = self.FFN_layer2(h) | |
# print("output of FFN_layer2 ", h[0:2]) | |
# if torch.sum(torch.isnan(h)) > 0: | |
# print("output of the FFN_layer2 ALREADY NAN HERE") | |
# 0 / 0 | |
if self.residual: | |
h = h_in2 + h # residual connection | |
if self.layer_norm: | |
h = self.layer_norm2(h) | |
if self.batch_norm: | |
h = self.batch_norm2(h) | |
return h | |
def __repr__(self): | |
return "{}(in_channels={}, out_channels={}, heads={}, residual={})".format( | |
self.__class__.__name__, | |
self.in_channels, | |
self.out_channels, | |
self.num_heads, | |
self.residual, | |
) | |