File size: 3,727 Bytes
2e8db15 c1384eb 2e8db15 c1384eb 2e8db15 c1384eb 2e8db15 c1384eb 2e8db15 c1384eb 2e8db15 c1384eb 2e8db15 c1384eb 2e8db15 c1384eb 2e8db15 c1384eb 2e8db15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.multi_headed_attention import MultiHeadAttention
import math
class SkipConnection(nn.Module):
def __init__(self, module, use_mask=True):
super(SkipConnection, self).__init__()
self.use_mask = use_mask
self.module = module
def forward(self, input):
if isinstance(input, tuple):
if len(input) > 1:
input, mask = input[0], input[1]
else:
input = input[0]
mask = None
else:
mask = None
if self.use_mask:
return input + self.module(input, mask=mask), mask
else:
return input + self.module(input), mask
class Normalization(nn.Module):
def __init__(self, embed_dim, normalization='batch'):
super(Normalization, self).__init__()
normalizer_class = {
'batch': nn.BatchNorm1d,
'instance': nn.InstanceNorm1d
}.get(normalization, None)
self.normalizer = normalizer_class(embed_dim, affine=True)
def forward(self, input):
if isinstance(input, tuple):
if len(input) > 1:
input, mask = input[0], input[1]
else:
input = input[0]
mask = None
else:
mask = None
if isinstance(self.normalizer, nn.BatchNorm1d):
return self.normalizer(input.view(-1, input.size(-1))).view(*input.size()), mask
elif isinstance(self.normalizer, nn.InstanceNorm1d):
return self.normalizer(input.permute(0, 2, 1)).permute(0, 2, 1), mask
else:
return input, mask
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, n_heads, embed_dim, feed_forward_hidden=512, normalization='batch'):
super(MultiHeadAttentionLayer, self).__init__()
self.attention = SkipConnection(
MultiHeadAttention(n_heads, input_dim=embed_dim, embed_dim=embed_dim),
use_mask=True
)
self.norm1 = Normalization(embed_dim, normalization)
self.ff = SkipConnection(
nn.Sequential(
nn.Linear(embed_dim, feed_forward_hidden),
nn.ReLU(),
nn.Linear(feed_forward_hidden, embed_dim)
) if feed_forward_hidden > 0 else nn.Linear(embed_dim, embed_dim),
use_mask=False
)
self.norm2 = Normalization(embed_dim, normalization)
def forward(self, input):
h, mask = self.attention(input)
h, mask = self.norm1((h, mask))
h, mask = self.ff((h, mask))
h, mask = self.norm2((h, mask))
return h, mask
class Encoder(nn.Module):
def __init__(self, n_heads, embed_dim, n_layers, node_dim=None,
normalization='batch', feed_forward_hidden=200):
super(Encoder, self).__init__()
self.init_embed = nn.Linear(node_dim, embed_dim) if node_dim is not None else None
self.layers = nn.ModuleList([
MultiHeadAttentionLayer(
n_heads, embed_dim,
feed_forward_hidden=feed_forward_hidden,
normalization=normalization
) for _ in range(n_layers)
])
def forward(self, input, mask=None):
device = input.device
batch_size = input.shape[0]
num_nodes = input.shape[1]
if mask is None:
mask = torch.ones(batch_size, num_nodes, num_nodes).to(device).float()
mask = (mask == 0)
x = self.init_embed(input) if self.init_embed is not None else input
h = x
for layer in self.layers:
h, mask = layer((h, mask))
return h
|