Create nets/decoder.py
Browse files- nets/decoder.py +100 -0
nets/decoder.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class ClassifierOutput(nn.Module):
|
9 |
+
def __init__(self, embedding_size, C=10, softmax_output=False):
|
10 |
+
super().__init__()
|
11 |
+
self.C = C
|
12 |
+
self.embedding_size = embedding_size
|
13 |
+
self.softmax_output = softmax_output
|
14 |
+
self.W_q = nn.Parameter(torch.Tensor(1, embedding_size, embedding_size))
|
15 |
+
self.init_parameters()
|
16 |
+
|
17 |
+
def init_parameters(self):
|
18 |
+
for param in self.parameters():
|
19 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
20 |
+
param.data.uniform_(-stdv, stdv)
|
21 |
+
|
22 |
+
def forward(self, context, V_output):
|
23 |
+
batch_size = context.shape[0]
|
24 |
+
Q = torch.bmm(context, self.W_q.repeat(batch_size, 1, 1))
|
25 |
+
z = torch.bmm(Q, V_output.permute(0, 2, 1))
|
26 |
+
z = z / (self.embedding_size ** 0.5)
|
27 |
+
z = self.C * torch.tanh(z)
|
28 |
+
return F.softmax(z, dim=1) if self.softmax_output else z
|
29 |
+
|
30 |
+
|
31 |
+
class Attention(nn.Module):
|
32 |
+
def __init__(self, n_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None):
|
33 |
+
super().__init__()
|
34 |
+
if val_dim is None:
|
35 |
+
assert embed_dim is not None
|
36 |
+
val_dim = embed_dim // n_heads
|
37 |
+
if key_dim is None:
|
38 |
+
key_dim = val_dim
|
39 |
+
|
40 |
+
self.n_heads = n_heads
|
41 |
+
self.input_dim = input_dim
|
42 |
+
self.embed_dim = embed_dim
|
43 |
+
self.val_dim = val_dim
|
44 |
+
self.key_dim = key_dim
|
45 |
+
self.norm_factor = 1 / math.sqrt(key_dim)
|
46 |
+
|
47 |
+
self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
|
48 |
+
self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim))
|
49 |
+
self.init_parameters()
|
50 |
+
|
51 |
+
def init_parameters(self):
|
52 |
+
for param in self.parameters():
|
53 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
54 |
+
param.data.uniform_(-stdv, stdv)
|
55 |
+
|
56 |
+
def forward(self, q, K, V, mask=None):
|
57 |
+
batch_size = K.size(1)
|
58 |
+
graph_size = K.size(2)
|
59 |
+
n_query = q.size(1)
|
60 |
+
qflat = q.contiguous().view(-1, self.input_dim)
|
61 |
+
|
62 |
+
shp_q = (self.n_heads, batch_size, n_query, -1)
|
63 |
+
Q = torch.matmul(qflat, self.W_query).view(shp_q)
|
64 |
+
compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3))
|
65 |
+
|
66 |
+
if mask is not None:
|
67 |
+
mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
|
68 |
+
compatibility[mask] = -np.inf
|
69 |
+
|
70 |
+
attn = F.softmax(compatibility, dim=-1)
|
71 |
+
if mask is not None:
|
72 |
+
attn = attn.masked_fill(mask, 0)
|
73 |
+
|
74 |
+
heads = torch.matmul(attn, V)
|
75 |
+
out = torch.mm(
|
76 |
+
heads.permute(1, 2, 0, 3).contiguous().view(-1, self.n_heads * self.val_dim),
|
77 |
+
self.W_out.view(-1, self.embed_dim)
|
78 |
+
).view(batch_size, n_query, self.embed_dim)
|
79 |
+
|
80 |
+
return out
|
81 |
+
|
82 |
+
|
83 |
+
class Decoder(nn.Module):
|
84 |
+
def __init__(self, num_heads, input_size, embedding_size, softmax_output=False, C=10):
|
85 |
+
super().__init__()
|
86 |
+
self.embedding_size = embedding_size
|
87 |
+
self.initial_embedding = nn.Linear(input_size, embedding_size)
|
88 |
+
self.attention = Attention(n_heads=num_heads, input_dim=embedding_size, embed_dim=embedding_size)
|
89 |
+
self.classifier_output = ClassifierOutput(embedding_size=embedding_size, C=C, softmax_output=softmax_output)
|
90 |
+
|
91 |
+
def forward(self, decoder_input, projections, mask, *args, **kwargs):
|
92 |
+
mask = (mask == 0)
|
93 |
+
K = projections['K']
|
94 |
+
V = projections['V']
|
95 |
+
V_output = projections['V_output']
|
96 |
+
|
97 |
+
embedded_input = self.initial_embedding(decoder_input)
|
98 |
+
context = self.attention(embedded_input, K, V, mask)
|
99 |
+
output = self.classifier_output(context, V_output)
|
100 |
+
return output
|