File size: 3,727 Bytes
2e8db15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1384eb
2e8db15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1384eb
2e8db15
 
 
 
 
 
 
 
 
 
 
c1384eb
2e8db15
c1384eb
 
 
 
 
 
 
 
 
 
 
 
 
2e8db15
c1384eb
 
 
 
 
 
 
 
2e8db15
 
 
 
 
 
 
 
 
c1384eb
2e8db15
 
 
 
 
c1384eb
2e8db15
 
 
 
 
 
 
 
c1384eb
2e8db15
c1384eb
 
 
 
2e8db15
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.multi_headed_attention import MultiHeadAttention
import math


class SkipConnection(nn.Module):
    def __init__(self, module, use_mask=True):
        super(SkipConnection, self).__init__()
        self.use_mask = use_mask
        self.module = module

    def forward(self, input):
        if isinstance(input, tuple):
            if len(input) > 1:
                input, mask = input[0], input[1]
            else:
                input = input[0]
                mask = None
        else:
            mask = None

        if self.use_mask:
            return input + self.module(input, mask=mask), mask
        else:
            return input + self.module(input), mask


class Normalization(nn.Module):
    def __init__(self, embed_dim, normalization='batch'):
        super(Normalization, self).__init__()
        normalizer_class = {
            'batch': nn.BatchNorm1d,
            'instance': nn.InstanceNorm1d
        }.get(normalization, None)

        self.normalizer = normalizer_class(embed_dim, affine=True)

    def forward(self, input):
        if isinstance(input, tuple):
            if len(input) > 1:
                input, mask = input[0], input[1]
            else:
                input = input[0]
                mask = None
        else:
            mask = None

        if isinstance(self.normalizer, nn.BatchNorm1d):
            return self.normalizer(input.view(-1, input.size(-1))).view(*input.size()), mask
        elif isinstance(self.normalizer, nn.InstanceNorm1d):
            return self.normalizer(input.permute(0, 2, 1)).permute(0, 2, 1), mask
        else:
            return input, mask


class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, n_heads, embed_dim, feed_forward_hidden=512, normalization='batch'):
        super(MultiHeadAttentionLayer, self).__init__()
        self.attention = SkipConnection(
            MultiHeadAttention(n_heads, input_dim=embed_dim, embed_dim=embed_dim),
            use_mask=True
        )
        self.norm1 = Normalization(embed_dim, normalization)
        self.ff = SkipConnection(
            nn.Sequential(
                nn.Linear(embed_dim, feed_forward_hidden),
                nn.ReLU(),
                nn.Linear(feed_forward_hidden, embed_dim)
            ) if feed_forward_hidden > 0 else nn.Linear(embed_dim, embed_dim),
            use_mask=False
        )
        self.norm2 = Normalization(embed_dim, normalization)

    def forward(self, input):
        h, mask = self.attention(input)
        h, mask = self.norm1((h, mask))
        h, mask = self.ff((h, mask))
        h, mask = self.norm2((h, mask))
        return h, mask


class Encoder(nn.Module):
    def __init__(self, n_heads, embed_dim, n_layers, node_dim=None,
                 normalization='batch', feed_forward_hidden=200):
        super(Encoder, self).__init__()

        self.init_embed = nn.Linear(node_dim, embed_dim) if node_dim is not None else None

        self.layers = nn.ModuleList([
            MultiHeadAttentionLayer(
                n_heads, embed_dim,
                feed_forward_hidden=feed_forward_hidden,
                normalization=normalization
            ) for _ in range(n_layers)
        ])

    def forward(self, input, mask=None):
        device = input.device
        batch_size = input.shape[0]
        num_nodes = input.shape[1]

        if mask is None:
            mask = torch.ones(batch_size, num_nodes, num_nodes).to(device).float()
        mask = (mask == 0)

        x = self.init_embed(input) if self.init_embed is not None else input
        h = x
        for layer in self.layers:
            h, mask = layer((h, mask))

        return h