a-ragab-h-m commited on
Commit
fd3eb7b
·
verified ·
1 Parent(s): 722e008

Update nets/multi_headed_attention.py

Browse files
Files changed (1) hide show
  1. nets/multi_headed_attention.py +18 -54
nets/multi_headed_attention.py CHANGED
@@ -1,19 +1,11 @@
1
  import torch
2
  import torch.nn.functional as F
3
- import numpy as np
4
  from torch import nn
5
  import math
6
 
7
 
8
  class MultiHeadAttention(nn.Module):
9
- def __init__(
10
- self,
11
- n_heads,
12
- input_dim,
13
- embed_dim=None,
14
- val_dim=None,
15
- key_dim=None
16
- ):
17
  super(MultiHeadAttention, self).__init__()
18
 
19
  if val_dim is None:
@@ -28,76 +20,48 @@ class MultiHeadAttention(nn.Module):
28
  self.val_dim = val_dim
29
  self.key_dim = key_dim
30
 
31
- self.norm_factor = 1 / math.sqrt(key_dim) # See Attention is all you need
32
 
33
- self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim), requires_grad=True)
34
- self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim), requires_grad=True)
35
- self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim), requires_grad=True)
36
-
37
- if embed_dim is not None:
38
- self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim), requires_grad=True)
39
 
40
  self.init_parameters()
41
 
42
  def init_parameters(self):
43
-
44
  for param in self.parameters():
45
  stdv = 1. / math.sqrt(param.size(-1))
46
  param.data.uniform_(-stdv, stdv)
47
 
48
-
49
  def forward(self, q, h=None, mask=None):
50
- """
51
- :param q: queries (batch_size, n_query, input_dim)
52
- :param h: data (batch_size, graph_size, input_dim)
53
- :param mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1)
54
- Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency)
55
- :return:
56
- """
57
  if h is None:
58
- h = q # compute self-attention
59
 
60
- # h should be (batch_size, graph_size, input_dim)
61
  batch_size, graph_size, input_dim = h.size()
62
  n_query = q.size(1)
63
- assert q.size(0) == batch_size
64
- assert q.size(2) == input_dim
65
- assert input_dim == self.input_dim, "Wrong embedding dimension of input"
66
 
67
  hflat = h.contiguous().view(-1, input_dim)
68
  qflat = q.contiguous().view(-1, input_dim)
69
 
70
- # last dimension can be different for keys and values
71
- shp = (self.n_heads, batch_size, graph_size, -1)
72
- shp_q = (self.n_heads, batch_size, n_query, -1)
73
 
74
- # Calculate queries, (n_heads, n_query, graph_size, key/val_size)
75
- Q = torch.matmul(qflat, self.W_query).view(shp_q)
76
- # Calculate keys and values (n_heads, batch_size, graph_size, key/val_size)
77
- K = torch.matmul(hflat, self.W_key).view(shp)
78
- V = torch.matmul(hflat, self.W_val).view(shp)
79
 
80
- # Calculate compatibility (n_heads, batch_size, n_query, graph_size)
81
- compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3))
82
-
83
- # Optionally apply mask to prevent attention
84
  if mask is not None:
85
  mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
86
- compatibility[mask] = -np.inf
87
 
88
  attn = F.softmax(compatibility, dim=-1)
89
 
90
- # If there are nodes with no neighbours then softmax returns nan so we fix them to 0
91
- if mask is not None:
92
- attnc = attn.clone()
93
- attnc[mask] = 0
94
- attn = attnc
95
-
96
- heads = torch.matmul(attn, V)
97
 
98
- out = torch.mm(
99
- heads.permute(1, 2, 0, 3).contiguous().view(-1, self.n_heads * self.val_dim),
100
- self.W_out.view(-1, self.embed_dim)
101
- ).view(batch_size, n_query, self.embed_dim)
102
 
103
  return out
 
1
  import torch
2
  import torch.nn.functional as F
 
3
  from torch import nn
4
  import math
5
 
6
 
7
  class MultiHeadAttention(nn.Module):
8
+ def __init__(self, n_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None):
 
 
 
 
 
 
 
9
  super(MultiHeadAttention, self).__init__()
10
 
11
  if val_dim is None:
 
20
  self.val_dim = val_dim
21
  self.key_dim = key_dim
22
 
23
+ self.norm_factor = 1 / math.sqrt(key_dim)
24
 
25
+ self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
26
+ self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
27
+ self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim))
28
+ self.W_out = nn.Parameter(torch.Tensor(n_heads * val_dim, embed_dim))
 
 
29
 
30
  self.init_parameters()
31
 
32
  def init_parameters(self):
 
33
  for param in self.parameters():
34
  stdv = 1. / math.sqrt(param.size(-1))
35
  param.data.uniform_(-stdv, stdv)
36
 
 
37
  def forward(self, q, h=None, mask=None):
 
 
 
 
 
 
 
38
  if h is None:
39
+ h = q # self-attention
40
 
 
41
  batch_size, graph_size, input_dim = h.size()
42
  n_query = q.size(1)
 
 
 
43
 
44
  hflat = h.contiguous().view(-1, input_dim)
45
  qflat = q.contiguous().view(-1, input_dim)
46
 
47
+ K = torch.matmul(hflat, self.W_key).view(self.n_heads, batch_size, graph_size, self.key_dim)
48
+ V = torch.matmul(hflat, self.W_val).view(self.n_heads, batch_size, graph_size, self.val_dim)
49
+ Q = torch.matmul(qflat, self.W_query).view(self.n_heads, batch_size, n_query, self.key_dim)
50
 
51
+ # Compute attention scores
52
+ compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) # (n_heads, batch, n_query, graph)
 
 
 
53
 
 
 
 
 
54
  if mask is not None:
55
  mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
56
+ compatibility = compatibility.masked_fill(mask, -1e9)
57
 
58
  attn = F.softmax(compatibility, dim=-1)
59
 
60
+ # Apply attention to values
61
+ heads = torch.matmul(attn, V) # (n_heads, batch, n_query, val_dim)
 
 
 
 
 
62
 
63
+ # Concatenate heads and project
64
+ heads = heads.permute(1, 2, 0, 3).contiguous().view(batch_size, n_query, -1)
65
+ out = torch.matmul(heads, self.W_out) # (batch, n_query, embed_dim)
 
66
 
67
  return out