Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import math | |
class Projections(nn.Module): | |
def __init__(self, n_heads, embed_dim): | |
super(Projections, self).__init__() | |
self.n_heads = n_heads | |
self.embed_dim = embed_dim | |
self.val_dim = embed_dim // n_heads | |
self.W_key = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim)) | |
self.W_val = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim)) | |
self.W_output = nn.Parameter(torch.Tensor(1, embed_dim, embed_dim)) | |
self.init_parameters() | |
def init_parameters(self): | |
for param in self.parameters(): | |
stdv = 1. / math.sqrt(param.size(-1)) | |
param.data.uniform_(-stdv, stdv) | |
def forward(self, h): | |
""" | |
:param h: (batch_size, graph_size, embed_dim) | |
:return: dict with keys K, V, V_output for attention | |
""" | |
batch_size, graph_size, input_dim = h.size() | |
hflat = h.view(-1, input_dim) # (batch_size * graph_size, embed_dim) | |
shp = (self.n_heads, batch_size, graph_size, self.val_dim) | |
# Apply projections | |
K = torch.matmul(hflat, self.W_key).view(shp) # (n_heads, batch_size, graph_size, val_dim) | |
V = torch.matmul(hflat, self.W_val).view(shp) # (n_heads, batch_size, graph_size, val_dim) | |
# Output projection | |
V_output = torch.bmm(h, self.W_output.repeat(batch_size, 1, 1)) # (batch_size, graph_size, embed_dim) | |
return { | |
'K': K, | |
'V': V, | |
'V_output': V_output | |
} | |