import torch import torch.nn as nn import torch.nn.functional as F import dgl import dgl.function as fn import numpy as np from src.layers.GravNetConv3 import WeirdBatchNorm, knn_per_graph """ Graph Transformer Layer """ """ Util functions """ def src_dot_dst(src_field, dst_field, out_field): def func(edges): return { out_field: (edges.src[src_field] * edges.dst[dst_field]).sum( -1, keepdim=True ) } return func def scaled_exp(field, scale_constant): def func(edges): # clamp for softmax numerical stability return {field: torch.exp((edges.data[field] / scale_constant).clamp(-5, 5))} return func def src_dot_dst2(src_field, dst_field, out_field): def func(edges): return {out_field: (edges.src[src_field] - edges.dst[dst_field])} return func """ Single Attention Head """ class RelativePositionMessage(nn.Module): """ Compute the input feature from neighbors """ def __init__(self, out_dim): super(RelativePositionMessage, self).__init__() self.out_dim = out_dim def forward(self, edges): dist = -torch.sqrt((edges.src["G_h"] - edges.dst["G_h"]).pow(2).sum(-1) + 1e-6) distance = torch.exp((dist / np.sqrt(self.out_dim)).clamp(-5, 5)) score = (edges.src["K_h"] * edges.dst["Q_h"]).sum(-1, keepdim=True) score_e = torch.exp((score / np.sqrt(self.out_dim)).clamp(-5, 5)) print("checkling shapes", score_e.shape, distance.shape, edges.src["V_h"].shape) weight = torch.mul(score_e.view(-1, 1, 1), distance.view(-1, 1, 1)) v_h = torch.mul(weight, edges.src["V_h"]) return {"V1_h": v_h} class MultiHeadAttentionLayer(nn.Module): def __init__(self, n_neigh, in_dim, out_dim, num_heads, use_bias): super().__init__() self.out_dim = out_dim self.num_heads = num_heads self.n_neigh = n_neigh if use_bias: self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True) self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True) self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True) else: self.G = nn.Linear(in_dim, 3 * num_heads, bias=False) self.K = nn.Linear(in_dim, out_dim * num_heads, bias=False) self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=False) self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False) self.RelativePositionMessage = RelativePositionMessage(out_dim) # self.M1 = nn.Linear(1, out_dim, bias=False) # self.relu = nn.ReLU() # self.M2 = nn.Linear(out_dim, out_dim, bias=False) def propagate_attention(self, g): # Compute attention score # g.apply_edges(dist_calc("G_h", "G_h", "distance")) g.apply_edges(src_dot_dst("K_h", "Q_h", "score")) g.apply_edges(scaled_exp("score", np.sqrt(self.out_dim))) # g.apply_edges(scaled_exp("distance", np.sqrt(self.out_dim))) # g.apply_edges(score_times_dist("score_dis")) eids = g.edges() g.send_and_recv(eids, self.RelativePositionMessage, fn.sum("V1_h", "wV")) g.send_and_recv(eids, fn.copy_e("score", "score"), fn.sum("score", "z")) def forward(self, g, h): K_h = self.K(h) V_h = self.V(h) Q_h = self.Q(h) G_h = self.G(h) g.ndata["K_h"] = K_h.view(-1, self.num_heads, self.out_dim) g.ndata["Q_h"] = Q_h.view(-1, self.num_heads, self.out_dim) g.ndata["G_h"] = G_h.view(-1, self.num_heads, 3) g.ndata["V_h"] = V_h.view(-1, self.num_heads, self.out_dim) s_l = g.ndata["G_h"] gu = knn_per_graph(g, s_l.view(-1, 3), self.n_neigh) gu.ndata["K_h"] = g.ndata["K_h"] gu.ndata["V_h"] = g.ndata["V_h"] gu.ndata["Q_h"] = g.ndata["Q_h"] gu.ndata["G_h"] = g.ndata["G_h"] self.propagate_attention(gu) # print(gu.ndata["z"].shape) gu.ndata["z"] = gu.ndata["z"].view(-1, 1, 1).tile((1, 1, self.out_dim)) mask_empty = gu.ndata["z"] > 0 head_out = gu.ndata["wV"] head_out[mask_empty] = head_out[mask_empty] / (gu.ndata["z"][mask_empty]) gu.ndata["z"] = gu.ndata["z"][:, :, 0].view( gu.ndata["wV"].shape[0], self.num_heads, 1 ) return head_out class GraphTransformerLayer(nn.Module): """ Param: """ def __init__( self, neigh, in_dim, out_dim, num_heads, dropout=0.0, layer_norm=False, batch_norm=True, residual=False, use_bias=False, ): super().__init__() self.d_shape = 32 self.in_channels = in_dim self.out_channels = out_dim self.num_heads = num_heads self.dropout = dropout self.residual = residual self.layer_norm = layer_norm self.batch_norm = batch_norm self.neigh = neigh self.attention = MultiHeadAttentionLayer( self.neigh, self.d_shape, out_dim // num_heads, num_heads, use_bias ) self.O = nn.Linear(out_dim, out_dim) if self.layer_norm: self.layer_norm1 = nn.LayerNorm(out_dim) if self.batch_norm: self.batch_norm1 = nn.BatchNorm1d(out_dim) # FFN self.FFN_layer1 = nn.Linear(out_dim, out_dim * 2) self.FFN_layer2 = nn.Linear(out_dim * 2, out_dim) if self.layer_norm: self.layer_norm2 = nn.LayerNorm(out_dim) if self.batch_norm: self.batch_norm2 = nn.BatchNorm1d(out_dim) self.pre_gravnet = nn.Sequential( nn.Linear(self.in_channels, self.d_shape), #! Dense 1 nn.ELU(), nn.Linear(self.d_shape, self.d_shape), #! Dense 2 nn.ELU(), ) def forward(self, g, h): h_in1 = h # for first residual connection h = self.pre_gravnet(h) # multi-head attention out attn_out = self.attention(g, h) h = attn_out.view(-1, self.out_channels) # print("output of the attention ", h[0:2]) # if torch.sum(torch.isnan(h)) > 0: # print("output of the attention ALREADY NAN HERE") # 0 / 0 h = F.dropout(h, self.dropout, training=self.training) h = self.O(h) if self.residual: h = h_in1 + h # residual connection # print("output of residual ", h[0:2]) # if torch.sum(torch.isnan(h)) > 0: # print("output of the residual ALREADY NAN HERE") # 0 / 0 if self.layer_norm: h = self.layer_norm1(h) if self.batch_norm: h = self.batch_norm1(h) # # print("output of batchnorm ", h[0:2]) # if torch.sum(torch.isnan(h)) > 0: # print("output of the batchnorm ALREADY NAN HERE") # 0 / 0 h_in2 = h # for second residual connection # FFN h = self.FFN_layer1(h) h = F.relu(h) h = F.dropout(h, self.dropout, training=self.training) h = self.FFN_layer2(h) # print("output of FFN_layer2 ", h[0:2]) # if torch.sum(torch.isnan(h)) > 0: # print("output of the FFN_layer2 ALREADY NAN HERE") # 0 / 0 if self.residual: h = h_in2 + h # residual connection if self.layer_norm: h = self.layer_norm2(h) if self.batch_norm: h = self.batch_norm2(h) return h def __repr__(self): return "{}(in_channels={}, out_channels={}, heads={}, residual={})".format( self.__class__.__name__, self.in_channels, self.out_channels, self.num_heads, self.residual, ) # if torch.sum(torch.isnan(g.edata["vector"])) > 0: # print("VECTOR ALREADY NAN HERE") # 0 / 0 # e_data_m1 = self.M1(g.edata["vector"]) # e_data_m1 = self.relu(e_data_m1) # e_data_m1 = self.M2(e_data_m1) # print("e_data_m1", e_data_m1[0:2]) # g.edata["vector"] = e_data_m1 # print("wV", g.ndata["wV"][0:2]) # g.send_and_recv(eids, fn.copy_e("vector", "vector"), fn.sum("vector", "z")) # print("z", g.ndata["z"][0:2]) # if torch.sum(torch.isnan(g.ndata["z"])) > 0: # 0 / 0 # class MultiHeadAttentionLayer2(nn.Module): # def __init__(self, n_neigh, in_dim, out_dim, num_heads, use_bias): # super().__init__() # self.out_dim = out_dim # self.num_heads = num_heads # self.n_neigh = n_neigh # if use_bias: # self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True) # self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True) # self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True) # else: # self.K = nn.Linear(in_dim, 3 * num_heads, bias=False) # self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False) # self.M1 = nn.Linear(3, out_dim, bias=False) # self.relu = nn.ReLU() # self.M2 = nn.Linear(out_dim, out_dim, bias=False) # def propagate_attention(self, g): # # Compute attention score # g.apply_edges(src_dot_dst2("K_h", "K_h", "vector")) # , edges) # # if torch.sum(torch.isnan(g.edata["vector"])) > 0: # # print("VECTOR ALREADY NAN HERE") # # 0 / 0 # e_data_m1 = self.M1(g.edata["vector"]) # e_data_m1 = self.relu(e_data_m1) # e_data_m1 = self.M2(e_data_m1) # g.edata["vector"] = e_data_m1 # g.apply_edges(scaled_exp("vector", np.sqrt(self.out_dim))) # # if torch.sum(torch.isnan(g.edata["vector"])) > 0: # # print(g.edata["vector"]) # # Send weighted values to target nodes # eids = g.edges() # # vector attention to modulate individual channels # g.send_and_recv(eids, fn.u_mul_e("V_h", "vector", "V_h"), fn.sum("V_h", "wV")) # # print("wV", g.ndata["wV"][0:2]) # g.send_and_recv(eids, fn.copy_e("vector", "vector"), fn.sum("vector", "z")) # # print("z", g.ndata["z"][0:2]) # # if torch.sum(torch.isnan(g.ndata["z"])) > 0: # # 0 / 0 # def forward(self, g, h): # K_h = self.K(h) # V_h = self.V(h) # g.ndata["K_h"] = K_h.view(-1, self.num_heads, 3) # g.ndata["V_h"] = V_h.view(-1, self.num_heads, self.out_dim) # # print("q_h", Q_h[0:2]) # # print("K_h", K_h[0:2]) # # print("V_h", V_h[0:2]) # s_l = g.ndata["K_h"] # gu = knn_per_graph(g, s_l.view(-1, 3), self.n_neigh) # gu.ndata["K_h"] = g.ndata["K_h"] # gu.ndata["V_h"] = g.ndata["V_h"] # self.propagate_attention(gu) # # print(gu.ndata["z"].shape) # # gu.ndata["z"] = gu.ndata["z"].view(-1, 1, 1).tile((1, 1, self.out_dim)) # mask_empty = gu.ndata["z"] > 0 # head_out = gu.ndata["wV"] # # print(head_out.shape, gu.ndata["z"].shape) # head_out[mask_empty] = head_out[mask_empty] / (gu.ndata["z"][mask_empty]) # # g.ndata["z"] = g.ndata["z"][:, :, 0].view( # # g.ndata["wV"].shape[0], self.num_heads, 1 # # ) # # print("head_out", head_out[0:2]) # # if torch.sum(torch.isnan(head_out)) > 0: # # print("head_out ALREADY NAN HERE") # # 0 / 0 # return head_out