a-ragab-h-m commited on
Commit
9d169e7
·
verified ·
1 Parent(s): 5186a44

Upload multi_headed_attention.py

Browse files
Files changed (1) hide show
  1. nets/multi_headed_attention.py +103 -0
nets/multi_headed_attention.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from torch import nn
5
+ import math
6
+
7
+
8
+ class MultiHeadAttention(nn.Module):
9
+ def __init__(
10
+ self,
11
+ n_heads,
12
+ input_dim,
13
+ embed_dim=None,
14
+ val_dim=None,
15
+ key_dim=None
16
+ ):
17
+ super(MultiHeadAttention, self).__init__()
18
+
19
+ if val_dim is None:
20
+ assert embed_dim is not None, "Provide either embed_dim or val_dim"
21
+ val_dim = embed_dim // n_heads
22
+ if key_dim is None:
23
+ key_dim = val_dim
24
+
25
+ self.n_heads = n_heads
26
+ self.input_dim = input_dim
27
+ self.embed_dim = embed_dim
28
+ self.val_dim = val_dim
29
+ self.key_dim = key_dim
30
+
31
+ self.norm_factor = 1 / math.sqrt(key_dim) # See Attention is all you need
32
+
33
+ self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim), requires_grad=True)
34
+ self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim), requires_grad=True)
35
+ self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim), requires_grad=True)
36
+
37
+ if embed_dim is not None:
38
+ self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim), requires_grad=True)
39
+
40
+ self.init_parameters()
41
+
42
+ def init_parameters(self):
43
+
44
+ for param in self.parameters():
45
+ stdv = 1. / math.sqrt(param.size(-1))
46
+ param.data.uniform_(-stdv, stdv)
47
+
48
+
49
+ def forward(self, q, h=None, mask=None):
50
+ """
51
+ :param q: queries (batch_size, n_query, input_dim)
52
+ :param h: data (batch_size, graph_size, input_dim)
53
+ :param mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1)
54
+ Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency)
55
+ :return:
56
+ """
57
+ if h is None:
58
+ h = q # compute self-attention
59
+
60
+ # h should be (batch_size, graph_size, input_dim)
61
+ batch_size, graph_size, input_dim = h.size()
62
+ n_query = q.size(1)
63
+ assert q.size(0) == batch_size
64
+ assert q.size(2) == input_dim
65
+ assert input_dim == self.input_dim, "Wrong embedding dimension of input"
66
+
67
+ hflat = h.contiguous().view(-1, input_dim)
68
+ qflat = q.contiguous().view(-1, input_dim)
69
+
70
+ # last dimension can be different for keys and values
71
+ shp = (self.n_heads, batch_size, graph_size, -1)
72
+ shp_q = (self.n_heads, batch_size, n_query, -1)
73
+
74
+ # Calculate queries, (n_heads, n_query, graph_size, key/val_size)
75
+ Q = torch.matmul(qflat, self.W_query).view(shp_q)
76
+ # Calculate keys and values (n_heads, batch_size, graph_size, key/val_size)
77
+ K = torch.matmul(hflat, self.W_key).view(shp)
78
+ V = torch.matmul(hflat, self.W_val).view(shp)
79
+
80
+ # Calculate compatibility (n_heads, batch_size, n_query, graph_size)
81
+ compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3))
82
+
83
+ # Optionally apply mask to prevent attention
84
+ if mask is not None:
85
+ mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
86
+ compatibility[mask] = -np.inf
87
+
88
+ attn = F.softmax(compatibility, dim=-1)
89
+
90
+ # If there are nodes with no neighbours then softmax returns nan so we fix them to 0
91
+ if mask is not None:
92
+ attnc = attn.clone()
93
+ attnc[mask] = 0
94
+ attn = attnc
95
+
96
+ heads = torch.matmul(attn, V)
97
+
98
+ out = torch.mm(
99
+ heads.permute(1, 2, 0, 3).contiguous().view(-1, self.n_heads * self.val_dim),
100
+ self.W_out.view(-1, self.embed_dim)
101
+ ).view(batch_size, n_query, self.embed_dim)
102
+
103
+ return out