jetclustering / src /layers /graph_transformer_layer_pc1.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
11.5 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
import numpy as np
from src.layers.GravNetConv3 import WeirdBatchNorm, knn_per_graph
"""
Graph Transformer Layer
"""
"""
Util functions
"""
def src_dot_dst(src_field, dst_field, out_field):
def func(edges):
return {
out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(
-1, keepdim=True
)
}
return func
def scaled_exp(field, scale_constant):
def func(edges):
# clamp for softmax numerical stability
return {field: torch.exp((edges.data[field] / scale_constant).clamp(-5, 5))}
return func
def src_dot_dst2(src_field, dst_field, out_field):
def func(edges):
return {out_field: (edges.src[src_field] - edges.dst[dst_field])}
return func
"""
Single Attention Head
"""
class RelativePositionMessage(nn.Module):
"""
Compute the input feature from neighbors
"""
def __init__(self, out_dim):
super(RelativePositionMessage, self).__init__()
self.out_dim = out_dim
def forward(self, edges):
dist = -torch.sqrt((edges.src["G_h"] - edges.dst["G_h"]).pow(2).sum(-1) + 1e-6)
distance = torch.exp((dist / np.sqrt(self.out_dim)).clamp(-5, 5))
score = (edges.src["K_h"] * edges.dst["Q_h"]).sum(-1, keepdim=True)
score_e = torch.exp((score / np.sqrt(self.out_dim)).clamp(-5, 5))
print("checkling shapes", score_e.shape, distance.shape, edges.src["V_h"].shape)
weight = torch.mul(score_e.view(-1, 1, 1), distance.view(-1, 1, 1))
v_h = torch.mul(weight, edges.src["V_h"])
return {"V1_h": v_h}
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, n_neigh, in_dim, out_dim, num_heads, use_bias):
super().__init__()
self.out_dim = out_dim
self.num_heads = num_heads
self.n_neigh = n_neigh
if use_bias:
self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True)
self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True)
self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True)
else:
self.G = nn.Linear(in_dim, 3 * num_heads, bias=False)
self.K = nn.Linear(in_dim, out_dim * num_heads, bias=False)
self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=False)
self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False)
self.RelativePositionMessage = RelativePositionMessage(out_dim)
# self.M1 = nn.Linear(1, out_dim, bias=False)
# self.relu = nn.ReLU()
# self.M2 = nn.Linear(out_dim, out_dim, bias=False)
def propagate_attention(self, g):
# Compute attention score
# g.apply_edges(dist_calc("G_h", "G_h", "distance"))
g.apply_edges(src_dot_dst("K_h", "Q_h", "score"))
g.apply_edges(scaled_exp("score", np.sqrt(self.out_dim)))
# g.apply_edges(scaled_exp("distance", np.sqrt(self.out_dim)))
# g.apply_edges(score_times_dist("score_dis"))
eids = g.edges()
g.send_and_recv(eids, self.RelativePositionMessage, fn.sum("V1_h", "wV"))
g.send_and_recv(eids, fn.copy_e("score", "score"), fn.sum("score", "z"))
def forward(self, g, h):
K_h = self.K(h)
V_h = self.V(h)
Q_h = self.Q(h)
G_h = self.G(h)
g.ndata["K_h"] = K_h.view(-1, self.num_heads, self.out_dim)
g.ndata["Q_h"] = Q_h.view(-1, self.num_heads, self.out_dim)
g.ndata["G_h"] = G_h.view(-1, self.num_heads, 3)
g.ndata["V_h"] = V_h.view(-1, self.num_heads, self.out_dim)
s_l = g.ndata["G_h"]
gu = knn_per_graph(g, s_l.view(-1, 3), self.n_neigh)
gu.ndata["K_h"] = g.ndata["K_h"]
gu.ndata["V_h"] = g.ndata["V_h"]
gu.ndata["Q_h"] = g.ndata["Q_h"]
gu.ndata["G_h"] = g.ndata["G_h"]
self.propagate_attention(gu)
# print(gu.ndata["z"].shape)
gu.ndata["z"] = gu.ndata["z"].view(-1, 1, 1).tile((1, 1, self.out_dim))
mask_empty = gu.ndata["z"] > 0
head_out = gu.ndata["wV"]
head_out[mask_empty] = head_out[mask_empty] / (gu.ndata["z"][mask_empty])
gu.ndata["z"] = gu.ndata["z"][:, :, 0].view(
gu.ndata["wV"].shape[0], self.num_heads, 1
)
return head_out
class GraphTransformerLayer(nn.Module):
"""
Param:
"""
def __init__(
self,
neigh,
in_dim,
out_dim,
num_heads,
dropout=0.0,
layer_norm=False,
batch_norm=True,
residual=False,
use_bias=False,
):
super().__init__()
self.d_shape = 32
self.in_channels = in_dim
self.out_channels = out_dim
self.num_heads = num_heads
self.dropout = dropout
self.residual = residual
self.layer_norm = layer_norm
self.batch_norm = batch_norm
self.neigh = neigh
self.attention = MultiHeadAttentionLayer(
self.neigh, self.d_shape, out_dim // num_heads, num_heads, use_bias
)
self.O = nn.Linear(out_dim, out_dim)
if self.layer_norm:
self.layer_norm1 = nn.LayerNorm(out_dim)
if self.batch_norm:
self.batch_norm1 = nn.BatchNorm1d(out_dim)
# FFN
self.FFN_layer1 = nn.Linear(out_dim, out_dim * 2)
self.FFN_layer2 = nn.Linear(out_dim * 2, out_dim)
if self.layer_norm:
self.layer_norm2 = nn.LayerNorm(out_dim)
if self.batch_norm:
self.batch_norm2 = nn.BatchNorm1d(out_dim)
self.pre_gravnet = nn.Sequential(
nn.Linear(self.in_channels, self.d_shape), #! Dense 1
nn.ELU(),
nn.Linear(self.d_shape, self.d_shape), #! Dense 2
nn.ELU(),
)
def forward(self, g, h):
h_in1 = h # for first residual connection
h = self.pre_gravnet(h)
# multi-head attention out
attn_out = self.attention(g, h)
h = attn_out.view(-1, self.out_channels)
# print("output of the attention ", h[0:2])
# if torch.sum(torch.isnan(h)) > 0:
# print("output of the attention ALREADY NAN HERE")
# 0 / 0
h = F.dropout(h, self.dropout, training=self.training)
h = self.O(h)
if self.residual:
h = h_in1 + h # residual connection
# print("output of residual ", h[0:2])
# if torch.sum(torch.isnan(h)) > 0:
# print("output of the residual ALREADY NAN HERE")
# 0 / 0
if self.layer_norm:
h = self.layer_norm1(h)
if self.batch_norm:
h = self.batch_norm1(h)
# # print("output of batchnorm ", h[0:2])
# if torch.sum(torch.isnan(h)) > 0:
# print("output of the batchnorm ALREADY NAN HERE")
# 0 / 0
h_in2 = h # for second residual connection
# FFN
h = self.FFN_layer1(h)
h = F.relu(h)
h = F.dropout(h, self.dropout, training=self.training)
h = self.FFN_layer2(h)
# print("output of FFN_layer2 ", h[0:2])
# if torch.sum(torch.isnan(h)) > 0:
# print("output of the FFN_layer2 ALREADY NAN HERE")
# 0 / 0
if self.residual:
h = h_in2 + h # residual connection
if self.layer_norm:
h = self.layer_norm2(h)
if self.batch_norm:
h = self.batch_norm2(h)
return h
def __repr__(self):
return "{}(in_channels={}, out_channels={}, heads={}, residual={})".format(
self.__class__.__name__,
self.in_channels,
self.out_channels,
self.num_heads,
self.residual,
)
# if torch.sum(torch.isnan(g.edata["vector"])) > 0:
# print("VECTOR ALREADY NAN HERE")
# 0 / 0
# e_data_m1 = self.M1(g.edata["vector"])
# e_data_m1 = self.relu(e_data_m1)
# e_data_m1 = self.M2(e_data_m1)
# print("e_data_m1", e_data_m1[0:2])
# g.edata["vector"] = e_data_m1
# print("wV", g.ndata["wV"][0:2])
# g.send_and_recv(eids, fn.copy_e("vector", "vector"), fn.sum("vector", "z"))
# print("z", g.ndata["z"][0:2])
# if torch.sum(torch.isnan(g.ndata["z"])) > 0:
# 0 / 0
# class MultiHeadAttentionLayer2(nn.Module):
# def __init__(self, n_neigh, in_dim, out_dim, num_heads, use_bias):
# super().__init__()
# self.out_dim = out_dim
# self.num_heads = num_heads
# self.n_neigh = n_neigh
# if use_bias:
# self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True)
# self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True)
# self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True)
# else:
# self.K = nn.Linear(in_dim, 3 * num_heads, bias=False)
# self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False)
# self.M1 = nn.Linear(3, out_dim, bias=False)
# self.relu = nn.ReLU()
# self.M2 = nn.Linear(out_dim, out_dim, bias=False)
# def propagate_attention(self, g):
# # Compute attention score
# g.apply_edges(src_dot_dst2("K_h", "K_h", "vector")) # , edges)
# # if torch.sum(torch.isnan(g.edata["vector"])) > 0:
# # print("VECTOR ALREADY NAN HERE")
# # 0 / 0
# e_data_m1 = self.M1(g.edata["vector"])
# e_data_m1 = self.relu(e_data_m1)
# e_data_m1 = self.M2(e_data_m1)
# g.edata["vector"] = e_data_m1
# g.apply_edges(scaled_exp("vector", np.sqrt(self.out_dim)))
# # if torch.sum(torch.isnan(g.edata["vector"])) > 0:
# # print(g.edata["vector"])
# # Send weighted values to target nodes
# eids = g.edges()
# # vector attention to modulate individual channels
# g.send_and_recv(eids, fn.u_mul_e("V_h", "vector", "V_h"), fn.sum("V_h", "wV"))
# # print("wV", g.ndata["wV"][0:2])
# g.send_and_recv(eids, fn.copy_e("vector", "vector"), fn.sum("vector", "z"))
# # print("z", g.ndata["z"][0:2])
# # if torch.sum(torch.isnan(g.ndata["z"])) > 0:
# # 0 / 0
# def forward(self, g, h):
# K_h = self.K(h)
# V_h = self.V(h)
# g.ndata["K_h"] = K_h.view(-1, self.num_heads, 3)
# g.ndata["V_h"] = V_h.view(-1, self.num_heads, self.out_dim)
# # print("q_h", Q_h[0:2])
# # print("K_h", K_h[0:2])
# # print("V_h", V_h[0:2])
# s_l = g.ndata["K_h"]
# gu = knn_per_graph(g, s_l.view(-1, 3), self.n_neigh)
# gu.ndata["K_h"] = g.ndata["K_h"]
# gu.ndata["V_h"] = g.ndata["V_h"]
# self.propagate_attention(gu)
# # print(gu.ndata["z"].shape)
# # gu.ndata["z"] = gu.ndata["z"].view(-1, 1, 1).tile((1, 1, self.out_dim))
# mask_empty = gu.ndata["z"] > 0
# head_out = gu.ndata["wV"]
# # print(head_out.shape, gu.ndata["z"].shape)
# head_out[mask_empty] = head_out[mask_empty] / (gu.ndata["z"][mask_empty])
# # g.ndata["z"] = g.ndata["z"][:, :, 0].view(
# # g.ndata["wV"].shape[0], self.num_heads, 1
# # )
# # print("head_out", head_out[0:2])
# # if torch.sum(torch.isnan(head_out)) > 0:
# # print("head_out ALREADY NAN HERE")
# # 0 / 0
# return head_out