Upload multi_headed_attention.py
Browse files- nets/multi_headed_attention.py +103 -0
nets/multi_headed_attention.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
from torch import nn
|
5 |
+
import math
|
6 |
+
|
7 |
+
|
8 |
+
class MultiHeadAttention(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
n_heads,
|
12 |
+
input_dim,
|
13 |
+
embed_dim=None,
|
14 |
+
val_dim=None,
|
15 |
+
key_dim=None
|
16 |
+
):
|
17 |
+
super(MultiHeadAttention, self).__init__()
|
18 |
+
|
19 |
+
if val_dim is None:
|
20 |
+
assert embed_dim is not None, "Provide either embed_dim or val_dim"
|
21 |
+
val_dim = embed_dim // n_heads
|
22 |
+
if key_dim is None:
|
23 |
+
key_dim = val_dim
|
24 |
+
|
25 |
+
self.n_heads = n_heads
|
26 |
+
self.input_dim = input_dim
|
27 |
+
self.embed_dim = embed_dim
|
28 |
+
self.val_dim = val_dim
|
29 |
+
self.key_dim = key_dim
|
30 |
+
|
31 |
+
self.norm_factor = 1 / math.sqrt(key_dim) # See Attention is all you need
|
32 |
+
|
33 |
+
self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim), requires_grad=True)
|
34 |
+
self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim), requires_grad=True)
|
35 |
+
self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim), requires_grad=True)
|
36 |
+
|
37 |
+
if embed_dim is not None:
|
38 |
+
self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim), requires_grad=True)
|
39 |
+
|
40 |
+
self.init_parameters()
|
41 |
+
|
42 |
+
def init_parameters(self):
|
43 |
+
|
44 |
+
for param in self.parameters():
|
45 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
46 |
+
param.data.uniform_(-stdv, stdv)
|
47 |
+
|
48 |
+
|
49 |
+
def forward(self, q, h=None, mask=None):
|
50 |
+
"""
|
51 |
+
:param q: queries (batch_size, n_query, input_dim)
|
52 |
+
:param h: data (batch_size, graph_size, input_dim)
|
53 |
+
:param mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1)
|
54 |
+
Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency)
|
55 |
+
:return:
|
56 |
+
"""
|
57 |
+
if h is None:
|
58 |
+
h = q # compute self-attention
|
59 |
+
|
60 |
+
# h should be (batch_size, graph_size, input_dim)
|
61 |
+
batch_size, graph_size, input_dim = h.size()
|
62 |
+
n_query = q.size(1)
|
63 |
+
assert q.size(0) == batch_size
|
64 |
+
assert q.size(2) == input_dim
|
65 |
+
assert input_dim == self.input_dim, "Wrong embedding dimension of input"
|
66 |
+
|
67 |
+
hflat = h.contiguous().view(-1, input_dim)
|
68 |
+
qflat = q.contiguous().view(-1, input_dim)
|
69 |
+
|
70 |
+
# last dimension can be different for keys and values
|
71 |
+
shp = (self.n_heads, batch_size, graph_size, -1)
|
72 |
+
shp_q = (self.n_heads, batch_size, n_query, -1)
|
73 |
+
|
74 |
+
# Calculate queries, (n_heads, n_query, graph_size, key/val_size)
|
75 |
+
Q = torch.matmul(qflat, self.W_query).view(shp_q)
|
76 |
+
# Calculate keys and values (n_heads, batch_size, graph_size, key/val_size)
|
77 |
+
K = torch.matmul(hflat, self.W_key).view(shp)
|
78 |
+
V = torch.matmul(hflat, self.W_val).view(shp)
|
79 |
+
|
80 |
+
# Calculate compatibility (n_heads, batch_size, n_query, graph_size)
|
81 |
+
compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3))
|
82 |
+
|
83 |
+
# Optionally apply mask to prevent attention
|
84 |
+
if mask is not None:
|
85 |
+
mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
|
86 |
+
compatibility[mask] = -np.inf
|
87 |
+
|
88 |
+
attn = F.softmax(compatibility, dim=-1)
|
89 |
+
|
90 |
+
# If there are nodes with no neighbours then softmax returns nan so we fix them to 0
|
91 |
+
if mask is not None:
|
92 |
+
attnc = attn.clone()
|
93 |
+
attnc[mask] = 0
|
94 |
+
attn = attnc
|
95 |
+
|
96 |
+
heads = torch.matmul(attn, V)
|
97 |
+
|
98 |
+
out = torch.mm(
|
99 |
+
heads.permute(1, 2, 0, 3).contiguous().view(-1, self.n_heads * self.val_dim),
|
100 |
+
self.W_out.view(-1, self.embed_dim)
|
101 |
+
).view(batch_size, n_query, self.embed_dim)
|
102 |
+
|
103 |
+
return out
|