vrp-shanghai-transformer / nets /multi_headed_attention.py
a-ragab-h-m's picture
Update nets/multi_headed_attention.py
fd3eb7b verified
import torch
import torch.nn.functional as F
from torch import nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, n_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None):
super(MultiHeadAttention, self).__init__()
if val_dim is None:
assert embed_dim is not None, "Provide either embed_dim or val_dim"
val_dim = embed_dim // n_heads
if key_dim is None:
key_dim = val_dim
self.n_heads = n_heads
self.input_dim = input_dim
self.embed_dim = embed_dim
self.val_dim = val_dim
self.key_dim = key_dim
self.norm_factor = 1 / math.sqrt(key_dim)
self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim))
self.W_out = nn.Parameter(torch.Tensor(n_heads * val_dim, embed_dim))
self.init_parameters()
def init_parameters(self):
for param in self.parameters():
stdv = 1. / math.sqrt(param.size(-1))
param.data.uniform_(-stdv, stdv)
def forward(self, q, h=None, mask=None):
if h is None:
h = q # self-attention
batch_size, graph_size, input_dim = h.size()
n_query = q.size(1)
hflat = h.contiguous().view(-1, input_dim)
qflat = q.contiguous().view(-1, input_dim)
K = torch.matmul(hflat, self.W_key).view(self.n_heads, batch_size, graph_size, self.key_dim)
V = torch.matmul(hflat, self.W_val).view(self.n_heads, batch_size, graph_size, self.val_dim)
Q = torch.matmul(qflat, self.W_query).view(self.n_heads, batch_size, n_query, self.key_dim)
# Compute attention scores
compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) # (n_heads, batch, n_query, graph)
if mask is not None:
mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
compatibility = compatibility.masked_fill(mask, -1e9)
attn = F.softmax(compatibility, dim=-1)
# Apply attention to values
heads = torch.matmul(attn, V) # (n_heads, batch, n_query, val_dim)
# Concatenate heads and project
heads = heads.permute(1, 2, 0, 3).contiguous().view(batch_size, n_query, -1)
out = torch.matmul(heads, self.W_out) # (batch, n_query, embed_dim)
return out