a-ragab-h-m commited on
Commit
2e8db15
·
verified ·
1 Parent(s): d639276

Create nets/encoder.py

Browse files
Files changed (1) hide show
  1. nets/encoder.py +106 -0
nets/encoder.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from nets.multi_headed_attention import MultiHeadAttention
5
+ import math
6
+
7
+
8
+ class SkipConnection(nn.Module):
9
+ def __init__(self, module, use_mask=True):
10
+ super(SkipConnection, self).__init__()
11
+ self.use_mask = use_mask
12
+ self.module = module
13
+
14
+ def forward(self, input):
15
+ if isinstance(input, tuple):
16
+ if len(input) > 1:
17
+ input, mask = input[0], input[1]
18
+ else:
19
+ input = input[0]
20
+ else:
21
+ mask = None
22
+
23
+ if self.use_mask:
24
+ return input + self.module(input, mask=mask), mask
25
+ else:
26
+ return input + self.module(input), mask
27
+
28
+
29
+ class Normalization(nn.Module):
30
+ def __init__(self, embed_dim, normalization='batch'):
31
+ super(Normalization, self).__init__()
32
+
33
+ normalizer_class = {
34
+ 'batch': nn.BatchNorm1d,
35
+ 'instance': nn.InstanceNorm1d
36
+ }.get(normalization, None)
37
+
38
+ self.normalizer = normalizer_class(embed_dim, affine=True)
39
+
40
+ def forward(self, input):
41
+ if isinstance(input, tuple):
42
+ if len(input) > 1:
43
+ input, mask = input[0], input[1]
44
+ else:
45
+ input = input[0]
46
+ else:
47
+ mask = None
48
+
49
+ if isinstance(self.normalizer, nn.BatchNorm1d):
50
+ return self.normalizer(input.view(-1, input.size(-1))).view(*input.size()), mask
51
+ elif isinstance(self.normalizer, nn.InstanceNorm1d):
52
+ return self.normalizer(input.permute(0, 2, 1)).permute(0, 2, 1), mask
53
+ else:
54
+ assert self.normalizer is None, "Unknown normalizer type"
55
+ return input, mask
56
+
57
+
58
+ class MultiHeadAttentionLayer(nn.Sequential):
59
+ def __init__(self, n_heads, embed_dim, feed_forward_hidden=512, normalization='batch'):
60
+ super(MultiHeadAttentionLayer, self).__init__(
61
+ SkipConnection(
62
+ MultiHeadAttention(n_heads, input_dim=embed_dim, embed_dim=embed_dim),
63
+ use_mask=True
64
+ ),
65
+ Normalization(embed_dim, normalization),
66
+ SkipConnection(
67
+ nn.Sequential(
68
+ nn.Linear(embed_dim, feed_forward_hidden),
69
+ nn.ReLU(),
70
+ nn.Linear(feed_forward_hidden, embed_dim)
71
+ ) if feed_forward_hidden > 0 else nn.Linear(embed_dim, embed_dim),
72
+ use_mask=False
73
+ ),
74
+ Normalization(embed_dim, normalization)
75
+ )
76
+
77
+
78
+ class Encoder(nn.Module):
79
+ def __init__(self, n_heads, embed_dim, n_layers, node_dim=None,
80
+ normalization='batch', feed_forward_hidden=200):
81
+ super(Encoder, self).__init__()
82
+
83
+ self.init_embed = nn.Linear(node_dim, embed_dim) if node_dim is not None else None
84
+
85
+ self.layers = nn.Sequential(*(
86
+ MultiHeadAttentionLayer(
87
+ n_heads, embed_dim,
88
+ feed_forward_hidden=feed_forward_hidden,
89
+ normalization=normalization
90
+ ) for _ in range(n_layers)
91
+ ))
92
+
93
+ def forward(self, input, mask=None):
94
+ device = input.device
95
+ batch_size = input.shape[0]
96
+ num_nodes = input.shape[1]
97
+
98
+ if mask is None:
99
+ mask = torch.ones(batch_size, num_nodes, num_nodes).to(device).float()
100
+
101
+ mask = (mask == 0) # invert mask: 1s where we want to mask
102
+
103
+ x = input
104
+ h = self.init_embed(x.view(-1, x.size(-1))).view(*x.size()[:2], -1) if self.init_embed is not None else x
105
+ h, _ = self.layers((h, mask)) # Pass both h and mask through layers
106
+ return h