Update nets/encoder.py
Browse files- nets/encoder.py +31 -24
nets/encoder.py
CHANGED
@@ -17,6 +17,7 @@ class SkipConnection(nn.Module):
|
|
17 |
input, mask = input[0], input[1]
|
18 |
else:
|
19 |
input = input[0]
|
|
|
20 |
else:
|
21 |
mask = None
|
22 |
|
@@ -29,7 +30,6 @@ class SkipConnection(nn.Module):
|
|
29 |
class Normalization(nn.Module):
|
30 |
def __init__(self, embed_dim, normalization='batch'):
|
31 |
super(Normalization, self).__init__()
|
32 |
-
|
33 |
normalizer_class = {
|
34 |
'batch': nn.BatchNorm1d,
|
35 |
'instance': nn.InstanceNorm1d
|
@@ -43,6 +43,7 @@ class Normalization(nn.Module):
|
|
43 |
input, mask = input[0], input[1]
|
44 |
else:
|
45 |
input = input[0]
|
|
|
46 |
else:
|
47 |
mask = None
|
48 |
|
@@ -51,28 +52,33 @@ class Normalization(nn.Module):
|
|
51 |
elif isinstance(self.normalizer, nn.InstanceNorm1d):
|
52 |
return self.normalizer(input.permute(0, 2, 1)).permute(0, 2, 1), mask
|
53 |
else:
|
54 |
-
assert self.normalizer is None, "Unknown normalizer type"
|
55 |
return input, mask
|
56 |
|
57 |
|
58 |
-
class MultiHeadAttentionLayer(nn.
|
59 |
def __init__(self, n_heads, embed_dim, feed_forward_hidden=512, normalization='batch'):
|
60 |
-
super(MultiHeadAttentionLayer, self).__init__(
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
),
|
74 |
-
Normalization(embed_dim, normalization)
|
75 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
|
78 |
class Encoder(nn.Module):
|
@@ -82,13 +88,13 @@ class Encoder(nn.Module):
|
|
82 |
|
83 |
self.init_embed = nn.Linear(node_dim, embed_dim) if node_dim is not None else None
|
84 |
|
85 |
-
self.layers = nn.
|
86 |
MultiHeadAttentionLayer(
|
87 |
n_heads, embed_dim,
|
88 |
feed_forward_hidden=feed_forward_hidden,
|
89 |
normalization=normalization
|
90 |
) for _ in range(n_layers)
|
91 |
-
)
|
92 |
|
93 |
def forward(self, input, mask=None):
|
94 |
device = input.device
|
@@ -97,10 +103,11 @@ class Encoder(nn.Module):
|
|
97 |
|
98 |
if mask is None:
|
99 |
mask = torch.ones(batch_size, num_nodes, num_nodes).to(device).float()
|
|
|
100 |
|
101 |
-
|
|
|
|
|
|
|
102 |
|
103 |
-
x = input
|
104 |
-
h = self.init_embed(x.view(-1, x.size(-1))).view(*x.size()[:2], -1) if self.init_embed is not None else x
|
105 |
-
h, _ = self.layers((h, mask)) # Pass both h and mask through layers
|
106 |
return h
|
|
|
17 |
input, mask = input[0], input[1]
|
18 |
else:
|
19 |
input = input[0]
|
20 |
+
mask = None
|
21 |
else:
|
22 |
mask = None
|
23 |
|
|
|
30 |
class Normalization(nn.Module):
|
31 |
def __init__(self, embed_dim, normalization='batch'):
|
32 |
super(Normalization, self).__init__()
|
|
|
33 |
normalizer_class = {
|
34 |
'batch': nn.BatchNorm1d,
|
35 |
'instance': nn.InstanceNorm1d
|
|
|
43 |
input, mask = input[0], input[1]
|
44 |
else:
|
45 |
input = input[0]
|
46 |
+
mask = None
|
47 |
else:
|
48 |
mask = None
|
49 |
|
|
|
52 |
elif isinstance(self.normalizer, nn.InstanceNorm1d):
|
53 |
return self.normalizer(input.permute(0, 2, 1)).permute(0, 2, 1), mask
|
54 |
else:
|
|
|
55 |
return input, mask
|
56 |
|
57 |
|
58 |
+
class MultiHeadAttentionLayer(nn.Module):
|
59 |
def __init__(self, n_heads, embed_dim, feed_forward_hidden=512, normalization='batch'):
|
60 |
+
super(MultiHeadAttentionLayer, self).__init__()
|
61 |
+
self.attention = SkipConnection(
|
62 |
+
MultiHeadAttention(n_heads, input_dim=embed_dim, embed_dim=embed_dim),
|
63 |
+
use_mask=True
|
64 |
+
)
|
65 |
+
self.norm1 = Normalization(embed_dim, normalization)
|
66 |
+
self.ff = SkipConnection(
|
67 |
+
nn.Sequential(
|
68 |
+
nn.Linear(embed_dim, feed_forward_hidden),
|
69 |
+
nn.ReLU(),
|
70 |
+
nn.Linear(feed_forward_hidden, embed_dim)
|
71 |
+
) if feed_forward_hidden > 0 else nn.Linear(embed_dim, embed_dim),
|
72 |
+
use_mask=False
|
|
|
|
|
73 |
)
|
74 |
+
self.norm2 = Normalization(embed_dim, normalization)
|
75 |
+
|
76 |
+
def forward(self, input):
|
77 |
+
h, mask = self.attention(input)
|
78 |
+
h, mask = self.norm1((h, mask))
|
79 |
+
h, mask = self.ff((h, mask))
|
80 |
+
h, mask = self.norm2((h, mask))
|
81 |
+
return h, mask
|
82 |
|
83 |
|
84 |
class Encoder(nn.Module):
|
|
|
88 |
|
89 |
self.init_embed = nn.Linear(node_dim, embed_dim) if node_dim is not None else None
|
90 |
|
91 |
+
self.layers = nn.ModuleList([
|
92 |
MultiHeadAttentionLayer(
|
93 |
n_heads, embed_dim,
|
94 |
feed_forward_hidden=feed_forward_hidden,
|
95 |
normalization=normalization
|
96 |
) for _ in range(n_layers)
|
97 |
+
])
|
98 |
|
99 |
def forward(self, input, mask=None):
|
100 |
device = input.device
|
|
|
103 |
|
104 |
if mask is None:
|
105 |
mask = torch.ones(batch_size, num_nodes, num_nodes).to(device).float()
|
106 |
+
mask = (mask == 0)
|
107 |
|
108 |
+
x = self.init_embed(input) if self.init_embed is not None else input
|
109 |
+
h = x
|
110 |
+
for layer in self.layers:
|
111 |
+
h, mask = layer((h, mask))
|
112 |
|
|
|
|
|
|
|
113 |
return h
|