|
import torch |
|
import torch.nn as nn |
|
import math |
|
|
|
class MultiHeadAttention(nn.Module): |
|
def __init__(self, dimension_for_model, num_of_heads, dropout = 0.1): |
|
''' |
|
initializes multi-head attention module |
|
dimension_for_model: the same variable as the one in the embeddings, meaning the dimensionality of the embeddings |
|
num_heads: the number of attention heads |
|
dropout: as explained in positional_encodings, the dropout rate, defaulted to 0.1 |
|
''' |
|
|
|
|
|
super(MultiHeadAttention, self).__init__() |
|
assert dimension_for_model % num_of_heads == 0, "dimension_for_model must be devisible by num_of_heads" |
|
|
|
self.num_of_heads = num_of_heads |
|
self.dimension_for_model = dimension_for_model |
|
self.d_k = dimension_for_model//num_of_heads |
|
|
|
|
|
self.linear_query = nn.Linear(dimension_for_model, dimension_for_model) |
|
self.linear_key = nn.Linear(dimension_for_model, dimension_for_model) |
|
self.linear_value = nn.Linear(dimension_for_model, dimension_for_model) |
|
self.linear_out = nn.Linear(dimension_for_model, dimension_for_model) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
self.softmax = nn.Softmax(dim = -1) |
|
|
|
def forward(self, query, key, value, mask = None): |
|
''' |
|
Forward pass for multi-head attention. |
|
query: tensors with shape batch_size, sequence_length, dimension_for_model |
|
key: same as query |
|
value: same as query and key |
|
mask: a tensor that can be applied to attention scores |
|
''' |
|
batch_size = query.size(0) |
|
seq_len = query.size(1) |
|
|
|
|
|
Q = self.linear_query(query) |
|
K = self.linear_key(key) |
|
V = self.linear_value(value) |
|
|
|
|
|
Q = Q.view(batch_size, seq_len, self.num_of_heads, self.d_k).transpose(1,2) |
|
K = K.view(batch_size, seq_len, self.num_of_heads, self.d_k).transpose(1,2) |
|
V = V.view(batch_size, seq_len, self.num_of_heads, self.d_k).transpose(1,2) |
|
|
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) |
|
|
|
|
|
if mask is not None: |
|
scores = scores.masked_fill(mask == 0, -1e9) |
|
|
|
|
|
attn = self.softmax(scores) |
|
attn = self.dropout(attn) |
|
|
|
output = torch.matmul(attn, V) |
|
|
|
|
|
output = output.transpose(1,2).contiguous().view(batch_size, seq_len, self.dimension_for_model) |
|
|
|
|
|
output = self.linear_out(output) |
|
|
|
return output, attn |