Update nets/multi_headed_attention.py
Browse files- nets/multi_headed_attention.py +18 -54
nets/multi_headed_attention.py
CHANGED
@@ -1,19 +1,11 @@
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
-
import numpy as np
|
4 |
from torch import nn
|
5 |
import math
|
6 |
|
7 |
|
8 |
class MultiHeadAttention(nn.Module):
|
9 |
-
def __init__(
|
10 |
-
self,
|
11 |
-
n_heads,
|
12 |
-
input_dim,
|
13 |
-
embed_dim=None,
|
14 |
-
val_dim=None,
|
15 |
-
key_dim=None
|
16 |
-
):
|
17 |
super(MultiHeadAttention, self).__init__()
|
18 |
|
19 |
if val_dim is None:
|
@@ -28,76 +20,48 @@ class MultiHeadAttention(nn.Module):
|
|
28 |
self.val_dim = val_dim
|
29 |
self.key_dim = key_dim
|
30 |
|
31 |
-
self.norm_factor = 1 / math.sqrt(key_dim)
|
32 |
|
33 |
-
self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)
|
34 |
-
self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)
|
35 |
-
self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim)
|
36 |
-
|
37 |
-
if embed_dim is not None:
|
38 |
-
self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim), requires_grad=True)
|
39 |
|
40 |
self.init_parameters()
|
41 |
|
42 |
def init_parameters(self):
|
43 |
-
|
44 |
for param in self.parameters():
|
45 |
stdv = 1. / math.sqrt(param.size(-1))
|
46 |
param.data.uniform_(-stdv, stdv)
|
47 |
|
48 |
-
|
49 |
def forward(self, q, h=None, mask=None):
|
50 |
-
"""
|
51 |
-
:param q: queries (batch_size, n_query, input_dim)
|
52 |
-
:param h: data (batch_size, graph_size, input_dim)
|
53 |
-
:param mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1)
|
54 |
-
Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency)
|
55 |
-
:return:
|
56 |
-
"""
|
57 |
if h is None:
|
58 |
-
h = q #
|
59 |
|
60 |
-
# h should be (batch_size, graph_size, input_dim)
|
61 |
batch_size, graph_size, input_dim = h.size()
|
62 |
n_query = q.size(1)
|
63 |
-
assert q.size(0) == batch_size
|
64 |
-
assert q.size(2) == input_dim
|
65 |
-
assert input_dim == self.input_dim, "Wrong embedding dimension of input"
|
66 |
|
67 |
hflat = h.contiguous().view(-1, input_dim)
|
68 |
qflat = q.contiguous().view(-1, input_dim)
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
|
74 |
-
#
|
75 |
-
|
76 |
-
# Calculate keys and values (n_heads, batch_size, graph_size, key/val_size)
|
77 |
-
K = torch.matmul(hflat, self.W_key).view(shp)
|
78 |
-
V = torch.matmul(hflat, self.W_val).view(shp)
|
79 |
|
80 |
-
# Calculate compatibility (n_heads, batch_size, n_query, graph_size)
|
81 |
-
compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3))
|
82 |
-
|
83 |
-
# Optionally apply mask to prevent attention
|
84 |
if mask is not None:
|
85 |
mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
|
86 |
-
compatibility
|
87 |
|
88 |
attn = F.softmax(compatibility, dim=-1)
|
89 |
|
90 |
-
#
|
91 |
-
|
92 |
-
attnc = attn.clone()
|
93 |
-
attnc[mask] = 0
|
94 |
-
attn = attnc
|
95 |
-
|
96 |
-
heads = torch.matmul(attn, V)
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
).view(batch_size, n_query, self.embed_dim)
|
102 |
|
103 |
return out
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
|
|
3 |
from torch import nn
|
4 |
import math
|
5 |
|
6 |
|
7 |
class MultiHeadAttention(nn.Module):
|
8 |
+
def __init__(self, n_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
super(MultiHeadAttention, self).__init__()
|
10 |
|
11 |
if val_dim is None:
|
|
|
20 |
self.val_dim = val_dim
|
21 |
self.key_dim = key_dim
|
22 |
|
23 |
+
self.norm_factor = 1 / math.sqrt(key_dim)
|
24 |
|
25 |
+
self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
|
26 |
+
self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
|
27 |
+
self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim))
|
28 |
+
self.W_out = nn.Parameter(torch.Tensor(n_heads * val_dim, embed_dim))
|
|
|
|
|
29 |
|
30 |
self.init_parameters()
|
31 |
|
32 |
def init_parameters(self):
|
|
|
33 |
for param in self.parameters():
|
34 |
stdv = 1. / math.sqrt(param.size(-1))
|
35 |
param.data.uniform_(-stdv, stdv)
|
36 |
|
|
|
37 |
def forward(self, q, h=None, mask=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
if h is None:
|
39 |
+
h = q # self-attention
|
40 |
|
|
|
41 |
batch_size, graph_size, input_dim = h.size()
|
42 |
n_query = q.size(1)
|
|
|
|
|
|
|
43 |
|
44 |
hflat = h.contiguous().view(-1, input_dim)
|
45 |
qflat = q.contiguous().view(-1, input_dim)
|
46 |
|
47 |
+
K = torch.matmul(hflat, self.W_key).view(self.n_heads, batch_size, graph_size, self.key_dim)
|
48 |
+
V = torch.matmul(hflat, self.W_val).view(self.n_heads, batch_size, graph_size, self.val_dim)
|
49 |
+
Q = torch.matmul(qflat, self.W_query).view(self.n_heads, batch_size, n_query, self.key_dim)
|
50 |
|
51 |
+
# Compute attention scores
|
52 |
+
compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) # (n_heads, batch, n_query, graph)
|
|
|
|
|
|
|
53 |
|
|
|
|
|
|
|
|
|
54 |
if mask is not None:
|
55 |
mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
|
56 |
+
compatibility = compatibility.masked_fill(mask, -1e9)
|
57 |
|
58 |
attn = F.softmax(compatibility, dim=-1)
|
59 |
|
60 |
+
# Apply attention to values
|
61 |
+
heads = torch.matmul(attn, V) # (n_heads, batch, n_query, val_dim)
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
# Concatenate heads and project
|
64 |
+
heads = heads.permute(1, 2, 0, 3).contiguous().view(batch_size, n_query, -1)
|
65 |
+
out = torch.matmul(heads, self.W_out) # (batch, n_query, embed_dim)
|
|
|
66 |
|
67 |
return out
|