a-ragab-h-m commited on
Commit
c1384eb
·
verified ·
1 Parent(s): 7494603

Update nets/encoder.py

Browse files
Files changed (1) hide show
  1. nets/encoder.py +31 -24
nets/encoder.py CHANGED
@@ -17,6 +17,7 @@ class SkipConnection(nn.Module):
17
  input, mask = input[0], input[1]
18
  else:
19
  input = input[0]
 
20
  else:
21
  mask = None
22
 
@@ -29,7 +30,6 @@ class SkipConnection(nn.Module):
29
  class Normalization(nn.Module):
30
  def __init__(self, embed_dim, normalization='batch'):
31
  super(Normalization, self).__init__()
32
-
33
  normalizer_class = {
34
  'batch': nn.BatchNorm1d,
35
  'instance': nn.InstanceNorm1d
@@ -43,6 +43,7 @@ class Normalization(nn.Module):
43
  input, mask = input[0], input[1]
44
  else:
45
  input = input[0]
 
46
  else:
47
  mask = None
48
 
@@ -51,28 +52,33 @@ class Normalization(nn.Module):
51
  elif isinstance(self.normalizer, nn.InstanceNorm1d):
52
  return self.normalizer(input.permute(0, 2, 1)).permute(0, 2, 1), mask
53
  else:
54
- assert self.normalizer is None, "Unknown normalizer type"
55
  return input, mask
56
 
57
 
58
- class MultiHeadAttentionLayer(nn.Sequential):
59
  def __init__(self, n_heads, embed_dim, feed_forward_hidden=512, normalization='batch'):
60
- super(MultiHeadAttentionLayer, self).__init__(
61
- SkipConnection(
62
- MultiHeadAttention(n_heads, input_dim=embed_dim, embed_dim=embed_dim),
63
- use_mask=True
64
- ),
65
- Normalization(embed_dim, normalization),
66
- SkipConnection(
67
- nn.Sequential(
68
- nn.Linear(embed_dim, feed_forward_hidden),
69
- nn.ReLU(),
70
- nn.Linear(feed_forward_hidden, embed_dim)
71
- ) if feed_forward_hidden > 0 else nn.Linear(embed_dim, embed_dim),
72
- use_mask=False
73
- ),
74
- Normalization(embed_dim, normalization)
75
  )
 
 
 
 
 
 
 
 
76
 
77
 
78
  class Encoder(nn.Module):
@@ -82,13 +88,13 @@ class Encoder(nn.Module):
82
 
83
  self.init_embed = nn.Linear(node_dim, embed_dim) if node_dim is not None else None
84
 
85
- self.layers = nn.Sequential(*(
86
  MultiHeadAttentionLayer(
87
  n_heads, embed_dim,
88
  feed_forward_hidden=feed_forward_hidden,
89
  normalization=normalization
90
  ) for _ in range(n_layers)
91
- ))
92
 
93
  def forward(self, input, mask=None):
94
  device = input.device
@@ -97,10 +103,11 @@ class Encoder(nn.Module):
97
 
98
  if mask is None:
99
  mask = torch.ones(batch_size, num_nodes, num_nodes).to(device).float()
 
100
 
101
- mask = (mask == 0) # invert mask: 1s where we want to mask
 
 
 
102
 
103
- x = input
104
- h = self.init_embed(x.view(-1, x.size(-1))).view(*x.size()[:2], -1) if self.init_embed is not None else x
105
- h, _ = self.layers((h, mask)) # Pass both h and mask through layers
106
  return h
 
17
  input, mask = input[0], input[1]
18
  else:
19
  input = input[0]
20
+ mask = None
21
  else:
22
  mask = None
23
 
 
30
  class Normalization(nn.Module):
31
  def __init__(self, embed_dim, normalization='batch'):
32
  super(Normalization, self).__init__()
 
33
  normalizer_class = {
34
  'batch': nn.BatchNorm1d,
35
  'instance': nn.InstanceNorm1d
 
43
  input, mask = input[0], input[1]
44
  else:
45
  input = input[0]
46
+ mask = None
47
  else:
48
  mask = None
49
 
 
52
  elif isinstance(self.normalizer, nn.InstanceNorm1d):
53
  return self.normalizer(input.permute(0, 2, 1)).permute(0, 2, 1), mask
54
  else:
 
55
  return input, mask
56
 
57
 
58
+ class MultiHeadAttentionLayer(nn.Module):
59
  def __init__(self, n_heads, embed_dim, feed_forward_hidden=512, normalization='batch'):
60
+ super(MultiHeadAttentionLayer, self).__init__()
61
+ self.attention = SkipConnection(
62
+ MultiHeadAttention(n_heads, input_dim=embed_dim, embed_dim=embed_dim),
63
+ use_mask=True
64
+ )
65
+ self.norm1 = Normalization(embed_dim, normalization)
66
+ self.ff = SkipConnection(
67
+ nn.Sequential(
68
+ nn.Linear(embed_dim, feed_forward_hidden),
69
+ nn.ReLU(),
70
+ nn.Linear(feed_forward_hidden, embed_dim)
71
+ ) if feed_forward_hidden > 0 else nn.Linear(embed_dim, embed_dim),
72
+ use_mask=False
 
 
73
  )
74
+ self.norm2 = Normalization(embed_dim, normalization)
75
+
76
+ def forward(self, input):
77
+ h, mask = self.attention(input)
78
+ h, mask = self.norm1((h, mask))
79
+ h, mask = self.ff((h, mask))
80
+ h, mask = self.norm2((h, mask))
81
+ return h, mask
82
 
83
 
84
  class Encoder(nn.Module):
 
88
 
89
  self.init_embed = nn.Linear(node_dim, embed_dim) if node_dim is not None else None
90
 
91
+ self.layers = nn.ModuleList([
92
  MultiHeadAttentionLayer(
93
  n_heads, embed_dim,
94
  feed_forward_hidden=feed_forward_hidden,
95
  normalization=normalization
96
  ) for _ in range(n_layers)
97
+ ])
98
 
99
  def forward(self, input, mask=None):
100
  device = input.device
 
103
 
104
  if mask is None:
105
  mask = torch.ones(batch_size, num_nodes, num_nodes).to(device).float()
106
+ mask = (mask == 0)
107
 
108
+ x = self.init_embed(input) if self.init_embed is not None else input
109
+ h = x
110
+ for layer in self.layers:
111
+ h, mask = layer((h, mask))
112
 
 
 
 
113
  return h