File size: 1,650 Bytes
c43de6b 722e008 c43de6b 722e008 c43de6b 722e008 c43de6b 722e008 c43de6b 722e008 c43de6b 722e008 c43de6b 722e008 c43de6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import torch
import torch.nn as nn
import math
class Projections(nn.Module):
def __init__(self, n_heads, embed_dim):
super(Projections, self).__init__()
self.n_heads = n_heads
self.embed_dim = embed_dim
self.val_dim = embed_dim // n_heads
self.W_key = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim))
self.W_val = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim))
self.W_output = nn.Parameter(torch.Tensor(embed_dim, embed_dim))
self.init_parameters()
def init_parameters(self):
for param in self.parameters():
stdv = 1. / math.sqrt(param.size(-1))
param.data.uniform_(-stdv, stdv)
def forward(self, h):
"""
:param h: Tensor of shape (batch_size, graph_size, embed_dim)
:return: dict with keys: K, V, V_output
"""
batch_size, graph_size, input_dim = h.size()
hflat = h.contiguous().view(-1, input_dim) # (batch_size * graph_size, embed_dim)
# Compute Keys and Values per head
shp = (self.n_heads, batch_size, graph_size, self.val_dim)
K = torch.matmul(hflat, self.W_key).view(shp)
V = torch.matmul(hflat, self.W_val).view(shp)
# Compute output projection: (batch_size, graph_size, embed_dim)
V_output = torch.matmul(h, self.W_output.expand_as(self.W_output))
return {
'K': K, # (n_heads, batch_size, graph_size, val_dim)
'V': V, # (n_heads, batch_size, graph_size, val_dim)
'V_output': V_output # (batch_size, graph_size, embed_dim)
}
|