| from .mlp_loca import MLP | |
| from torch import nn | |
| class TransformerEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| num_layers: int, | |
| emb_dim: int, | |
| num_heads: int, | |
| dropout: float, | |
| layer_norm_eps: float, | |
| mlp_factor: int, | |
| norm_first: bool, | |
| activation: nn.Module, | |
| norm: bool, | |
| ): | |
| super(TransformerEncoder, self).__init__() | |
| self.layers = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| emb_dim, num_heads, dropout, layer_norm_eps, | |
| mlp_factor, norm_first, activation | |
| ) for _ in range(num_layers) | |
| ]) | |
| self.norm = nn.LayerNorm(emb_dim, layer_norm_eps) if norm else nn.Identity() | |
| def forward(self, src, pos_emb, src_mask, src_key_padding_mask): | |
| output = src | |
| for layer in self.layers: | |
| output = layer(output, pos_emb, src_mask, src_key_padding_mask) | |
| return self.norm(output) | |
| class TransformerEncoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| emb_dim: int, | |
| num_heads: int, | |
| dropout: float, | |
| layer_norm_eps: float, | |
| mlp_factor: int, | |
| norm_first: bool, | |
| activation: nn.Module, | |
| ): | |
| super(TransformerEncoderLayer, self).__init__() | |
| self.norm_first = norm_first | |
| self.norm1 = nn.LayerNorm(emb_dim, layer_norm_eps) | |
| self.norm2 = nn.LayerNorm(emb_dim, layer_norm_eps) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.self_attn = nn.MultiheadAttention( | |
| emb_dim, num_heads, dropout | |
| ) | |
| self.mlp = MLP(emb_dim, mlp_factor * emb_dim, dropout, activation) | |
| def with_emb(self, x, emb): | |
| return x if emb is None else x + emb | |
| def forward(self, src, pos_emb, src_mask, src_key_padding_mask): | |
| if self.norm_first: | |
| src_norm = self.norm1(src) | |
| q = k = src_norm + pos_emb | |
| src = src + self.dropout1(self.self_attn( | |
| query=q, | |
| key=k, | |
| value=src_norm, | |
| attn_mask=src_mask, | |
| key_padding_mask=src_key_padding_mask | |
| )[0]) | |
| src_norm = self.norm2(src) | |
| src = src + self.dropout2(self.mlp(src_norm)) | |
| else: | |
| q = k = src + pos_emb | |
| src = self.norm1(src + self.dropout1(self.self_attn( | |
| query=q, | |
| key=k, | |
| value=src, | |
| attn_mask=src_mask, | |
| key_padding_mask=src_key_padding_mask | |
| )[0])) | |
| src = self.norm2(src + self.dropout2(self.mlp(src))) | |
| return src | |