AlienChen commited on
Commit
8f2863d
·
verified ·
1 Parent(s): 0cd961f

Create modules_vec.py

Browse files
Files changed (1) hide show
  1. models/modules_vec.py +387 -0
models/modules_vec.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import pdb
5
+
6
+ class IntraGraphAttention(nn.Module):
7
+ def __init__(self, d_node, d_edge, num_heads, negative_slope=0.2):
8
+ super(IntraGraphAttention, self).__init__()
9
+ assert d_node % num_heads == 0, "d_node must be divisible by num_heads"
10
+ assert d_edge % num_heads == 0, "d_edge must be divisible by num_heads"
11
+
12
+ self.num_heads = num_heads
13
+ self.d_k = d_node // num_heads
14
+ self.d_edge_head = d_edge // num_heads
15
+
16
+ self.Wn = nn.Linear(d_node, d_node)
17
+ self.Wh = nn.Linear(self.d_k, self.d_k)
18
+ self.We = nn.Linear(d_edge, d_edge)
19
+ self.Wn_2 = nn.Linear(d_node, d_node)
20
+ self.We_2 = nn.Linear(d_edge, d_edge)
21
+ self.attn_linear = nn.Linear(self.d_k * 2 + self.d_edge_head, 1, bias=False)
22
+ self.edge_linear = nn.Linear(self.d_k * 2 + self.d_edge_head, self.d_edge_head)
23
+
24
+ self.out_proj_node = nn.Linear(d_node, d_node)
25
+ self.out_proj_edge = nn.Linear(d_edge, d_edge)
26
+
27
+ self.leaky_relu = nn.LeakyReLU(negative_slope)
28
+
29
+ def forward(self, node_representation, edge_representation):
30
+ # node_representation: (B, L, d_node)
31
+ # edge_representation: (B, L, L, d_edge)
32
+ # pdb.set_trace()
33
+ B, L, d_node = node_representation.size()
34
+ d_edge = edge_representation.size(-1)
35
+
36
+ # Multi-head projection
37
+ node_proj = self.Wn(node_representation).view(B, L, self.num_heads, self.d_k) # (B, L, num_heads, d_k)
38
+ edge_proj = self.We(edge_representation).view(B, L, L, self.num_heads, self.d_edge_head) # (B, L, L, num_heads, d_edge_head)
39
+
40
+ # Node representation update
41
+ new_node_representation = self.single_head_attention_node(node_proj, edge_proj)
42
+
43
+ concatenated_node_rep = new_node_representation.view(B, L, -1) # Shape: (B, L, num_heads * d_k)
44
+ new_node_representation = self.out_proj_node(concatenated_node_rep)
45
+
46
+ # Edge representation update
47
+ node_proj_2 = self.Wn_2(new_node_representation).view(B, L, self.num_heads, self.d_k) # (B, L, num_heads, d_k)
48
+ edge_proj_2 = self.We_2(edge_representation).view(B, L, L, self.num_heads, self.d_edge_head) # (B, L, L, num_heads, d_edge_head)
49
+
50
+ new_edge_representation = self.single_head_attention_edge(node_proj_2, edge_proj_2)
51
+
52
+ concatenated_edge_rep = new_edge_representation.view(B, L, L, -1) # Shape: (B, L, L, num_heads * d_edge_head)
53
+ new_edge_representation = self.out_proj_edge(concatenated_edge_rep)
54
+
55
+ return new_node_representation, new_edge_representation
56
+
57
+ def single_head_attention_node(self, node_representation, edge_representation):
58
+ B, L, num_heads, d_k = node_representation.size()
59
+ d_edge_head = edge_representation.size(-1)
60
+
61
+ hi = node_representation.unsqueeze(2) # shape: (B, L, 1, num_heads, d_k)
62
+ hj = node_representation.unsqueeze(1) # shape: (B, 1, L, num_heads, d_k)
63
+
64
+ hi_hj_concat = torch.cat([hi.expand(-1, -1, L, -1, -1),
65
+ hj.expand(-1, L, -1, -1, -1),
66
+ edge_representation], dim=-1) # shape: (B, L, L, num_heads, 2*d_k + d_edge_head)
67
+
68
+ attention_scores = self.attn_linear(hi_hj_concat).squeeze(-1) # shape: (B, L, L, num_heads)
69
+
70
+ # Mask the diagonal (self-attention) by setting it to a large negative value
71
+ mask = torch.eye(L).bool().unsqueeze(0).unsqueeze(-1).to(node_representation.device) # shape: (1, L, L, 1)
72
+ attention_scores.masked_fill_(mask, float('-inf'))
73
+
74
+ attention_probs = F.softmax(self.leaky_relu(attention_scores), dim=2) # shape: (B, L, L, num_heads)
75
+
76
+ # Aggregating features correctly along the L dimension
77
+ node_representation_Wh = self.Wh(node_representation) # shape: (B, L, num_heads, d_k)
78
+ node_representation_Wh = node_representation_Wh.permute(0, 2, 1, 3) # shape: (B, num_heads, L, d_k)
79
+
80
+ aggregated_features = torch.matmul(attention_probs.permute(0, 3, 1, 2), node_representation_Wh) # shape: (B, num_heads, L, d_k)
81
+ aggregated_features = aggregated_features.permute(0, 2, 1, 3) # shape: (B, L, num_heads, d_k)
82
+
83
+ new_node_representation = node_representation + self.leaky_relu(aggregated_features) # shape: (B, L, num_heads, d_k)
84
+
85
+ return new_node_representation
86
+
87
+ def single_head_attention_edge(self, node_representation, edge_representation):
88
+ # Update edge representation
89
+ B, L, num_heads, d_k = node_representation.size()
90
+ d_edge_head = edge_representation.size(-1)
91
+
92
+ hi = node_representation.unsqueeze(2) # shape: (B, L, 1, num_heads, d_k)
93
+ hj = node_representation.unsqueeze(1) # shape: (B, 1, L, num_heads, d_k)
94
+
95
+ hi_hj_concat = torch.cat([edge_representation, hi.expand(-1, -1, L, -1, -1), hj.expand(-1, L, -1, -1, -1)], dim=-1) # shape: (B, L, L, num_heads, 2*d_k + d_edge_head)
96
+
97
+ new_edge_representation = self.edge_linear(hi_hj_concat) # shape: (B, L, L, num_heads, d_edge_head)
98
+
99
+ return new_edge_representation
100
+
101
+
102
+ class DiffEmbeddingLayer(nn.Module):
103
+ def __init__(self, d_node):
104
+ super(DiffEmbeddingLayer, self).__init__()
105
+ self.W_delta = nn.Linear(d_node, d_node)
106
+
107
+ def forward(self, wt_node, mut_node):
108
+ delta_h = mut_node - wt_node # (B, L, d_node)
109
+ diff_vec = torch.relu(self.W_delta(delta_h)) # (B, L, d_node)
110
+ return diff_vec
111
+
112
+
113
+ class MIM(nn.Module):
114
+ def __init__(self, d_node, d_edge, d_diff, num_heads, negative_slope=0.2):
115
+ super(MIM, self).__init__()
116
+ assert d_node % num_heads == 0, "d_node must be divisible by num_heads"
117
+ assert d_edge % num_heads == 0, "d_edge must be divisible by num_heads"
118
+ assert d_diff % num_heads == 0, "d_diff must be divisible by num_heads"
119
+
120
+ self.num_heads = num_heads
121
+ self.d_k = d_node // num_heads
122
+ self.d_edge_head = d_edge // num_heads
123
+ self.d_diff_head = d_diff // num_heads
124
+
125
+ self.Wn = nn.Linear(d_node, d_node)
126
+ self.Wh = nn.Linear(self.d_k, self.d_k)
127
+ self.We = nn.Linear(d_edge, d_edge)
128
+ self.Wn_2 = nn.Linear(d_node, d_node)
129
+ self.We_2 = nn.Linear(d_edge, d_edge)
130
+ self.Wd = nn.Linear(d_diff, d_diff)
131
+ self.Wd_2 = nn.Linear(d_diff, d_diff)
132
+ self.attn_linear = nn.Linear(self.d_k * 2 + self.d_edge_head + 2 * self.d_diff_head, 1, bias=False)
133
+ self.edge_linear = nn.Linear(self.d_k * 2 + self.d_edge_head + 2 * self.d_diff_head, self.d_edge_head)
134
+
135
+ self.out_proj_node = nn.Linear(d_node, d_node)
136
+ self.out_proj_edge = nn.Linear(d_edge, d_edge)
137
+
138
+ self.leaky_relu = nn.LeakyReLU(negative_slope)
139
+
140
+ def forward(self, node_representation, edge_representation, diff_vec):
141
+ # node_representation: (B, L, d_node)
142
+ # edge_representation: (B, L, L, d_edge)
143
+ # diff_vec: (B, L, d_diff)
144
+
145
+ B, L, d_node = node_representation.size()
146
+ d_edge = edge_representation.size(-1)
147
+
148
+ # Multi-head projection
149
+ node_proj = self.Wn(node_representation).view(B, L, self.num_heads, self.d_k) # (B, L, num_heads, d_k)
150
+ edge_proj = self.We(edge_representation).view(B, L, L, self.num_heads, self.d_edge_head) # (B, L, L, num_heads, d_edge_head)
151
+ diff_proj = self.Wd(diff_vec).view(B, L, self.num_heads, self.d_diff_head) # (B, L, num_heads, d_diff_head)
152
+
153
+ # Node representation update
154
+ new_node_representation = self.single_head_attention_node(node_proj, edge_proj, diff_proj)
155
+
156
+ concatenated_node_rep = new_node_representation.view(B, L, -1) # Shape: (B, L, num_heads * d_k)
157
+ new_node_representation = self.out_proj_node(concatenated_node_rep)
158
+
159
+ # Edge representation update
160
+ node_proj_2 = self.Wn_2(new_node_representation).view(B, L, self.num_heads, self.d_k) # (B, L, num_heads, d_k)
161
+ edge_proj_2 = self.We_2(edge_representation).view(B, L, L, self.num_heads, self.d_edge_head) # (B, L, L, num_heads, d_edge_head)
162
+ diff_proj_2 = self.Wd_2(diff_vec).view(B, L, self.num_heads, self.d_diff_head) # (B, L, num_heads, d_diff_head)
163
+
164
+ new_edge_representation = self.single_head_attention_edge(node_proj_2, edge_proj_2, diff_proj_2)
165
+
166
+ concatenated_edge_rep = new_edge_representation.view(B, L, L, -1) # Shape: (B, L, L, num_heads * d_edge_head)
167
+ new_edge_representation = self.out_proj_edge(concatenated_edge_rep)
168
+
169
+ return new_node_representation, new_edge_representation
170
+
171
+ def single_head_attention_node(self, node_representation, edge_representation, diff_vec):
172
+ # Update node representation
173
+ B, L, num_heads, d_k = node_representation.size()
174
+ d_edge_head = edge_representation.size(-1)
175
+ d_diff_head = diff_vec.size(-1)
176
+
177
+ hi = node_representation.unsqueeze(2) # shape: (B, L, 1, num_heads, d_k)
178
+ hj = node_representation.unsqueeze(1) # shape: (B, 1, L, num_heads, d_k)
179
+ diff_i = diff_vec.unsqueeze(2) # shape: (B, L, 1, num_heads, d_diff_head)
180
+ diff_j = diff_vec.unsqueeze(1) # shape: (B, 1, L, num_heads, d_diff_head)
181
+
182
+ hi_hj_concat = torch.cat([
183
+ hi.expand(-1, -1, L, -1, -1),
184
+ hj.expand(-1, L, -1, -1, -1),
185
+ edge_representation,
186
+ diff_i.expand(-1, -1, L, -1, -1),
187
+ diff_j.expand(-1, L, -1, -1, -1)
188
+ ], dim=-1) # shape: (B, L, L, num_heads, 2*d_k + d_edge_head + 2*d_diff_head)
189
+
190
+ attention_scores = self.attn_linear(hi_hj_concat).squeeze(-1) # shape: (B, L, L, num_heads)
191
+
192
+ # Mask the diagonal (self-attention) by setting it to a large negative value
193
+ mask = torch.eye(L).bool().unsqueeze(0).unsqueeze(-1).to(node_representation.device) # shape: (1, L, L, 1)
194
+ attention_scores.masked_fill_(mask, float('-inf'))
195
+
196
+ attention_probs = F.softmax(self.leaky_relu(attention_scores), dim=2) # shape: (B, L, L, num_heads)
197
+
198
+ # Aggregating features correctly along the L dimension
199
+ node_representation_Wh = self.Wh(node_representation) # shape: (B, L, num_heads, d_k)
200
+ node_representation_Wh = node_representation_Wh.permute(0, 2, 1, 3) # shape: (B, num_heads, L, d_k)
201
+
202
+ aggregated_features = torch.matmul(attention_probs.permute(0, 3, 1, 2), node_representation_Wh) # shape: (B, num_heads, L, d_k)
203
+ aggregated_features = aggregated_features.permute(0, 2, 1, 3) # shape: (B, L, num_heads, d_k)
204
+
205
+ new_node_representation = node_representation + self.leaky_relu(aggregated_features) # shape: (B, L, num_heads, d_k)
206
+
207
+ return new_node_representation
208
+
209
+
210
+ def single_head_attention_edge(self, node_representation, edge_representation, diff_vec):
211
+ # Update edge representation
212
+ B, L, num_heads, d_k = node_representation.size()
213
+ d_edge_head = edge_representation.size(-1)
214
+ d_diff_head = diff_vec.size(-1)
215
+
216
+ hi = node_representation.unsqueeze(2) # shape: (B, L, 1, num_heads, d_k)
217
+ hj = node_representation.unsqueeze(1) # shape: (B, 1, L, num_heads, d_k)
218
+ diff_i = diff_vec.unsqueeze(2) # shape: (B, L, 1, num_heads, d_diff_head)
219
+ diff_j = diff_vec.unsqueeze(1) # shape: (B, 1, L, num_heads, d_diff_head)
220
+
221
+ hi_hj_concat = torch.cat([edge_representation,
222
+ hi.expand(-1, -1, L, -1, -1),
223
+ hj.expand(-1, L, -1, -1, -1),
224
+ diff_i.expand(-1, -1, L, -1, -1),
225
+ diff_j.expand(-1, L, -1, -1, -1)], dim=-1) # shape: (B, L, L, num_heads, 2*d_k + d_edge_head + 2*d_diff_head)
226
+
227
+ new_edge_representation = self.edge_linear(hi_hj_concat) # shape: (B, L, L, num_heads, d_edge_head)
228
+
229
+ return new_edge_representation
230
+
231
+
232
+ class CrossGraphAttention(nn.Module):
233
+ def __init__(self, d_node, d_cross_edge, d_diff, num_heads, negative_slope=0.2):
234
+ super(CrossGraphAttention, self).__init__()
235
+ assert d_node % num_heads == 0, "d_node must be divisible by num_heads"
236
+ assert d_cross_edge % num_heads == 0, "d_edge must be divisible by num_heads"
237
+ assert d_diff % num_heads == 0, "d_diff must be divisible by num_heads"
238
+
239
+ self.num_heads = num_heads
240
+ self.d_k = d_node // num_heads
241
+ self.d_edge_head = d_cross_edge // num_heads
242
+ self.d_diff_head = d_diff // num_heads
243
+
244
+ self.Wn = nn.Linear(d_node, d_node)
245
+ self.Wh = nn.Linear(self.d_k, self.d_k)
246
+ self.We = nn.Linear(d_cross_edge, d_cross_edge)
247
+ self.Wn_2 = nn.Linear(d_node, d_node)
248
+ self.We_2 = nn.Linear(d_cross_edge, d_cross_edge)
249
+ self.Wd = nn.Linear(d_diff, d_diff)
250
+ self.Wd_2 = nn.Linear(d_diff, d_diff)
251
+ self.attn_linear_target = nn.Linear(self.d_k * 2 + self.d_edge_head + self.d_diff_head, 1, bias=False)
252
+ self.attn_linear_binder = nn.Linear(self.d_k * 2 + self.d_edge_head, 1, bias=False)
253
+ self.edge_linear = nn.Linear(self.d_k * 2 + self.d_edge_head + self.d_diff_head, self.d_edge_head)
254
+
255
+ self.out_proj_node = nn.Linear(d_node, d_node)
256
+ self.out_proj_edge = nn.Linear(d_cross_edge, d_cross_edge)
257
+
258
+ self.leaky_relu = nn.LeakyReLU(negative_slope)
259
+
260
+ def forward(self, target_representation, binder_representation, edge_representation, diff_vec):
261
+ B, L1, d_node = target_representation.size()
262
+ L2 = binder_representation.size()[1]
263
+ d_edge = edge_representation.size(-1)
264
+
265
+ # pdb.set_trace()
266
+
267
+ # Multi-head projection
268
+ target_proj = self.Wn(target_representation).view(B, L1, self.num_heads, self.d_k)
269
+ binder_proj = self.Wn(binder_representation).view(B, L2, self.num_heads, self.d_k)
270
+ edge_proj = self.We(edge_representation).view(B, L1, L2, self.num_heads, self.d_edge_head)
271
+ diff_proj = self.Wd(diff_vec).view(B, L1, self.num_heads, self.d_diff_head)
272
+
273
+ # Edge representation update
274
+ new_edge_representation = self.single_head_attention_edge(target_proj, binder_proj, edge_proj, diff_proj)
275
+
276
+ concatenated_edge_rep = new_edge_representation.view(B, L1, L2, -1)
277
+ new_edge_representation = self.out_proj_edge(concatenated_edge_rep)
278
+
279
+ # Node representation update
280
+ target_proj_2 = self.Wn_2(target_representation).view(B, L1, self.num_heads, self.d_k)
281
+ binder_proj_2 = self.Wn_2(binder_representation).view(B, L2, self.num_heads, self.d_k)
282
+ edge_proj_2 = self.We_2(new_edge_representation).view(B, L1, L2, self.num_heads, self.d_edge_head)
283
+ diff_proj_2 = self.Wd_2(diff_vec).view(B, L1, self.num_heads, self.d_diff_head)
284
+
285
+ new_target_representation = self.single_head_attention_target(target_proj_2, binder_proj_2, edge_proj_2, diff_proj_2)
286
+ new_binder_representation = self.single_head_attention_binder(binder_proj_2, target_proj_2, edge_proj_2)
287
+
288
+ concatenated_target_rep = new_target_representation.view(B, L1, -1)
289
+ new_target_representation = self.out_proj_node(concatenated_target_rep)
290
+
291
+ concatenated_binder_rep = new_binder_representation.view(B, L2, -1)
292
+ new_binder_representation = self.out_proj_node(concatenated_binder_rep)
293
+
294
+ return new_target_representation, new_binder_representation, new_edge_representation
295
+
296
+ def single_head_attention_target(self, target_representation, binder_representation, edge_representation, diff_vec):
297
+ # Update target node representation
298
+ # pdb.set_trace()
299
+ B, L1, num_heads, d_k = target_representation.size()
300
+ L2 = binder_representation.size(1)
301
+ d_edge_head = edge_representation.size(-1)
302
+ d_diff_head = diff_vec.size(-1)
303
+
304
+ hi = target_representation.unsqueeze(2) # shape: (B, L1, 1, num_heads, d_k)
305
+ hj = binder_representation.unsqueeze(1) # shape: (B, 1, L2, num_heads, d_k)
306
+ diff_i = diff_vec.unsqueeze(2) # shape: (B, L1, 1, num_heads, d_diff_head)
307
+
308
+ # Concatenate hi, hj, edge_representation, and diff_i
309
+ hi_hj_concat = torch.cat([
310
+ hi.expand(-1, -1, L2, -1, -1),
311
+ hj.expand(-1, L1, -1, -1, -1),
312
+ edge_representation,
313
+ diff_i.expand(-1, -1, L2, -1, -1)
314
+ ], dim=-1) # shape: (B, L1, L2, num_heads, 2*d_k + d_edge_head + d_diff_head)
315
+
316
+ # Calculate attention scores
317
+ attention_scores = self.attn_linear_target(hi_hj_concat).squeeze(-1) # shape: (B, L1, L2, num_heads)
318
+ attention_probs = F.softmax(self.leaky_relu(attention_scores), dim=2) # shape: (B, L1, L2, num_heads)
319
+
320
+ # Aggregating features correctly along the L2 dimension
321
+ binder_representation_Wh = self.Wh(binder_representation) # shape: (B, L2, num_heads, d_k)
322
+ binder_representation_Wh = binder_representation_Wh.permute(0, 2, 1, 3) # shape: (B, num_heads, L2, d_k)
323
+
324
+ aggregated_features = torch.matmul(attention_probs.permute(0, 3, 1, 2), binder_representation_Wh) # shape: (B, num_heads, L1, d_k)
325
+ aggregated_features = aggregated_features.permute(0, 2, 1, 3) # shape: (B, L1, num_heads, d_k)
326
+
327
+ # Update target representation
328
+ new_target_representation = target_representation + self.leaky_relu(aggregated_features) # shape: (B, L1, num_heads, d_k)
329
+
330
+ return new_target_representation
331
+
332
+
333
+ def single_head_attention_binder(self, target_representation, binder_representation, edge_representation):
334
+ # Update target node representation
335
+ # pdb.set_trace()
336
+ B, L1, num_heads, d_k = target_representation.size()
337
+ L2 = binder_representation.size(1)
338
+ d_edge_head = edge_representation.size(-1)
339
+
340
+ hi = target_representation.unsqueeze(2) # shape: (B, L1, 1, num_heads, d_k)
341
+ hj = binder_representation.unsqueeze(1) # shape: (B, 1, L2, num_heads, d_k)
342
+ edge_representation = edge_representation.transpose(1,2)
343
+
344
+ # Concatenate hi, hj, edge_representation, and diff_i
345
+ hi_hj_concat = torch.cat([
346
+ hi.expand(-1, -1, L2, -1, -1),
347
+ hj.expand(-1, L1, -1, -1, -1),
348
+ edge_representation,
349
+ ], dim=-1) # shape: (B, L1, L2, num_heads, 2*d_k + d_edge_head)
350
+
351
+ # Calculate attention scores
352
+ attention_scores = self.attn_linear_binder(hi_hj_concat).squeeze(-1) # shape: (B, L1, L2, num_heads)
353
+ attention_probs = F.softmax(self.leaky_relu(attention_scores), dim=2) # shape: (B, L1, L2, num_heads)
354
+
355
+ # Aggregating features correctly along the L2 dimension
356
+ binder_representation_Wh = self.Wh(binder_representation) # shape: (B, L2, num_heads, d_k)
357
+ binder_representation_Wh = binder_representation_Wh.permute(0, 2, 1, 3) # shape: (B, num_heads, L2, d_k)
358
+
359
+ aggregated_features = torch.matmul(attention_probs.permute(0, 3, 1, 2), binder_representation_Wh) # shape: (B, num_heads, L1, d_k)
360
+ aggregated_features = aggregated_features.permute(0, 2, 1, 3) # shape: (B, L1, num_heads, d_k)
361
+
362
+ # Update target representation
363
+ new_target_representation = target_representation + self.leaky_relu(aggregated_features) # shape: (B, L1, num_heads, d_k)
364
+
365
+ return new_target_representation
366
+
367
+
368
+ def single_head_attention_edge(self, target_representation, binder_representation, edge_representation, diff_vec):
369
+ # Update edge representation
370
+ # pdb.set_trace()
371
+ B, L1, num_heads, d_k = target_representation.size()
372
+ L2 = binder_representation.size(1)
373
+ d_edge_head = edge_representation.size(-1)
374
+ d_diff_head = diff_vec.size(-1)
375
+
376
+ hi = target_representation.unsqueeze(2) # shape: (B, L1, 1, num_heads, d_k)
377
+ hj = binder_representation.unsqueeze(1) # shape: (B, 1, L2, num_heads, d_k)
378
+ diff_i = diff_vec.unsqueeze(2) # shape: (B, L1, 1, num_heads, d_diff_head)
379
+
380
+ hi_hj_concat = torch.cat([edge_representation,
381
+ hi.expand(-1, -1, L2, -1, -1),
382
+ hj.expand(-1, L1, -1, -1, -1),
383
+ diff_i.expand(-1, -1, L2, -1, -1)], dim=-1) # shape: (B, L1, L2, num_heads, 2*d_k + d_edge_head + d_diff_head)
384
+
385
+ new_edge_representation = self.edge_linear(hi_hj_concat) # shape: (B, L1, L2, num_heads, d_edge_head)
386
+
387
+ return new_edge_representation