Spaces:
Runtime error
Runtime error
Update nets/multi_headed_attention.py
Browse files- nets/multi_headed_attention.py +18 -54
nets/multi_headed_attention.py
CHANGED
|
@@ -1,19 +1,11 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn.functional as F
|
| 3 |
-
import numpy as np
|
| 4 |
from torch import nn
|
| 5 |
import math
|
| 6 |
|
| 7 |
|
| 8 |
class MultiHeadAttention(nn.Module):
|
| 9 |
-
def __init__(
|
| 10 |
-
self,
|
| 11 |
-
n_heads,
|
| 12 |
-
input_dim,
|
| 13 |
-
embed_dim=None,
|
| 14 |
-
val_dim=None,
|
| 15 |
-
key_dim=None
|
| 16 |
-
):
|
| 17 |
super(MultiHeadAttention, self).__init__()
|
| 18 |
|
| 19 |
if val_dim is None:
|
|
@@ -28,76 +20,48 @@ class MultiHeadAttention(nn.Module):
|
|
| 28 |
self.val_dim = val_dim
|
| 29 |
self.key_dim = key_dim
|
| 30 |
|
| 31 |
-
self.norm_factor = 1 / math.sqrt(key_dim)
|
| 32 |
|
| 33 |
-
self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)
|
| 34 |
-
self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)
|
| 35 |
-
self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim)
|
| 36 |
-
|
| 37 |
-
if embed_dim is not None:
|
| 38 |
-
self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim), requires_grad=True)
|
| 39 |
|
| 40 |
self.init_parameters()
|
| 41 |
|
| 42 |
def init_parameters(self):
|
| 43 |
-
|
| 44 |
for param in self.parameters():
|
| 45 |
stdv = 1. / math.sqrt(param.size(-1))
|
| 46 |
param.data.uniform_(-stdv, stdv)
|
| 47 |
|
| 48 |
-
|
| 49 |
def forward(self, q, h=None, mask=None):
|
| 50 |
-
"""
|
| 51 |
-
:param q: queries (batch_size, n_query, input_dim)
|
| 52 |
-
:param h: data (batch_size, graph_size, input_dim)
|
| 53 |
-
:param mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1)
|
| 54 |
-
Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency)
|
| 55 |
-
:return:
|
| 56 |
-
"""
|
| 57 |
if h is None:
|
| 58 |
-
h = q #
|
| 59 |
|
| 60 |
-
# h should be (batch_size, graph_size, input_dim)
|
| 61 |
batch_size, graph_size, input_dim = h.size()
|
| 62 |
n_query = q.size(1)
|
| 63 |
-
assert q.size(0) == batch_size
|
| 64 |
-
assert q.size(2) == input_dim
|
| 65 |
-
assert input_dim == self.input_dim, "Wrong embedding dimension of input"
|
| 66 |
|
| 67 |
hflat = h.contiguous().view(-1, input_dim)
|
| 68 |
qflat = q.contiguous().view(-1, input_dim)
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
|
| 74 |
-
#
|
| 75 |
-
|
| 76 |
-
# Calculate keys and values (n_heads, batch_size, graph_size, key/val_size)
|
| 77 |
-
K = torch.matmul(hflat, self.W_key).view(shp)
|
| 78 |
-
V = torch.matmul(hflat, self.W_val).view(shp)
|
| 79 |
|
| 80 |
-
# Calculate compatibility (n_heads, batch_size, n_query, graph_size)
|
| 81 |
-
compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3))
|
| 82 |
-
|
| 83 |
-
# Optionally apply mask to prevent attention
|
| 84 |
if mask is not None:
|
| 85 |
mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
|
| 86 |
-
compatibility
|
| 87 |
|
| 88 |
attn = F.softmax(compatibility, dim=-1)
|
| 89 |
|
| 90 |
-
#
|
| 91 |
-
|
| 92 |
-
attnc = attn.clone()
|
| 93 |
-
attnc[mask] = 0
|
| 94 |
-
attn = attnc
|
| 95 |
-
|
| 96 |
-
heads = torch.matmul(attn, V)
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
).view(batch_size, n_query, self.embed_dim)
|
| 102 |
|
| 103 |
return out
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn.functional as F
|
|
|
|
| 3 |
from torch import nn
|
| 4 |
import math
|
| 5 |
|
| 6 |
|
| 7 |
class MultiHeadAttention(nn.Module):
|
| 8 |
+
def __init__(self, n_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
super(MultiHeadAttention, self).__init__()
|
| 10 |
|
| 11 |
if val_dim is None:
|
|
|
|
| 20 |
self.val_dim = val_dim
|
| 21 |
self.key_dim = key_dim
|
| 22 |
|
| 23 |
+
self.norm_factor = 1 / math.sqrt(key_dim)
|
| 24 |
|
| 25 |
+
self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
|
| 26 |
+
self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
|
| 27 |
+
self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim))
|
| 28 |
+
self.W_out = nn.Parameter(torch.Tensor(n_heads * val_dim, embed_dim))
|
|
|
|
|
|
|
| 29 |
|
| 30 |
self.init_parameters()
|
| 31 |
|
| 32 |
def init_parameters(self):
|
|
|
|
| 33 |
for param in self.parameters():
|
| 34 |
stdv = 1. / math.sqrt(param.size(-1))
|
| 35 |
param.data.uniform_(-stdv, stdv)
|
| 36 |
|
|
|
|
| 37 |
def forward(self, q, h=None, mask=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
if h is None:
|
| 39 |
+
h = q # self-attention
|
| 40 |
|
|
|
|
| 41 |
batch_size, graph_size, input_dim = h.size()
|
| 42 |
n_query = q.size(1)
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
hflat = h.contiguous().view(-1, input_dim)
|
| 45 |
qflat = q.contiguous().view(-1, input_dim)
|
| 46 |
|
| 47 |
+
K = torch.matmul(hflat, self.W_key).view(self.n_heads, batch_size, graph_size, self.key_dim)
|
| 48 |
+
V = torch.matmul(hflat, self.W_val).view(self.n_heads, batch_size, graph_size, self.val_dim)
|
| 49 |
+
Q = torch.matmul(qflat, self.W_query).view(self.n_heads, batch_size, n_query, self.key_dim)
|
| 50 |
|
| 51 |
+
# Compute attention scores
|
| 52 |
+
compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) # (n_heads, batch, n_query, graph)
|
|
|
|
|
|
|
|
|
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
if mask is not None:
|
| 55 |
mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
|
| 56 |
+
compatibility = compatibility.masked_fill(mask, -1e9)
|
| 57 |
|
| 58 |
attn = F.softmax(compatibility, dim=-1)
|
| 59 |
|
| 60 |
+
# Apply attention to values
|
| 61 |
+
heads = torch.matmul(attn, V) # (n_heads, batch, n_query, val_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
# Concatenate heads and project
|
| 64 |
+
heads = heads.permute(1, 2, 0, 3).contiguous().view(batch_size, n_query, -1)
|
| 65 |
+
out = torch.matmul(heads, self.W_out) # (batch, n_query, embed_dim)
|
|
|
|
| 66 |
|
| 67 |
return out
|