|
|
|
import torch |
|
import torch.nn as nn |
|
import math |
|
from .positional_encodings import PositionalEncoding |
|
from .multihead_attention import MultiHeadAttention |
|
from .encoding_layers import position_wide_feed_forward, Residual_layer |
|
from .masking_for_attention import mask |
|
from .embeddings import Embeddings |
|
|
|
class EncoderLayer(nn.Module): |
|
def __init__(self, dimension_for_model, num_of_heads, dim_feedforward, dropout = 0.1): |
|
''' |
|
dimension_for_model: the dimension desired for the model specified at the embeddings layer |
|
num_of_heads: the number of heads for the multi-head-attention structure to keep track of |
|
dim_feedforward: the dimension of the positional feed forward structure |
|
dropout: structure for removing model dependencies during training, improving robustness |
|
''' |
|
super().__init__() |
|
|
|
self.self_attn = MultiHeadAttention(dimension_for_model, num_of_heads, dropout) |
|
self.norm1 = nn.LayerNorm(dimension_for_model) |
|
self.dropout1 = nn.Dropout(dropout) |
|
|
|
self.ffn = position_wide_feed_forward(dimension_for_model, dim_feedforward, dropout) |
|
self.norm2 = nn.LayerNorm(dimension_for_model) |
|
self.dropout2 = nn.Dropout(dropout) |
|
|
|
def forward(self, src: torch.Tensor, src_mask: torch.Tensor = None) -> torch.Tensor: |
|
|
|
_src = src |
|
attn_output, _ = self.self_attn(src, src, src, mask=src_mask) |
|
src = self.norm1(_src + self.dropout1(attn_output)) |
|
|
|
_src = src |
|
ff_output = self.ffn(src) |
|
src = self.norm2(_src + self.dropout2(ff_output)) |
|
return src |
|
|
|
|
|
class Encoder(nn.Module): |
|
""" |
|
Stacked Transformer encoder: |
|
- embedding + positional encoding |
|
- N encoder layers |
|
- final layer norm |
|
""" |
|
def __init__(self, vocab_size, dimension_of_model, num_of_heads, num_layers, dim_feedforward = 2048, dropout = 0.1, max_len = 5000, num_of_roles=2, max_turns=16): |
|
super().__init__() |
|
|
|
self.embed = Embeddings(vocab_size, dimension_for_model=dimension_of_model, num_of_roles=num_of_roles, max_turns=max_turns) |
|
|
|
self.pe = PositionalEncoding(dimension_of_model, dropout=dropout, max_len=max_len) |
|
|
|
self.layers = nn.ModuleList([ |
|
EncoderLayer(dimension_of_model, num_of_heads, dim_feedforward, dropout) |
|
for _ in range(num_layers) |
|
]) |
|
|
|
self.norm = nn.LayerNorm(dimension_of_model) |
|
|
|
def forward(self, src_ids, roles, turns, src_mask = None) -> torch.Tensor: |
|
""" |
|
Args: |
|
src_ids: [batch_size x seq_len] input token indices |
|
roles: [batch_size x seq_len] role ids |
|
turns: [batch_size x seq_len] turn ids |
|
src_mask: [batch_size, 1, 1, seq_len] mask to prevent attending to padding tokens |
|
""" |
|
|
|
x = self.embed(src_ids, roles, turns) |
|
|
|
x = self.pe(x) |
|
|
|
for layer in self.layers: |
|
x = layer(x, src_mask) |
|
|
|
return self.norm(x) |
|
|
|
def load_state_dict(self, state_dict, strict=True): |
|
""" |
|
Custom state dict loading to handle backward compatibility with old model format |
|
""" |
|
|
|
if 'encoder.embed.weight' in state_dict: |
|
|
|
old_embed_weight = state_dict['encoder.embed.weight'] |
|
|
|
|
|
state_dict['encoder.embed.lut.weight'] = old_embed_weight |
|
state_dict['encoder.embed.lut_roles.weight'] = torch.zeros_like(old_embed_weight) |
|
state_dict['encoder.embed.lut_turns.weight'] = torch.zeros_like(old_embed_weight) |
|
state_dict['encoder.embed.norm.weight'] = torch.ones(old_embed_weight.size(1)) |
|
state_dict['encoder.embed.norm.bias'] = torch.zeros(old_embed_weight.size(1)) |
|
|
|
|
|
del state_dict['encoder.embed.weight'] |
|
|
|
return super().load_state_dict(state_dict, strict=strict) |
|
|
|
|