File size: 1,650 Bytes
c43de6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722e008
c43de6b
 
 
 
 
 
 
 
 
 
722e008
 
c43de6b
 
722e008
c43de6b
722e008
c43de6b
722e008
 
c43de6b
722e008
 
c43de6b
 
722e008
 
 
c43de6b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torch.nn as nn
import math


class Projections(nn.Module):
    def __init__(self, n_heads, embed_dim):
        super(Projections, self).__init__()

        self.n_heads = n_heads
        self.embed_dim = embed_dim
        self.val_dim = embed_dim // n_heads

        self.W_key = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim))
        self.W_val = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim))
        self.W_output = nn.Parameter(torch.Tensor(embed_dim, embed_dim))

        self.init_parameters()

    def init_parameters(self):
        for param in self.parameters():
            stdv = 1. / math.sqrt(param.size(-1))
            param.data.uniform_(-stdv, stdv)

    def forward(self, h):
        """
        :param h: Tensor of shape (batch_size, graph_size, embed_dim)
        :return: dict with keys: K, V, V_output
        """
        batch_size, graph_size, input_dim = h.size()
        hflat = h.contiguous().view(-1, input_dim)  # (batch_size * graph_size, embed_dim)

        # Compute Keys and Values per head
        shp = (self.n_heads, batch_size, graph_size, self.val_dim)
        K = torch.matmul(hflat, self.W_key).view(shp)
        V = torch.matmul(hflat, self.W_val).view(shp)

        # Compute output projection: (batch_size, graph_size, embed_dim)
        V_output = torch.matmul(h, self.W_output.expand_as(self.W_output))

        return {
            'K': K,             # (n_heads, batch_size, graph_size, val_dim)
            'V': V,             # (n_heads, batch_size, graph_size, val_dim)
            'V_output': V_output  # (batch_size, graph_size, embed_dim)
        }