|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import pdb |
|
|
|
class IntraGraphAttention(nn.Module): |
|
def __init__(self, d_node, d_edge, num_heads, negative_slope=0.2): |
|
super(IntraGraphAttention, self).__init__() |
|
assert d_node % num_heads == 0, "d_node must be divisible by num_heads" |
|
assert d_edge % num_heads == 0, "d_edge must be divisible by num_heads" |
|
|
|
self.num_heads = num_heads |
|
self.d_k = d_node // num_heads |
|
self.d_edge_head = d_edge // num_heads |
|
|
|
self.Wn = nn.Linear(d_node, d_node) |
|
self.Wh = nn.Linear(self.d_k, self.d_k) |
|
self.We = nn.Linear(d_edge, d_edge) |
|
self.Wn_2 = nn.Linear(d_node, d_node) |
|
self.We_2 = nn.Linear(d_edge, d_edge) |
|
self.attn_linear = nn.Linear(self.d_k * 2 + self.d_edge_head, 1, bias=False) |
|
self.edge_linear = nn.Linear(self.d_k * 2 + self.d_edge_head, self.d_edge_head) |
|
|
|
self.out_proj_node = nn.Linear(d_node, d_node) |
|
self.out_proj_edge = nn.Linear(d_edge, d_edge) |
|
|
|
self.leaky_relu = nn.LeakyReLU(negative_slope) |
|
|
|
def forward(self, node_representation, edge_representation): |
|
|
|
|
|
|
|
B, L, d_node = node_representation.size() |
|
d_edge = edge_representation.size(-1) |
|
|
|
|
|
node_proj = self.Wn(node_representation).view(B, L, self.num_heads, self.d_k) |
|
edge_proj = self.We(edge_representation).view(B, L, L, self.num_heads, self.d_edge_head) |
|
|
|
|
|
new_node_representation = self.single_head_attention_node(node_proj, edge_proj) |
|
|
|
concatenated_node_rep = new_node_representation.view(B, L, -1) |
|
new_node_representation = self.out_proj_node(concatenated_node_rep) |
|
|
|
|
|
node_proj_2 = self.Wn_2(new_node_representation).view(B, L, self.num_heads, self.d_k) |
|
edge_proj_2 = self.We_2(edge_representation).view(B, L, L, self.num_heads, self.d_edge_head) |
|
|
|
new_edge_representation = self.single_head_attention_edge(node_proj_2, edge_proj_2) |
|
|
|
concatenated_edge_rep = new_edge_representation.view(B, L, L, -1) |
|
new_edge_representation = self.out_proj_edge(concatenated_edge_rep) |
|
|
|
return new_node_representation, new_edge_representation |
|
|
|
def single_head_attention_node(self, node_representation, edge_representation): |
|
B, L, num_heads, d_k = node_representation.size() |
|
d_edge_head = edge_representation.size(-1) |
|
|
|
hi = node_representation.unsqueeze(2) |
|
hj = node_representation.unsqueeze(1) |
|
|
|
hi_hj_concat = torch.cat([hi.expand(-1, -1, L, -1, -1), |
|
hj.expand(-1, L, -1, -1, -1), |
|
edge_representation], dim=-1) |
|
|
|
attention_scores = self.attn_linear(hi_hj_concat).squeeze(-1) |
|
|
|
|
|
mask = torch.eye(L).bool().unsqueeze(0).unsqueeze(-1).to(node_representation.device) |
|
attention_scores.masked_fill_(mask, float('-inf')) |
|
|
|
attention_probs = F.softmax(self.leaky_relu(attention_scores), dim=2) |
|
|
|
|
|
node_representation_Wh = self.Wh(node_representation) |
|
node_representation_Wh = node_representation_Wh.permute(0, 2, 1, 3) |
|
|
|
aggregated_features = torch.matmul(attention_probs.permute(0, 3, 1, 2), node_representation_Wh) |
|
aggregated_features = aggregated_features.permute(0, 2, 1, 3) |
|
|
|
new_node_representation = node_representation + self.leaky_relu(aggregated_features) |
|
|
|
return new_node_representation |
|
|
|
def single_head_attention_edge(self, node_representation, edge_representation): |
|
|
|
B, L, num_heads, d_k = node_representation.size() |
|
d_edge_head = edge_representation.size(-1) |
|
|
|
hi = node_representation.unsqueeze(2) |
|
hj = node_representation.unsqueeze(1) |
|
|
|
hi_hj_concat = torch.cat([edge_representation, hi.expand(-1, -1, L, -1, -1), hj.expand(-1, L, -1, -1, -1)], dim=-1) |
|
|
|
new_edge_representation = self.edge_linear(hi_hj_concat) |
|
|
|
return new_edge_representation |
|
|
|
|
|
class DiffEmbeddingLayer(nn.Module): |
|
def __init__(self, d_node): |
|
super(DiffEmbeddingLayer, self).__init__() |
|
self.W_delta = nn.Linear(d_node, d_node) |
|
|
|
def forward(self, wt_node, mut_node): |
|
delta_h = mut_node - wt_node |
|
diff_vec = torch.relu(self.W_delta(delta_h)) |
|
return diff_vec |
|
|
|
|
|
class MIM(nn.Module): |
|
def __init__(self, d_node, d_edge, d_diff, num_heads, negative_slope=0.2): |
|
super(MIM, self).__init__() |
|
assert d_node % num_heads == 0, "d_node must be divisible by num_heads" |
|
assert d_edge % num_heads == 0, "d_edge must be divisible by num_heads" |
|
assert d_diff % num_heads == 0, "d_diff must be divisible by num_heads" |
|
|
|
self.num_heads = num_heads |
|
self.d_k = d_node // num_heads |
|
self.d_edge_head = d_edge // num_heads |
|
self.d_diff_head = d_diff // num_heads |
|
|
|
self.Wn = nn.Linear(d_node, d_node) |
|
self.Wh = nn.Linear(self.d_k, self.d_k) |
|
self.We = nn.Linear(d_edge, d_edge) |
|
self.Wn_2 = nn.Linear(d_node, d_node) |
|
self.We_2 = nn.Linear(d_edge, d_edge) |
|
self.Wd = nn.Linear(d_diff, d_diff) |
|
self.Wd_2 = nn.Linear(d_diff, d_diff) |
|
self.attn_linear = nn.Linear(self.d_k * 2 + self.d_edge_head + 2 * self.d_diff_head, 1, bias=False) |
|
self.edge_linear = nn.Linear(self.d_k * 2 + self.d_edge_head + 2 * self.d_diff_head, self.d_edge_head) |
|
|
|
self.out_proj_node = nn.Linear(d_node, d_node) |
|
self.out_proj_edge = nn.Linear(d_edge, d_edge) |
|
|
|
self.leaky_relu = nn.LeakyReLU(negative_slope) |
|
|
|
def forward(self, node_representation, edge_representation, diff_vec): |
|
|
|
|
|
|
|
|
|
B, L, d_node = node_representation.size() |
|
d_edge = edge_representation.size(-1) |
|
|
|
|
|
node_proj = self.Wn(node_representation).view(B, L, self.num_heads, self.d_k) |
|
edge_proj = self.We(edge_representation).view(B, L, L, self.num_heads, self.d_edge_head) |
|
diff_proj = self.Wd(diff_vec).view(B, L, self.num_heads, self.d_diff_head) |
|
|
|
|
|
new_node_representation = self.single_head_attention_node(node_proj, edge_proj, diff_proj) |
|
|
|
concatenated_node_rep = new_node_representation.view(B, L, -1) |
|
new_node_representation = self.out_proj_node(concatenated_node_rep) |
|
|
|
|
|
node_proj_2 = self.Wn_2(new_node_representation).view(B, L, self.num_heads, self.d_k) |
|
edge_proj_2 = self.We_2(edge_representation).view(B, L, L, self.num_heads, self.d_edge_head) |
|
diff_proj_2 = self.Wd_2(diff_vec).view(B, L, self.num_heads, self.d_diff_head) |
|
|
|
new_edge_representation = self.single_head_attention_edge(node_proj_2, edge_proj_2, diff_proj_2) |
|
|
|
concatenated_edge_rep = new_edge_representation.view(B, L, L, -1) |
|
new_edge_representation = self.out_proj_edge(concatenated_edge_rep) |
|
|
|
return new_node_representation, new_edge_representation |
|
|
|
def single_head_attention_node(self, node_representation, edge_representation, diff_vec): |
|
|
|
B, L, num_heads, d_k = node_representation.size() |
|
d_edge_head = edge_representation.size(-1) |
|
d_diff_head = diff_vec.size(-1) |
|
|
|
hi = node_representation.unsqueeze(2) |
|
hj = node_representation.unsqueeze(1) |
|
diff_i = diff_vec.unsqueeze(2) |
|
diff_j = diff_vec.unsqueeze(1) |
|
|
|
hi_hj_concat = torch.cat([ |
|
hi.expand(-1, -1, L, -1, -1), |
|
hj.expand(-1, L, -1, -1, -1), |
|
edge_representation, |
|
diff_i.expand(-1, -1, L, -1, -1), |
|
diff_j.expand(-1, L, -1, -1, -1) |
|
], dim=-1) |
|
|
|
attention_scores = self.attn_linear(hi_hj_concat).squeeze(-1) |
|
|
|
|
|
mask = torch.eye(L).bool().unsqueeze(0).unsqueeze(-1).to(node_representation.device) |
|
attention_scores.masked_fill_(mask, float('-inf')) |
|
|
|
attention_probs = F.softmax(self.leaky_relu(attention_scores), dim=2) |
|
|
|
|
|
node_representation_Wh = self.Wh(node_representation) |
|
node_representation_Wh = node_representation_Wh.permute(0, 2, 1, 3) |
|
|
|
aggregated_features = torch.matmul(attention_probs.permute(0, 3, 1, 2), node_representation_Wh) |
|
aggregated_features = aggregated_features.permute(0, 2, 1, 3) |
|
|
|
new_node_representation = node_representation + self.leaky_relu(aggregated_features) |
|
|
|
return new_node_representation |
|
|
|
|
|
def single_head_attention_edge(self, node_representation, edge_representation, diff_vec): |
|
|
|
B, L, num_heads, d_k = node_representation.size() |
|
d_edge_head = edge_representation.size(-1) |
|
d_diff_head = diff_vec.size(-1) |
|
|
|
hi = node_representation.unsqueeze(2) |
|
hj = node_representation.unsqueeze(1) |
|
diff_i = diff_vec.unsqueeze(2) |
|
diff_j = diff_vec.unsqueeze(1) |
|
|
|
hi_hj_concat = torch.cat([edge_representation, |
|
hi.expand(-1, -1, L, -1, -1), |
|
hj.expand(-1, L, -1, -1, -1), |
|
diff_i.expand(-1, -1, L, -1, -1), |
|
diff_j.expand(-1, L, -1, -1, -1)], dim=-1) |
|
|
|
new_edge_representation = self.edge_linear(hi_hj_concat) |
|
|
|
return new_edge_representation |
|
|
|
|
|
class CrossGraphAttention(nn.Module): |
|
def __init__(self, d_node, d_cross_edge, d_diff, num_heads, negative_slope=0.2): |
|
super(CrossGraphAttention, self).__init__() |
|
assert d_node % num_heads == 0, "d_node must be divisible by num_heads" |
|
assert d_cross_edge % num_heads == 0, "d_edge must be divisible by num_heads" |
|
assert d_diff % num_heads == 0, "d_diff must be divisible by num_heads" |
|
|
|
self.num_heads = num_heads |
|
self.d_k = d_node // num_heads |
|
self.d_edge_head = d_cross_edge // num_heads |
|
self.d_diff_head = d_diff // num_heads |
|
|
|
self.Wn = nn.Linear(d_node, d_node) |
|
self.Wh = nn.Linear(self.d_k, self.d_k) |
|
self.We = nn.Linear(d_cross_edge, d_cross_edge) |
|
self.Wn_2 = nn.Linear(d_node, d_node) |
|
self.We_2 = nn.Linear(d_cross_edge, d_cross_edge) |
|
self.Wd = nn.Linear(d_diff, d_diff) |
|
self.Wd_2 = nn.Linear(d_diff, d_diff) |
|
self.attn_linear_target = nn.Linear(self.d_k * 2 + self.d_edge_head + self.d_diff_head, 1, bias=False) |
|
self.attn_linear_binder = nn.Linear(self.d_k * 2 + self.d_edge_head, 1, bias=False) |
|
self.edge_linear = nn.Linear(self.d_k * 2 + self.d_edge_head + self.d_diff_head, self.d_edge_head) |
|
|
|
self.out_proj_node = nn.Linear(d_node, d_node) |
|
self.out_proj_edge = nn.Linear(d_cross_edge, d_cross_edge) |
|
|
|
self.leaky_relu = nn.LeakyReLU(negative_slope) |
|
|
|
def forward(self, target_representation, binder_representation, edge_representation, diff_vec): |
|
B, L1, d_node = target_representation.size() |
|
L2 = binder_representation.size()[1] |
|
d_edge = edge_representation.size(-1) |
|
|
|
|
|
|
|
|
|
target_proj = self.Wn(target_representation).view(B, L1, self.num_heads, self.d_k) |
|
binder_proj = self.Wn(binder_representation).view(B, L2, self.num_heads, self.d_k) |
|
edge_proj = self.We(edge_representation).view(B, L1, L2, self.num_heads, self.d_edge_head) |
|
diff_proj = self.Wd(diff_vec).view(B, L1, self.num_heads, self.d_diff_head) |
|
|
|
|
|
new_edge_representation = self.single_head_attention_edge(target_proj, binder_proj, edge_proj, diff_proj) |
|
|
|
concatenated_edge_rep = new_edge_representation.view(B, L1, L2, -1) |
|
new_edge_representation = self.out_proj_edge(concatenated_edge_rep) |
|
|
|
|
|
target_proj_2 = self.Wn_2(target_representation).view(B, L1, self.num_heads, self.d_k) |
|
binder_proj_2 = self.Wn_2(binder_representation).view(B, L2, self.num_heads, self.d_k) |
|
edge_proj_2 = self.We_2(new_edge_representation).view(B, L1, L2, self.num_heads, self.d_edge_head) |
|
diff_proj_2 = self.Wd_2(diff_vec).view(B, L1, self.num_heads, self.d_diff_head) |
|
|
|
new_target_representation = self.single_head_attention_target(target_proj_2, binder_proj_2, edge_proj_2, diff_proj_2) |
|
new_binder_representation = self.single_head_attention_binder(binder_proj_2, target_proj_2, edge_proj_2) |
|
|
|
concatenated_target_rep = new_target_representation.view(B, L1, -1) |
|
new_target_representation = self.out_proj_node(concatenated_target_rep) |
|
|
|
concatenated_binder_rep = new_binder_representation.view(B, L2, -1) |
|
new_binder_representation = self.out_proj_node(concatenated_binder_rep) |
|
|
|
return new_target_representation, new_binder_representation, new_edge_representation |
|
|
|
def single_head_attention_target(self, target_representation, binder_representation, edge_representation, diff_vec): |
|
|
|
|
|
B, L1, num_heads, d_k = target_representation.size() |
|
L2 = binder_representation.size(1) |
|
d_edge_head = edge_representation.size(-1) |
|
d_diff_head = diff_vec.size(-1) |
|
|
|
hi = target_representation.unsqueeze(2) |
|
hj = binder_representation.unsqueeze(1) |
|
diff_i = diff_vec.unsqueeze(2) |
|
|
|
|
|
hi_hj_concat = torch.cat([ |
|
hi.expand(-1, -1, L2, -1, -1), |
|
hj.expand(-1, L1, -1, -1, -1), |
|
edge_representation, |
|
diff_i.expand(-1, -1, L2, -1, -1) |
|
], dim=-1) |
|
|
|
|
|
attention_scores = self.attn_linear_target(hi_hj_concat).squeeze(-1) |
|
attention_probs = F.softmax(self.leaky_relu(attention_scores), dim=2) |
|
|
|
|
|
binder_representation_Wh = self.Wh(binder_representation) |
|
binder_representation_Wh = binder_representation_Wh.permute(0, 2, 1, 3) |
|
|
|
aggregated_features = torch.matmul(attention_probs.permute(0, 3, 1, 2), binder_representation_Wh) |
|
aggregated_features = aggregated_features.permute(0, 2, 1, 3) |
|
|
|
|
|
new_target_representation = target_representation + self.leaky_relu(aggregated_features) |
|
|
|
return new_target_representation |
|
|
|
|
|
def single_head_attention_binder(self, target_representation, binder_representation, edge_representation): |
|
|
|
|
|
B, L1, num_heads, d_k = target_representation.size() |
|
L2 = binder_representation.size(1) |
|
d_edge_head = edge_representation.size(-1) |
|
|
|
hi = target_representation.unsqueeze(2) |
|
hj = binder_representation.unsqueeze(1) |
|
edge_representation = edge_representation.transpose(1,2) |
|
|
|
|
|
hi_hj_concat = torch.cat([ |
|
hi.expand(-1, -1, L2, -1, -1), |
|
hj.expand(-1, L1, -1, -1, -1), |
|
edge_representation, |
|
], dim=-1) |
|
|
|
|
|
attention_scores = self.attn_linear_binder(hi_hj_concat).squeeze(-1) |
|
attention_probs = F.softmax(self.leaky_relu(attention_scores), dim=2) |
|
|
|
|
|
binder_representation_Wh = self.Wh(binder_representation) |
|
binder_representation_Wh = binder_representation_Wh.permute(0, 2, 1, 3) |
|
|
|
aggregated_features = torch.matmul(attention_probs.permute(0, 3, 1, 2), binder_representation_Wh) |
|
aggregated_features = aggregated_features.permute(0, 2, 1, 3) |
|
|
|
|
|
new_target_representation = target_representation + self.leaky_relu(aggregated_features) |
|
|
|
return new_target_representation |
|
|
|
|
|
def single_head_attention_edge(self, target_representation, binder_representation, edge_representation, diff_vec): |
|
|
|
|
|
B, L1, num_heads, d_k = target_representation.size() |
|
L2 = binder_representation.size(1) |
|
d_edge_head = edge_representation.size(-1) |
|
d_diff_head = diff_vec.size(-1) |
|
|
|
hi = target_representation.unsqueeze(2) |
|
hj = binder_representation.unsqueeze(1) |
|
diff_i = diff_vec.unsqueeze(2) |
|
|
|
hi_hj_concat = torch.cat([edge_representation, |
|
hi.expand(-1, -1, L2, -1, -1), |
|
hj.expand(-1, L1, -1, -1, -1), |
|
diff_i.expand(-1, -1, L2, -1, -1)], dim=-1) |
|
|
|
new_edge_representation = self.edge_linear(hi_hj_concat) |
|
|
|
return new_edge_representation |
|
|