Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| # Took from https://github.com/joeynmt/joeynmt/blob/fb66afcbe1beef9acd59283bcc084c4d4c1e6343/joeynmt/transformer_layers.py | |
| # pylint: disable=arguments-differ | |
| class MultiHeadedAttention(nn.Module): | |
| """ | |
| Multi-Head Attention module from "Attention is All You Need" | |
| Implementation modified from OpenNMT-py. | |
| https://github.com/OpenNMT/OpenNMT-py | |
| """ | |
| def __init__(self, num_heads: int, size: int, dropout: float = 0.1): | |
| """ | |
| Create a multi-headed attention layer. | |
| :param num_heads: the number of heads | |
| :param size: model size (must be divisible by num_heads) | |
| :param dropout: probability of dropping a unit | |
| """ | |
| super().__init__() | |
| assert size % num_heads == 0 | |
| self.head_size = head_size = size // num_heads | |
| self.model_size = size | |
| self.num_heads = num_heads | |
| self.k_layer = nn.Linear(size, num_heads * head_size) | |
| self.v_layer = nn.Linear(size, num_heads * head_size) | |
| self.q_layer = nn.Linear(size, num_heads * head_size) | |
| self.output_layer = nn.Linear(size, size) | |
| self.softmax = nn.Softmax(dim=-1) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, k: Tensor, v: Tensor, q: Tensor, mask: Tensor = None): | |
| """ | |
| Computes multi-headed attention. | |
| :param k: keys [B, M, D] with M being the sentence length. | |
| :param v: values [B, M, D] | |
| :param q: query [B, M, D] | |
| :param mask: optional mask [B, 1, M] or [B, M, M] | |
| :return: | |
| """ | |
| batch_size = k.size(0) | |
| num_heads = self.num_heads | |
| # project the queries (q), keys (k), and values (v) | |
| k = self.k_layer(k) | |
| v = self.v_layer(v) | |
| q = self.q_layer(q) | |
| # reshape q, k, v for our computation to [batch_size, num_heads, ..] | |
| k = k.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) | |
| v = v.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) | |
| q = q.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) | |
| # compute scores | |
| q = q / math.sqrt(self.head_size) | |
| # batch x num_heads x query_len x key_len | |
| scores = torch.matmul(q, k.transpose(2, 3)) | |
| # torch.Size([48, 8, 183, 183]) | |
| # apply the mask (if we have one) | |
| # we add a dimension for the heads to it below: [B, 1, 1, M] | |
| if mask is not None: | |
| scores = scores.masked_fill(~mask.unsqueeze(1), float('-inf')) | |
| # apply attention dropout and compute context vectors. | |
| attention = self.softmax(scores) | |
| attention = self.dropout(attention) | |
| # torch.Size([48, 8, 183, 183]) [bs, nheads, time, time] (for decoding) | |
| # v: torch.Size([48, 8, 183, 32]) (32 is 256/8) | |
| # get context vector (select values with attention) and reshape | |
| # back to [B, M, D] | |
| context = torch.matmul(attention, v) # torch.Size([48, 8, 183, 32]) | |
| context = context.transpose(1, 2).contiguous().view( | |
| batch_size, -1, num_heads * self.head_size) | |
| # torch.Size([48, 183, 256]) put back to 256 (combine the heads) | |
| output = self.output_layer(context) | |
| # torch.Size([48, 183, 256]): 1 output per time step | |
| return output | |
| # pylint: disable=arguments-differ | |
| class PositionwiseFeedForward(nn.Module): | |
| """ | |
| Position-wise Feed-forward layer | |
| Projects to ff_size and then back down to input_size. | |
| """ | |
| def __init__(self, input_size, ff_size, dropout=0.1): | |
| """ | |
| Initializes position-wise feed-forward layer. | |
| :param input_size: dimensionality of the input. | |
| :param ff_size: dimensionality of intermediate representation | |
| :param dropout: | |
| """ | |
| super().__init__() | |
| self.layer_norm = nn.LayerNorm(input_size, eps=1e-6) | |
| self.pwff_layer = nn.Sequential( | |
| nn.Linear(input_size, ff_size), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(ff_size, input_size), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x): | |
| x_norm = self.layer_norm(x) | |
| return self.pwff_layer(x_norm) + x | |
| # pylint: disable=arguments-differ | |
| class PositionalEncoding(nn.Module): | |
| """ | |
| Pre-compute position encodings (PE). | |
| In forward pass, this adds the position-encodings to the | |
| input for as many time steps as necessary. | |
| Implementation based on OpenNMT-py. | |
| https://github.com/OpenNMT/OpenNMT-py | |
| """ | |
| def __init__(self, size: int = 0, max_len: int = 5000): | |
| """ | |
| Positional Encoding with maximum length max_len | |
| :param size: | |
| :param max_len: | |
| :param dropout: | |
| """ | |
| if size % 2 != 0: | |
| raise ValueError("Cannot use sin/cos positional encoding with " | |
| "odd dim (got dim={:d})".format(size)) | |
| pe = torch.zeros(max_len, size) | |
| position = torch.arange(0, max_len).unsqueeze(1) | |
| div_term = torch.exp((torch.arange(0, size, 2, dtype=torch.float) * | |
| -(math.log(10000.0) / size))) | |
| pe[:, 0::2] = torch.sin(position.float() * div_term) | |
| pe[:, 1::2] = torch.cos(position.float() * div_term) | |
| pe = pe.unsqueeze(0) # shape: [1, size, max_len] | |
| super().__init__() | |
| self.register_buffer('pe', pe) | |
| self.dim = size | |
| def forward(self, emb): | |
| """Embed inputs. | |
| Args: | |
| emb (FloatTensor): Sequence of word vectors | |
| ``(seq_len, batch_size, self.dim)`` | |
| """ | |
| # Add position encodings | |
| return emb + self.pe[:, :emb.size(1)] | |
| class TransformerEncoderLayer(nn.Module): | |
| """ | |
| One Transformer encoder layer has a Multi-head attention layer plus | |
| a position-wise feed-forward layer. | |
| """ | |
| def __init__(self, | |
| size: int = 0, | |
| ff_size: int = 0, | |
| num_heads: int = 0, | |
| dropout: float = 0.1): | |
| """ | |
| A single Transformer layer. | |
| :param size: | |
| :param ff_size: | |
| :param num_heads: | |
| :param dropout: | |
| """ | |
| super().__init__() | |
| self.layer_norm = nn.LayerNorm(size, eps=1e-6) | |
| self.src_src_att = MultiHeadedAttention(num_heads, | |
| size, | |
| dropout=dropout) | |
| self.feed_forward = PositionwiseFeedForward(size, | |
| ff_size=ff_size, | |
| dropout=dropout) | |
| self.dropout = nn.Dropout(dropout) | |
| self.size = size | |
| # pylint: disable=arguments-differ | |
| def forward(self, x: Tensor, mask: Tensor) -> Tensor: | |
| """ | |
| Forward pass for a single transformer encoder layer. | |
| First applies layer norm, then self attention, | |
| then dropout with residual connection (adding the input to the result), | |
| and then a position-wise feed-forward layer. | |
| :param x: layer input | |
| :param mask: input mask | |
| :return: output tensor | |
| """ | |
| x_norm = self.layer_norm(x) | |
| h = self.src_src_att(x_norm, x_norm, x_norm, mask) | |
| h = self.dropout(h) + x | |
| o = self.feed_forward(h) | |
| return o | |
| class TransformerDecoderLayer(nn.Module): | |
| """ | |
| Transformer decoder layer. | |
| Consists of self-attention, source-attention, and feed-forward. | |
| """ | |
| def __init__(self, | |
| size: int = 0, | |
| ff_size: int = 0, | |
| num_heads: int = 0, | |
| dropout: float = 0.1): | |
| """ | |
| Represents a single Transformer decoder layer. | |
| It attends to the source representation and the previous decoder states. | |
| :param size: model dimensionality | |
| :param ff_size: size of the feed-forward intermediate layer | |
| :param num_heads: number of heads | |
| :param dropout: dropout to apply to input | |
| """ | |
| super().__init__() | |
| self.size = size | |
| self.trg_trg_att = MultiHeadedAttention(num_heads, | |
| size, | |
| dropout=dropout) | |
| self.src_trg_att = MultiHeadedAttention(num_heads, | |
| size, | |
| dropout=dropout) | |
| self.feed_forward = PositionwiseFeedForward(size, | |
| ff_size=ff_size, | |
| dropout=dropout) | |
| self.x_layer_norm = nn.LayerNorm(size, eps=1e-6) | |
| self.dec_layer_norm = nn.LayerNorm(size, eps=1e-6) | |
| self.dropout = nn.Dropout(dropout) | |
| # pylint: disable=arguments-differ | |
| def forward(self, | |
| x: Tensor = None, | |
| memory: Tensor = None, | |
| src_mask: Tensor = None, | |
| trg_mask: Tensor = None) -> Tensor: | |
| """ | |
| Forward pass of a single Transformer decoder layer. | |
| :param x: inputs | |
| :param memory: source representations | |
| :param src_mask: source mask | |
| :param trg_mask: target mask (so as to not condition on future steps) | |
| :return: output tensor | |
| """ | |
| # decoder/target self-attention | |
| x_norm = self.x_layer_norm(x) # torch.Size([48, 183, 256]) | |
| h1 = self.trg_trg_att(x_norm, x_norm, x_norm, mask=trg_mask) | |
| h1 = self.dropout(h1) + x | |
| # source-target attention | |
| h1_norm = self.dec_layer_norm( | |
| h1) # torch.Size([48, 183, 256]) (same for memory) | |
| h2 = self.src_trg_att(memory, memory, h1_norm, mask=src_mask) | |
| # final position-wise feed-forward layer | |
| o = self.feed_forward(self.dropout(h2) + h1) | |
| return o | |