import torch import torch.nn as nn import torch.nn.functional as F from nets.multi_headed_attention import MultiHeadAttention import math class SkipConnection(nn.Module): def __init__(self, module, use_mask=True): super(SkipConnection, self).__init__() self.use_mask = use_mask self.module = module def forward(self, input): if isinstance(input, tuple): if len(input) > 1: input, mask = input[0], input[1] else: input = input[0] mask = None else: mask = None if self.use_mask: return input + self.module(input, mask=mask), mask else: return input + self.module(input), mask class Normalization(nn.Module): def __init__(self, embed_dim, normalization='batch'): super(Normalization, self).__init__() normalizer_class = { 'batch': nn.BatchNorm1d, 'instance': nn.InstanceNorm1d }.get(normalization, None) self.normalizer = normalizer_class(embed_dim, affine=True) def forward(self, input): if isinstance(input, tuple): if len(input) > 1: input, mask = input[0], input[1] else: input = input[0] mask = None else: mask = None if isinstance(self.normalizer, nn.BatchNorm1d): return self.normalizer(input.view(-1, input.size(-1))).view(*input.size()), mask elif isinstance(self.normalizer, nn.InstanceNorm1d): return self.normalizer(input.permute(0, 2, 1)).permute(0, 2, 1), mask else: return input, mask class MultiHeadAttentionLayer(nn.Module): def __init__(self, n_heads, embed_dim, feed_forward_hidden=512, normalization='batch'): super(MultiHeadAttentionLayer, self).__init__() self.attention = SkipConnection( MultiHeadAttention(n_heads, input_dim=embed_dim, embed_dim=embed_dim), use_mask=True ) self.norm1 = Normalization(embed_dim, normalization) self.ff = SkipConnection( nn.Sequential( nn.Linear(embed_dim, feed_forward_hidden), nn.ReLU(), nn.Linear(feed_forward_hidden, embed_dim) ) if feed_forward_hidden > 0 else nn.Linear(embed_dim, embed_dim), use_mask=False ) self.norm2 = Normalization(embed_dim, normalization) def forward(self, input): h, mask = self.attention(input) h, mask = self.norm1((h, mask)) h, mask = self.ff((h, mask)) h, mask = self.norm2((h, mask)) return h, mask class Encoder(nn.Module): def __init__(self, n_heads, embed_dim, n_layers, node_dim=None, normalization='batch', feed_forward_hidden=200): super(Encoder, self).__init__() self.init_embed = nn.Linear(node_dim, embed_dim) if node_dim is not None else None self.layers = nn.ModuleList([ MultiHeadAttentionLayer( n_heads, embed_dim, feed_forward_hidden=feed_forward_hidden, normalization=normalization ) for _ in range(n_layers) ]) def forward(self, input, mask=None): device = input.device batch_size = input.shape[0] num_nodes = input.shape[1] if mask is None: mask = torch.ones(batch_size, num_nodes, num_nodes).to(device).float() mask = (mask == 0) x = self.init_embed(input) if self.init_embed is not None else input h = x for layer in self.layers: h, mask = layer((h, mask)) return h