File size: 3,759 Bytes
e4a98cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e61766c
e4a98cf
 
e61766c
e4a98cf
 
 
 
e61766c
e4a98cf
 
e61766c
e4a98cf
 
e61766c
 
 
e4a98cf
 
 
 
615dabe
e4a98cf
 
0b32cc2
e4a98cf
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np


class ClassifierOutput(nn.Module):
    def __init__(self, embedding_size, C=10, softmax_output=False):
        super().__init__()
        self.C = C
        self.embedding_size = embedding_size
        self.softmax_output = softmax_output
        self.W_q = nn.Parameter(torch.Tensor(1, embedding_size, embedding_size))
        self.init_parameters()

    def init_parameters(self):
        for param in self.parameters():
            stdv = 1. / math.sqrt(param.size(-1))
            param.data.uniform_(-stdv, stdv)

    def forward(self, context, V_output):
        batch_size = context.shape[0]
        Q = torch.bmm(context, self.W_q.repeat(batch_size, 1, 1))
        z = torch.bmm(Q, V_output.permute(0, 2, 1))
        z = z / (self.embedding_size ** 0.5)
        z = self.C * torch.tanh(z)
        return F.softmax(z, dim=1) if self.softmax_output else z


class Attention(nn.Module):
    def __init__(self, n_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None):
        super().__init__()
        if val_dim is None:
            assert embed_dim is not None
            val_dim = embed_dim // n_heads
        if key_dim is None:
            key_dim = val_dim

        self.n_heads = n_heads
        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.val_dim = val_dim
        self.key_dim = key_dim
        self.norm_factor = 1 / math.sqrt(key_dim)

        self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
        self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim))
        self.init_parameters()

    def init_parameters(self):
        for param in self.parameters():
            stdv = 1. / math.sqrt(param.size(-1))
            param.data.uniform_(-stdv, stdv)

    def forward(self, q, K, V, mask=None):
        batch_size = K.size(1)
        graph_size = K.size(2)
        n_query = q.size(1)

        qflat = q.contiguous().view(-1, self.input_dim)
        shp_q = (self.n_heads, batch_size, n_query, -1)
        Q = torch.matmul(qflat, self.W_query).view(shp_q)

        compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3))

        if mask is not None:
            mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
            compatibility = compatibility.masked_fill(mask, float('-inf'))

        attn = F.softmax(compatibility, dim=-1)
        attn = attn.masked_fill(torch.isnan(attn), 0)

        heads = torch.matmul(attn, V)
        heads_combined = heads.permute(1, 2, 0, 3).contiguous().view(batch_size, n_query, -1)
        W_out_combined = self.W_out.permute(1, 0, 2).reshape(self.n_heads * self.key_dim, self.embed_dim)
        out = torch.matmul(heads_combined, W_out_combined)
        return out


class Decoder(nn.Module):
    def __init__(self, num_heads, embedding_size, decoder_input_size, softmax_output=False, C=10):
        super().__init__()
        self.embedding_size = embedding_size
        self.initial_embedding = nn.Linear(decoder_input_size - 1, embedding_size)
        self.attention = Attention(n_heads=num_heads, input_dim=embedding_size, embed_dim=embedding_size)
        self.classifier_output = ClassifierOutput(embedding_size=embedding_size, C=C, softmax_output=softmax_output)

    def forward(self, decoder_input, projections, mask, *args, **kwargs):
        mask = (mask == 0)
        K = projections['K']
        V = projections['V']
        V_output = projections['V_output']

        embedded_input = self.initial_embedding(decoder_input)
        context = self.attention(embedded_input, K, V, mask)
        output = self.classifier_output(context, V_output)
        return output