|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
import math |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
def __init__(self, n_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None): |
|
super(MultiHeadAttention, self).__init__() |
|
|
|
if val_dim is None: |
|
assert embed_dim is not None, "Provide either embed_dim or val_dim" |
|
val_dim = embed_dim // n_heads |
|
if key_dim is None: |
|
key_dim = val_dim |
|
|
|
self.n_heads = n_heads |
|
self.input_dim = input_dim |
|
self.embed_dim = embed_dim |
|
self.val_dim = val_dim |
|
self.key_dim = key_dim |
|
|
|
self.norm_factor = 1 / math.sqrt(key_dim) |
|
|
|
self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) |
|
self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) |
|
self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim)) |
|
self.W_out = nn.Parameter(torch.Tensor(n_heads * val_dim, embed_dim)) |
|
|
|
self.init_parameters() |
|
|
|
def init_parameters(self): |
|
for param in self.parameters(): |
|
stdv = 1. / math.sqrt(param.size(-1)) |
|
param.data.uniform_(-stdv, stdv) |
|
|
|
def forward(self, q, h=None, mask=None): |
|
if h is None: |
|
h = q |
|
|
|
batch_size, graph_size, input_dim = h.size() |
|
n_query = q.size(1) |
|
|
|
hflat = h.contiguous().view(-1, input_dim) |
|
qflat = q.contiguous().view(-1, input_dim) |
|
|
|
K = torch.matmul(hflat, self.W_key).view(self.n_heads, batch_size, graph_size, self.key_dim) |
|
V = torch.matmul(hflat, self.W_val).view(self.n_heads, batch_size, graph_size, self.val_dim) |
|
Q = torch.matmul(qflat, self.W_query).view(self.n_heads, batch_size, n_query, self.key_dim) |
|
|
|
|
|
compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) |
|
|
|
if mask is not None: |
|
mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility) |
|
compatibility = compatibility.masked_fill(mask, -1e9) |
|
|
|
attn = F.softmax(compatibility, dim=-1) |
|
|
|
|
|
heads = torch.matmul(attn, V) |
|
|
|
|
|
heads = heads.permute(1, 2, 0, 3).contiguous().view(batch_size, n_query, -1) |
|
out = torch.matmul(heads, self.W_out) |
|
|
|
return out |
|
|