Update nets/projections.py
Browse files- nets/projections.py +12 -13
nets/projections.py
CHANGED
@@ -13,7 +13,7 @@ class Projections(nn.Module):
|
|
13 |
|
14 |
self.W_key = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim))
|
15 |
self.W_val = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim))
|
16 |
-
self.W_output = nn.Parameter(torch.Tensor(
|
17 |
|
18 |
self.init_parameters()
|
19 |
|
@@ -24,23 +24,22 @@ class Projections(nn.Module):
|
|
24 |
|
25 |
def forward(self, h):
|
26 |
"""
|
27 |
-
:param h: (batch_size, graph_size, embed_dim)
|
28 |
-
:return: dict with keys K, V, V_output
|
29 |
"""
|
30 |
batch_size, graph_size, input_dim = h.size()
|
31 |
-
hflat = h.view(-1, input_dim) # (batch_size * graph_size, embed_dim)
|
32 |
|
|
|
33 |
shp = (self.n_heads, batch_size, graph_size, self.val_dim)
|
|
|
|
|
34 |
|
35 |
-
#
|
36 |
-
|
37 |
-
V = torch.matmul(hflat, self.W_val).view(shp) # (n_heads, batch_size, graph_size, val_dim)
|
38 |
-
|
39 |
-
# Output projection
|
40 |
-
V_output = torch.bmm(h, self.W_output.repeat(batch_size, 1, 1)) # (batch_size, graph_size, embed_dim)
|
41 |
|
42 |
return {
|
43 |
-
'K': K,
|
44 |
-
'V': V,
|
45 |
-
'V_output': V_output
|
46 |
}
|
|
|
13 |
|
14 |
self.W_key = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim))
|
15 |
self.W_val = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim))
|
16 |
+
self.W_output = nn.Parameter(torch.Tensor(embed_dim, embed_dim))
|
17 |
|
18 |
self.init_parameters()
|
19 |
|
|
|
24 |
|
25 |
def forward(self, h):
|
26 |
"""
|
27 |
+
:param h: Tensor of shape (batch_size, graph_size, embed_dim)
|
28 |
+
:return: dict with keys: K, V, V_output
|
29 |
"""
|
30 |
batch_size, graph_size, input_dim = h.size()
|
31 |
+
hflat = h.contiguous().view(-1, input_dim) # (batch_size * graph_size, embed_dim)
|
32 |
|
33 |
+
# Compute Keys and Values per head
|
34 |
shp = (self.n_heads, batch_size, graph_size, self.val_dim)
|
35 |
+
K = torch.matmul(hflat, self.W_key).view(shp)
|
36 |
+
V = torch.matmul(hflat, self.W_val).view(shp)
|
37 |
|
38 |
+
# Compute output projection: (batch_size, graph_size, embed_dim)
|
39 |
+
V_output = torch.matmul(h, self.W_output.expand_as(self.W_output))
|
|
|
|
|
|
|
|
|
40 |
|
41 |
return {
|
42 |
+
'K': K, # (n_heads, batch_size, graph_size, val_dim)
|
43 |
+
'V': V, # (n_heads, batch_size, graph_size, val_dim)
|
44 |
+
'V_output': V_output # (batch_size, graph_size, embed_dim)
|
45 |
}
|