|
import torch |
|
import torch.nn as nn |
|
import math |
|
from .positional_encodings import PositionalEncoding |
|
from .encoding_layers import position_wide_feed_forward |
|
|
|
class DecoderLayer(nn.Module): |
|
def __init__(self, dimension_for_model, num_of_heads, dim_feedforward=2048, dropout=0.1): |
|
|
|
''' |
|
dimension_for_model: the desired dimension of model as specified from the embeddings layer |
|
num_of_heads: the desired number of heads wanted from the multi-head-attention mechanism, also specified within encoders |
|
dim_feedforward: the dimension for the feedforward module, defaulted to 2048 |
|
dropout: mechanism to remove model dependencies on other factors, defaulted to 0.1 |
|
''' |
|
|
|
super().__init__() |
|
self.self_attn = nn.MultiheadAttention(dimension_for_model, num_of_heads, dropout=dropout) |
|
self.cross_attn = nn.MultiheadAttention(dimension_for_model, num_of_heads, dropout=dropout) |
|
self.ffn = nn.Sequential( |
|
nn.Linear(dimension_for_model, dim_feedforward), |
|
nn.ReLU(), |
|
nn.Linear(dim_feedforward, dimension_for_model), |
|
) |
|
|
|
|
|
self.norm1 = nn.LayerNorm(dimension_for_model) |
|
self.norm2 = nn.LayerNorm(dimension_for_model) |
|
self.norm3 = nn.LayerNorm(dimension_for_model) |
|
|
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
self.dropout3 = nn.Dropout(dropout) |
|
|
|
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None): |
|
tgt_t = tgt.transpose(0, 1) |
|
memory_t = memory.transpose(0, 1) |
|
|
|
|
|
_tgt = tgt |
|
tgt2, _ = self.self_attn(tgt_t, tgt_t, tgt_t, attn_mask=tgt_mask) |
|
tgt2 = tgt2.transpose(0, 1) |
|
tgt = self.norm1(_tgt + self.dropout1(tgt2)) |
|
|
|
|
|
_tgt = tgt |
|
tgt2, _ = self.cross_attn(tgt_t, memory_t, memory_t, attn_mask=memory_mask) |
|
tgt2 = tgt2.transpose(0, 1) |
|
tgt = self.norm2(_tgt + self.dropout2(tgt2)) |
|
|
|
|
|
_tgt = tgt |
|
tgt2 = self.ffn(tgt) |
|
tgt = self.norm3(_tgt + self.dropout3(tgt2)) |
|
|
|
return tgt |
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, vocab_size, dimension_for_model, num_layers, num_of_heads, dim_feedforward=2048, dropout=0.1, max_len=5000): |
|
super().__init__() |
|
self.embed = nn.Embedding(vocab_size, dimension_for_model) |
|
self.pe = PositionalEncoding(dimension_for_model, dropout=dropout, max_len=max_len) |
|
self.layers = nn.ModuleList([ |
|
DecoderLayer(dimension_for_model, num_of_heads, dim_feedforward, dropout) |
|
for _ in range(num_layers) |
|
]) |
|
self.norm = nn.LayerNorm(dimension_for_model) |
|
|
|
def forward(self, tgt_seq, memory, tgt_mask=None, memory_mask=None): |
|
x = self.embed(tgt_seq) * math.sqrt(self.embed.embedding_dim) |
|
x = self.pe(x) |
|
for layer in self.layers: |
|
x = layer(x, memory, tgt_mask=tgt_mask, memory_mask=memory_mask) |
|
return self.norm(x) |