|
import torch |
|
import torch.nn as nn |
|
import math |
|
|
|
|
|
class Projections(nn.Module): |
|
def __init__(self, n_heads, embed_dim): |
|
super(Projections, self).__init__() |
|
|
|
self.n_heads = n_heads |
|
self.embed_dim = embed_dim |
|
self.val_dim = embed_dim // n_heads |
|
|
|
self.W_key = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim)) |
|
self.W_val = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim)) |
|
self.W_output = nn.Parameter(torch.Tensor(embed_dim, embed_dim)) |
|
|
|
self.init_parameters() |
|
|
|
def init_parameters(self): |
|
for param in self.parameters(): |
|
stdv = 1. / math.sqrt(param.size(-1)) |
|
param.data.uniform_(-stdv, stdv) |
|
|
|
def forward(self, h): |
|
""" |
|
:param h: Tensor of shape (batch_size, graph_size, embed_dim) |
|
:return: dict with keys: K, V, V_output |
|
""" |
|
batch_size, graph_size, input_dim = h.size() |
|
hflat = h.contiguous().view(-1, input_dim) |
|
|
|
|
|
shp = (self.n_heads, batch_size, graph_size, self.val_dim) |
|
K = torch.matmul(hflat, self.W_key).view(shp) |
|
V = torch.matmul(hflat, self.W_val).view(shp) |
|
|
|
|
|
V_output = torch.matmul(h, self.W_output.expand_as(self.W_output)) |
|
|
|
return { |
|
'K': K, |
|
'V': V, |
|
'V_output': V_output |
|
} |
|
|