a-ragab-h-m commited on
Commit
e4a98cf
·
verified ·
1 Parent(s): 2e8db15

Create nets/decoder.py

Browse files
Files changed (1) hide show
  1. nets/decoder.py +100 -0
nets/decoder.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ import numpy as np
6
+
7
+
8
+ class ClassifierOutput(nn.Module):
9
+ def __init__(self, embedding_size, C=10, softmax_output=False):
10
+ super().__init__()
11
+ self.C = C
12
+ self.embedding_size = embedding_size
13
+ self.softmax_output = softmax_output
14
+ self.W_q = nn.Parameter(torch.Tensor(1, embedding_size, embedding_size))
15
+ self.init_parameters()
16
+
17
+ def init_parameters(self):
18
+ for param in self.parameters():
19
+ stdv = 1. / math.sqrt(param.size(-1))
20
+ param.data.uniform_(-stdv, stdv)
21
+
22
+ def forward(self, context, V_output):
23
+ batch_size = context.shape[0]
24
+ Q = torch.bmm(context, self.W_q.repeat(batch_size, 1, 1))
25
+ z = torch.bmm(Q, V_output.permute(0, 2, 1))
26
+ z = z / (self.embedding_size ** 0.5)
27
+ z = self.C * torch.tanh(z)
28
+ return F.softmax(z, dim=1) if self.softmax_output else z
29
+
30
+
31
+ class Attention(nn.Module):
32
+ def __init__(self, n_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None):
33
+ super().__init__()
34
+ if val_dim is None:
35
+ assert embed_dim is not None
36
+ val_dim = embed_dim // n_heads
37
+ if key_dim is None:
38
+ key_dim = val_dim
39
+
40
+ self.n_heads = n_heads
41
+ self.input_dim = input_dim
42
+ self.embed_dim = embed_dim
43
+ self.val_dim = val_dim
44
+ self.key_dim = key_dim
45
+ self.norm_factor = 1 / math.sqrt(key_dim)
46
+
47
+ self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
48
+ self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim))
49
+ self.init_parameters()
50
+
51
+ def init_parameters(self):
52
+ for param in self.parameters():
53
+ stdv = 1. / math.sqrt(param.size(-1))
54
+ param.data.uniform_(-stdv, stdv)
55
+
56
+ def forward(self, q, K, V, mask=None):
57
+ batch_size = K.size(1)
58
+ graph_size = K.size(2)
59
+ n_query = q.size(1)
60
+ qflat = q.contiguous().view(-1, self.input_dim)
61
+
62
+ shp_q = (self.n_heads, batch_size, n_query, -1)
63
+ Q = torch.matmul(qflat, self.W_query).view(shp_q)
64
+ compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3))
65
+
66
+ if mask is not None:
67
+ mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
68
+ compatibility[mask] = -np.inf
69
+
70
+ attn = F.softmax(compatibility, dim=-1)
71
+ if mask is not None:
72
+ attn = attn.masked_fill(mask, 0)
73
+
74
+ heads = torch.matmul(attn, V)
75
+ out = torch.mm(
76
+ heads.permute(1, 2, 0, 3).contiguous().view(-1, self.n_heads * self.val_dim),
77
+ self.W_out.view(-1, self.embed_dim)
78
+ ).view(batch_size, n_query, self.embed_dim)
79
+
80
+ return out
81
+
82
+
83
+ class Decoder(nn.Module):
84
+ def __init__(self, num_heads, input_size, embedding_size, softmax_output=False, C=10):
85
+ super().__init__()
86
+ self.embedding_size = embedding_size
87
+ self.initial_embedding = nn.Linear(input_size, embedding_size)
88
+ self.attention = Attention(n_heads=num_heads, input_dim=embedding_size, embed_dim=embedding_size)
89
+ self.classifier_output = ClassifierOutput(embedding_size=embedding_size, C=C, softmax_output=softmax_output)
90
+
91
+ def forward(self, decoder_input, projections, mask, *args, **kwargs):
92
+ mask = (mask == 0)
93
+ K = projections['K']
94
+ V = projections['V']
95
+ V_output = projections['V_output']
96
+
97
+ embedded_input = self.initial_embedding(decoder_input)
98
+ context = self.attention(embedded_input, K, V, mask)
99
+ output = self.classifier_output(context, V_output)
100
+ return output