File size: 2,505 Bytes
9d169e7
 
 
 
 
 
 
fd3eb7b
9d169e7
 
 
 
 
 
 
 
 
 
 
 
 
 
fd3eb7b
9d169e7
fd3eb7b
 
 
 
9d169e7
 
 
 
 
 
 
 
 
 
fd3eb7b
9d169e7
 
 
 
 
 
 
fd3eb7b
 
 
9d169e7
fd3eb7b
 
9d169e7
 
 
fd3eb7b
9d169e7
 
 
fd3eb7b
 
9d169e7
fd3eb7b
 
 
9d169e7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn.functional as F
from torch import nn
import math


class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None):
        super(MultiHeadAttention, self).__init__()

        if val_dim is None:
            assert embed_dim is not None, "Provide either embed_dim or val_dim"
            val_dim = embed_dim // n_heads
        if key_dim is None:
            key_dim = val_dim

        self.n_heads = n_heads
        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.val_dim = val_dim
        self.key_dim = key_dim

        self.norm_factor = 1 / math.sqrt(key_dim)

        self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
        self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
        self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim))
        self.W_out = nn.Parameter(torch.Tensor(n_heads * val_dim, embed_dim))

        self.init_parameters()

    def init_parameters(self):
        for param in self.parameters():
            stdv = 1. / math.sqrt(param.size(-1))
            param.data.uniform_(-stdv, stdv)

    def forward(self, q, h=None, mask=None):
        if h is None:
            h = q  # self-attention

        batch_size, graph_size, input_dim = h.size()
        n_query = q.size(1)

        hflat = h.contiguous().view(-1, input_dim)
        qflat = q.contiguous().view(-1, input_dim)

        K = torch.matmul(hflat, self.W_key).view(self.n_heads, batch_size, graph_size, self.key_dim)
        V = torch.matmul(hflat, self.W_val).view(self.n_heads, batch_size, graph_size, self.val_dim)
        Q = torch.matmul(qflat, self.W_query).view(self.n_heads, batch_size, n_query, self.key_dim)

        # Compute attention scores
        compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3))  # (n_heads, batch, n_query, graph)

        if mask is not None:
            mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
            compatibility = compatibility.masked_fill(mask, -1e9)

        attn = F.softmax(compatibility, dim=-1)

        # Apply attention to values
        heads = torch.matmul(attn, V)  # (n_heads, batch, n_query, val_dim)

        # Concatenate heads and project
        heads = heads.permute(1, 2, 0, 3).contiguous().view(batch_size, n_query, -1)
        out = torch.matmul(heads, self.W_out)  # (batch, n_query, embed_dim)

        return out