|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from nets.multi_headed_attention import MultiHeadAttention |
|
import math |
|
|
|
|
|
class SkipConnection(nn.Module): |
|
def __init__(self, module, use_mask=True): |
|
super(SkipConnection, self).__init__() |
|
self.use_mask = use_mask |
|
self.module = module |
|
|
|
def forward(self, input): |
|
if isinstance(input, tuple): |
|
if len(input) > 1: |
|
input, mask = input[0], input[1] |
|
else: |
|
input = input[0] |
|
mask = None |
|
else: |
|
mask = None |
|
|
|
if self.use_mask: |
|
return input + self.module(input, mask=mask), mask |
|
else: |
|
return input + self.module(input), mask |
|
|
|
|
|
class Normalization(nn.Module): |
|
def __init__(self, embed_dim, normalization='batch'): |
|
super(Normalization, self).__init__() |
|
normalizer_class = { |
|
'batch': nn.BatchNorm1d, |
|
'instance': nn.InstanceNorm1d |
|
}.get(normalization, None) |
|
|
|
self.normalizer = normalizer_class(embed_dim, affine=True) |
|
|
|
def forward(self, input): |
|
if isinstance(input, tuple): |
|
if len(input) > 1: |
|
input, mask = input[0], input[1] |
|
else: |
|
input = input[0] |
|
mask = None |
|
else: |
|
mask = None |
|
|
|
if isinstance(self.normalizer, nn.BatchNorm1d): |
|
return self.normalizer(input.view(-1, input.size(-1))).view(*input.size()), mask |
|
elif isinstance(self.normalizer, nn.InstanceNorm1d): |
|
return self.normalizer(input.permute(0, 2, 1)).permute(0, 2, 1), mask |
|
else: |
|
return input, mask |
|
|
|
|
|
class MultiHeadAttentionLayer(nn.Module): |
|
def __init__(self, n_heads, embed_dim, feed_forward_hidden=512, normalization='batch'): |
|
super(MultiHeadAttentionLayer, self).__init__() |
|
self.attention = SkipConnection( |
|
MultiHeadAttention(n_heads, input_dim=embed_dim, embed_dim=embed_dim), |
|
use_mask=True |
|
) |
|
self.norm1 = Normalization(embed_dim, normalization) |
|
self.ff = SkipConnection( |
|
nn.Sequential( |
|
nn.Linear(embed_dim, feed_forward_hidden), |
|
nn.ReLU(), |
|
nn.Linear(feed_forward_hidden, embed_dim) |
|
) if feed_forward_hidden > 0 else nn.Linear(embed_dim, embed_dim), |
|
use_mask=False |
|
) |
|
self.norm2 = Normalization(embed_dim, normalization) |
|
|
|
def forward(self, input): |
|
h, mask = self.attention(input) |
|
h, mask = self.norm1((h, mask)) |
|
h, mask = self.ff((h, mask)) |
|
h, mask = self.norm2((h, mask)) |
|
return h, mask |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, n_heads, embed_dim, n_layers, node_dim=None, |
|
normalization='batch', feed_forward_hidden=200): |
|
super(Encoder, self).__init__() |
|
|
|
self.init_embed = nn.Linear(node_dim, embed_dim) if node_dim is not None else None |
|
|
|
self.layers = nn.ModuleList([ |
|
MultiHeadAttentionLayer( |
|
n_heads, embed_dim, |
|
feed_forward_hidden=feed_forward_hidden, |
|
normalization=normalization |
|
) for _ in range(n_layers) |
|
]) |
|
|
|
def forward(self, input, mask=None): |
|
device = input.device |
|
batch_size = input.shape[0] |
|
num_nodes = input.shape[1] |
|
|
|
if mask is None: |
|
mask = torch.ones(batch_size, num_nodes, num_nodes).to(device).float() |
|
mask = (mask == 0) |
|
|
|
x = self.init_embed(input) if self.init_embed is not None else input |
|
h = x |
|
for layer in self.layers: |
|
h, mask = layer((h, mask)) |
|
|
|
return h |
|
|