Spaces:
Running
Running
| import copy | |
| import numbers | |
| from functools import partial | |
| from typing import Any, Callable, List, Optional, Tuple, Union | |
| import torch | |
| from torch import Tensor, nn | |
| from torch.nn import functional as F | |
| from .activation import MultiheadAttention | |
| from .scaling import ActivationBalancer, BalancedDoubleSwish | |
| from .scaling import BasicNorm as _BasicNorm | |
| from .rotary_embedding import RotaryEmbedding | |
| from .conv import ConvolutionModule, MultiLayeredConv1d | |
| _shape_t = Union[int, List[int], torch.Size] | |
| class LayerNorm(nn.Module): | |
| __constants__ = ["normalized_shape", "eps", "elementwise_affine"] | |
| normalized_shape: Tuple[int, ...] | |
| eps: float | |
| elementwise_affine: bool | |
| def __init__( | |
| self, | |
| normalized_shape: _shape_t, | |
| eps: float = 1e-5, | |
| elementwise_affine: bool = True, | |
| device=None, | |
| dtype=None, | |
| ) -> None: | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super(LayerNorm, self).__init__() | |
| if isinstance(normalized_shape, numbers.Integral): | |
| # mypy error: incompatible types in assignment | |
| normalized_shape = (normalized_shape,) # type: ignore[assignment] | |
| self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] | |
| self.eps = eps | |
| self.elementwise_affine = elementwise_affine | |
| if self.elementwise_affine: | |
| self.weight = nn.Parameter( | |
| torch.empty(self.normalized_shape, **factory_kwargs) | |
| ) | |
| self.bias = nn.Parameter( | |
| torch.empty(self.normalized_shape, **factory_kwargs) | |
| ) | |
| else: | |
| self.register_parameter("weight", None) | |
| self.register_parameter("bias", None) | |
| self.reset_parameters() | |
| def reset_parameters(self) -> None: | |
| if self.elementwise_affine: | |
| nn.init.ones_(self.weight) | |
| nn.init.zeros_(self.bias) | |
| def forward(self, input: Tensor, embedding: Any = None) -> Tensor: | |
| if isinstance(input, tuple): | |
| input, embedding = input | |
| return ( | |
| F.layer_norm( | |
| input, | |
| self.normalized_shape, | |
| self.weight, | |
| self.bias, | |
| self.eps, | |
| ), | |
| embedding, | |
| ) | |
| assert embedding is None | |
| return F.layer_norm( | |
| input, self.normalized_shape, self.weight, self.bias, self.eps | |
| ) | |
| def extra_repr(self) -> str: | |
| return ( | |
| "{normalized_shape}, eps={eps}, " | |
| "elementwise_affine={elementwise_affine}".format(**self.__dict__) | |
| ) | |
| class AdaptiveLayerNorm(nn.Module): | |
| r"""Adaptive Layer Normalization""" | |
| def __init__(self, d_model, norm) -> None: | |
| super(AdaptiveLayerNorm, self).__init__() | |
| self.project_layer = nn.Linear(d_model, 2 * d_model) | |
| self.norm = norm | |
| self.d_model = d_model | |
| self.eps = self.norm.eps | |
| def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: | |
| if isinstance(input, tuple): | |
| input, embedding = input | |
| weight, bias = torch.split( | |
| self.project_layer(embedding), | |
| split_size_or_sections=self.d_model, | |
| dim=-1, | |
| ) | |
| return (weight * self.norm(input) + bias, embedding) | |
| weight, bias = torch.split( | |
| self.project_layer(embedding), | |
| split_size_or_sections=self.d_model, | |
| dim=-1, | |
| ) | |
| return weight * self.norm(input) + bias | |
| class BasicNorm(_BasicNorm): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| eps: float = 1e-5, | |
| device=None, | |
| dtype=None, | |
| ): | |
| super(BasicNorm, self).__init__(d_model, eps=eps) | |
| def forward(self, input: Tensor, embedding: Any = None) -> Tensor: | |
| if isinstance(input, tuple): | |
| input, embedding = input | |
| return ( | |
| super(BasicNorm, self).forward(input), | |
| embedding, | |
| ) | |
| assert embedding is None | |
| return super(BasicNorm, self).forward(input) | |
| class BalancedBasicNorm(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| eps: float = 1e-5, | |
| device=None, | |
| dtype=None, | |
| ): | |
| super(BalancedBasicNorm, self).__init__() | |
| self.balancer = ActivationBalancer( | |
| d_model, | |
| channel_dim=-1, | |
| min_positive=0.45, | |
| max_positive=0.55, | |
| max_abs=6.0, | |
| ) | |
| self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype) | |
| def forward(self, input: Tensor, embedding: Any = None) -> Tensor: | |
| if isinstance(input, tuple): | |
| input, embedding = input | |
| return self.norm((self.balancer(input), embedding)) | |
| assert embedding is None | |
| return self.norm(self.balancer(input)) | |
| class IdentityNorm(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| eps: float = 1e-5, | |
| device=None, | |
| dtype=None, | |
| ) -> None: | |
| super(IdentityNorm, self).__init__() | |
| def forward(self, input: Tensor, embedding: Any = None) -> Tensor: | |
| if isinstance(input, tuple): | |
| return input | |
| assert embedding is None | |
| return input | |
| class RMSNorm(nn.Module): | |
| def __init__(self, d, p=-1., eps=1e-8, bias=False): | |
| """ | |
| Root Mean Square Layer Normalization | |
| :param d: model size | |
| :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled) | |
| :param eps: epsilon value, default 1e-8 | |
| :param bias: whether use bias term for RMSNorm, disabled by | |
| default because RMSNorm doesn't enforce re-centering invariance. | |
| """ | |
| super(RMSNorm, self).__init__() | |
| self.eps = eps | |
| self.d = d | |
| self.p = p | |
| self.bias = bias | |
| self.scale = nn.Parameter(torch.ones(d)) | |
| self.register_parameter("scale", self.scale) | |
| if self.bias: | |
| self.offset = nn.Parameter(torch.zeros(d)) | |
| self.register_parameter("offset", self.offset) | |
| def forward(self, x): | |
| if self.p < 0. or self.p > 1.: | |
| norm_x = x.norm(2, dim=-1, keepdim=True) | |
| d_x = self.d | |
| else: | |
| partial_size = int(self.d * self.p) | |
| partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1) | |
| norm_x = partial_x.norm(2, dim=-1, keepdim=True) | |
| d_x = partial_size | |
| rms_x = norm_x * d_x ** (-1. / 2) | |
| x_normed = x / (rms_x + self.eps) | |
| if self.bias: | |
| return self.scale * x_normed + self.offset | |
| return self.scale * x_normed | |
| class TransformerEncoderLayer(nn.Module): | |
| __constants__ = ["batch_first", "norm_first"] | |
| def __init__( | |
| self, | |
| d_model: int, | |
| nhead: int, | |
| dim_feedforward: int = 2048, | |
| dropout: float = 0.1, | |
| activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, | |
| batch_first: bool = False, | |
| norm_first: bool = False, | |
| device=None, | |
| dtype=None, | |
| linear1_self_attention_cls: nn.Module = nn.Linear, | |
| linear2_self_attention_cls: nn.Module = nn.Linear, | |
| linear1_feedforward_cls: nn.Module = nn.Linear, | |
| linear2_feedforward_cls: nn.Module = nn.Linear, | |
| layer_norm_cls: nn.Module = LayerNorm, | |
| layer_norm_eps: float = 1e-5, | |
| adaptive_layer_norm=False, | |
| use_conv_module: bool = False, | |
| use_depth_wise_conv: bool = False, | |
| conv_ignore_prefix_len: int = 0, | |
| cross_attention: bool = False, | |
| ) -> None: | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super(TransformerEncoderLayer, self).__init__() | |
| self.self_attn = MultiheadAttention( | |
| d_model, | |
| nhead, | |
| dropout=dropout, | |
| batch_first=batch_first, | |
| linear1_cls=linear1_self_attention_cls, | |
| linear2_cls=linear2_self_attention_cls, | |
| **factory_kwargs, | |
| ) | |
| if cross_attention: | |
| self.has_cross_attention = True | |
| self.cross_attn = nn.MultiheadAttention( | |
| d_model, nhead, 0.1, batch_first=True | |
| ) | |
| self.norm3 = layer_norm_cls( | |
| d_model, eps=layer_norm_eps, **factory_kwargs | |
| ) | |
| # Implementation of Feedforward model | |
| self.use_depth_wise_conv = use_depth_wise_conv | |
| self.use_conv_module = use_conv_module | |
| if not use_depth_wise_conv: | |
| self.linear1 = linear1_feedforward_cls( | |
| d_model, dim_feedforward, **factory_kwargs | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = linear2_feedforward_cls( | |
| dim_feedforward, d_model, **factory_kwargs | |
| ) | |
| else: | |
| self.dw_ffn = MultiLayeredConv1d( | |
| in_chans=d_model, | |
| hidden_chans=dim_feedforward, | |
| kernel_size=5, | |
| dropout_rate=dropout, | |
| ) | |
| self.norm_first = norm_first | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| # Legacy string support for activation function. | |
| if isinstance(activation, str): | |
| activation = _get_activation_fn(activation) | |
| elif isinstance(activation, partial): | |
| activation = activation(d_model) | |
| elif activation == BalancedDoubleSwish: | |
| activation = BalancedDoubleSwish(d_model) | |
| self.activation = activation | |
| norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) | |
| if layer_norm_cls == IdentityNorm: | |
| norm2 = BalancedBasicNorm( | |
| d_model, eps=layer_norm_eps, **factory_kwargs | |
| ) | |
| else: | |
| norm2 = layer_norm_cls( | |
| d_model, eps=layer_norm_eps, **factory_kwargs | |
| ) | |
| if adaptive_layer_norm: | |
| self.norm1 = AdaptiveLayerNorm(d_model, norm1) | |
| self.norm2 = AdaptiveLayerNorm(d_model, norm2) | |
| else: | |
| self.norm1 = norm1 | |
| self.norm2 = norm2 | |
| self.rotary_emb = RotaryEmbedding(dim=d_model // nhead) | |
| if use_conv_module: | |
| self.conv_module = ConvolutionModule( | |
| d_model, | |
| kernel_size=31, | |
| activation=activation, | |
| ignore_prefix_len=conv_ignore_prefix_len, | |
| ) | |
| self.norm_conv = LayerNorm(d_model) # for the CNN module | |
| if adaptive_layer_norm: | |
| self.norm_conv = AdaptiveLayerNorm(d_model, self.norm_conv) | |
| else: | |
| self.conv_module = None | |
| def __setstate__(self, state): | |
| super(TransformerEncoderLayer, self).__setstate__(state) | |
| if not hasattr(self, "activation"): | |
| self.activation = F.relu | |
| def forward( | |
| self, | |
| src: Tensor, | |
| context: Optional[Tensor] = None, | |
| src_mask: Optional[Tensor] = None, | |
| src_key_padding_mask: Optional[Tensor] = None, | |
| use_rope: bool = False, | |
| ) -> Tensor: | |
| r"""Pass the input through the encoder layer. | |
| Args: | |
| src: the sequence to the encoder layer (required). | |
| src_mask: the mask for the src sequence (optional). | |
| src_key_padding_mask: the mask for the src keys per batch (optional). | |
| Shape: | |
| see the docs in Transformer class. | |
| """ | |
| is_src_tuple = False | |
| if isinstance(src, tuple): | |
| x, stage_embedding = src | |
| is_src_tuple = True | |
| else: | |
| x, stage_embedding = src, None | |
| if src_key_padding_mask is not None: | |
| _skpm_dtype = src_key_padding_mask.dtype | |
| if _skpm_dtype != torch.bool and not torch.is_floating_point( | |
| src_key_padding_mask | |
| ): | |
| raise AssertionError( | |
| "only bool and floating types of key_padding_mask are supported" | |
| ) | |
| if self.norm_first: | |
| x = x + self._sa_block( | |
| self.norm1(x, stage_embedding), | |
| src_mask, | |
| src_key_padding_mask, | |
| use_rope=use_rope, | |
| ) | |
| if self.conv_module is not None: | |
| residual = x | |
| x = self.norm_conv(x, stage_embedding) | |
| x = residual + self.dropout1(self.conv_module(x)) | |
| # if self.has_cross_attention: | |
| # x = x + self.cross_attn( | |
| # self.norm3(x, stage_embedding), | |
| # context, | |
| # context, | |
| # attn_mask=src_mask, | |
| # )[0] | |
| x = x + self._ff_block(self.norm2(x, stage_embedding)) | |
| else: | |
| x = self.norm1( | |
| x + self._sa_block(x, src_mask, src_key_padding_mask, use_rope=use_rope), | |
| stage_embedding, | |
| ) | |
| if self.conv_module is not None: | |
| residual = x | |
| x = residual + self.dropout(self.conv_module(x)) | |
| x = self.norm_conv(x, stage_embedding) | |
| x = self.norm2(x + self._ff_block(x), stage_embedding) | |
| if is_src_tuple: | |
| return (x, stage_embedding) | |
| return x | |
| def infer( | |
| self, | |
| src: Tensor, | |
| src_mask: Optional[Tensor] = None, | |
| src_key_padding_mask: Optional[Tensor] = None, | |
| past_kv: Optional[Tensor] = None, | |
| use_cache: bool = False, | |
| use_rope: bool = False, | |
| ): | |
| x, stage_embedding = src, None | |
| is_src_tuple = False | |
| if isinstance(src, tuple): | |
| x, stage_embedding = src | |
| is_src_tuple = True | |
| if src_key_padding_mask is not None: | |
| _skpm_dtype = src_key_padding_mask.dtype | |
| if _skpm_dtype != torch.bool and not torch.is_floating_point( | |
| src_key_padding_mask | |
| ): | |
| raise AssertionError( | |
| "only bool and floating types of key_padding_mask are supported" | |
| ) | |
| if self.norm_first: | |
| x_attn_out, kv = self.self_attn.infer( | |
| self.norm1(x, stage_embedding), | |
| attn_mask=src_mask, | |
| key_padding_mask=src_key_padding_mask, | |
| need_weights=False, | |
| past_kv=past_kv, | |
| use_cache=use_cache, | |
| use_rope=use_rope, | |
| rope=self.rotary_emb | |
| ) | |
| x = x + x_attn_out | |
| x = x + self._ff_block(self.norm2(x, stage_embedding)) | |
| if is_src_tuple: | |
| return (x, stage_embedding) | |
| return (x, kv) | |
| # self-attention block | |
| def _sa_block( | |
| self, | |
| x: Tensor, | |
| attn_mask: Optional[Tensor], | |
| key_padding_mask: Optional[Tensor], | |
| use_rope: bool = False, | |
| ) -> Tensor: | |
| x = self.self_attn( | |
| x, | |
| x, | |
| x, | |
| attn_mask=attn_mask, | |
| key_padding_mask=key_padding_mask, | |
| need_weights=False, | |
| use_rope=use_rope, | |
| rope=self.rotary_emb | |
| )[0] | |
| return self.dropout1(x) | |
| # feed forward block | |
| def _ff_block(self, x: Tensor) -> Tensor: | |
| if self.use_depth_wise_conv: | |
| x = self.dw_ffn(x) | |
| else: | |
| x = self.linear2(self.dropout(self.activation(self.linear1(x)))) | |
| return self.dropout2(x) | |
| class TransformerEncoder(nn.Module): | |
| r"""TransformerEncoder is a stack of N encoder layers. Users can build the | |
| BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. | |
| Args: | |
| encoder_layer: an instance of the TransformerEncoderLayer() class (required). | |
| num_layers: the number of sub-encoder-layers in the encoder (required). | |
| norm: the layer normalization component (optional). | |
| enable_nested_tensor: if True, input will automatically convert to nested tensor | |
| (and convert back on output). This will improve the overall performance of | |
| TransformerEncoder when padding rate is high. Default: ``True`` (enabled). | |
| Examples:: | |
| >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) | |
| >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) | |
| >>> src = torch.rand(10, 32, 512) | |
| >>> out = transformer_encoder(src) | |
| """ | |
| __constants__ = ["norm"] | |
| def __init__(self, encoder_layer, num_layers, norm=None): | |
| super(TransformerEncoder, self).__init__() | |
| self.layers = _get_clones(encoder_layer, num_layers) | |
| self.num_layers = num_layers | |
| self.norm = norm | |
| def forward( | |
| self, | |
| src: Tensor, | |
| mask: Optional[Tensor] = None, | |
| src_key_padding_mask: Optional[Tensor] = None, | |
| return_layer_states: bool = False, | |
| use_rope: bool = False, | |
| ) -> Tensor: | |
| r"""Pass the input through the encoder layers in turn. | |
| Args: | |
| src: the sequence to the encoder (required). | |
| mask: the mask for the src sequence (optional). | |
| src_key_padding_mask: the mask for the src keys per batch (optional). | |
| return_layer_states: return layers' state (optional). | |
| Shape: | |
| see the docs in Transformer class. | |
| """ | |
| if return_layer_states: | |
| layer_states = [] # layers' output | |
| output = src | |
| for mod in self.layers: | |
| output = mod( | |
| output, | |
| src_mask=mask, | |
| src_key_padding_mask=src_key_padding_mask, | |
| use_rope=use_rope, | |
| ) | |
| layer_states.append(output[0]) | |
| if self.norm is not None: | |
| output = self.norm(output) | |
| return layer_states, output | |
| output = src | |
| for mod in self.layers: | |
| output = mod( | |
| output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, use_rope=use_rope | |
| ) | |
| if self.norm is not None: | |
| output = self.norm(output) | |
| return output | |
| def infer( | |
| self, | |
| src: Tensor, | |
| mask: Optional[Tensor] = None, | |
| src_key_padding_mask: Optional[Tensor] = None, | |
| return_layer_states: bool = False, | |
| past_kv: Optional[Tensor] = None, | |
| use_cache: bool = False, | |
| use_rope: bool = False, | |
| ): | |
| if past_kv is None: | |
| past_length = 0 | |
| past_kv = tuple([None] * self.num_layers) | |
| else: | |
| past_length = past_kv[0][0].size(-2) | |
| new_kv = () if use_cache else None | |
| output = src | |
| for mod, past_layer_kv in zip(self.layers, past_kv): | |
| output, kv = mod.infer( | |
| output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past_kv=past_layer_kv, use_cache=use_cache, use_rope=use_rope | |
| ) | |
| if use_cache: | |
| new_kv = new_kv + (kv,) | |
| if self.norm is not None: | |
| output = self.norm(output) | |
| return output, new_kv | |
| class TransformerDecoderLayer(nn.Module): | |
| __constants__ = ["batch_first", "norm_first"] | |
| def __init__( | |
| self, | |
| d_model: int, | |
| nhead: int, | |
| dim_feedforward: int = 2048, | |
| dropout: float = 0.1, | |
| activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, | |
| linear1_self_attention_cls: nn.Module = nn.Linear, | |
| linear2_self_attention_cls: nn.Module = nn.Linear, | |
| linear1_feedforward_cls: nn.Module = nn.Linear, | |
| linear2_feedforward_cls: nn.Module = nn.Linear, | |
| batch_first: bool = False, | |
| norm_first: bool = False, | |
| device=None, | |
| dtype=None, | |
| layer_norm_cls: nn.Module = LayerNorm, | |
| layer_norm_eps: float = 1e-5, | |
| adaptive_layer_norm=False, | |
| ) -> None: | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super(TransformerDecoderLayer, self).__init__() | |
| self.self_attn = MultiheadAttention( | |
| d_model, | |
| nhead, | |
| dropout=dropout, | |
| batch_first=batch_first, | |
| linear1_cls=linear1_self_attention_cls, | |
| linear2_cls=linear2_self_attention_cls, | |
| **factory_kwargs, | |
| ) | |
| self.multihead_attn = MultiheadAttention( | |
| d_model, | |
| nhead, | |
| dropout=dropout, | |
| batch_first=batch_first, | |
| linear1_cls=linear1_self_attention_cls, | |
| linear2_cls=linear2_self_attention_cls, | |
| **factory_kwargs, | |
| ) | |
| # Implementation of Feedforward model | |
| self.linear1 = linear1_feedforward_cls( | |
| d_model, dim_feedforward, **factory_kwargs | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = linear2_feedforward_cls( | |
| dim_feedforward, d_model, **factory_kwargs | |
| ) | |
| self.norm_first = norm_first | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.dropout3 = nn.Dropout(dropout) | |
| # Legacy string support for activation function. | |
| if isinstance(activation, str): | |
| self.activation = _get_activation_fn(activation) | |
| elif isinstance(activation, partial): | |
| self.activation = activation(d_model) | |
| elif activation == BalancedDoubleSwish: | |
| self.activation = BalancedDoubleSwish(d_model) | |
| else: | |
| self.activation = activation | |
| if adaptive_layer_norm: | |
| norm1 = layer_norm_cls( | |
| d_model, eps=layer_norm_eps, **factory_kwargs | |
| ) | |
| norm2 = layer_norm_cls( | |
| d_model, eps=layer_norm_eps, **factory_kwargs | |
| ) | |
| norm3 = layer_norm_cls( | |
| d_model, eps=layer_norm_eps, **factory_kwargs | |
| ) | |
| self.norm1 = AdaptiveLayerNorm(d_model, norm1) | |
| self.norm2 = AdaptiveLayerNorm(d_model, norm2) | |
| self.norm3 = AdaptiveLayerNorm(d_model, norm3) | |
| else: | |
| self.norm1 = layer_norm_cls( | |
| d_model, eps=layer_norm_eps, **factory_kwargs | |
| ) | |
| self.norm2 = layer_norm_cls( | |
| d_model, eps=layer_norm_eps, **factory_kwargs | |
| ) | |
| if layer_norm_cls == IdentityNorm: | |
| self.norm3 = BalancedBasicNorm( | |
| d_model, eps=layer_norm_eps, **factory_kwargs | |
| ) | |
| else: | |
| self.norm3 = layer_norm_cls( | |
| d_model, eps=layer_norm_eps, **factory_kwargs | |
| ) | |
| self.rotary_emb = RotaryEmbedding(dim=d_model // nhead) | |
| def forward( | |
| self, | |
| tgt: Tensor, | |
| memory: Tensor, | |
| tgt_mask: Optional[Tensor] = None, | |
| memory_mask: Optional[Tensor] = None, | |
| tgt_key_padding_mask: Optional[Tensor] = None, | |
| memory_key_padding_mask: Optional[Tensor] = None, | |
| use_rope: bool = False, | |
| ) -> Tensor: | |
| r"""Pass the inputs (and mask) through the decoder layer. | |
| Args: | |
| tgt: the sequence to the decoder layer (required). | |
| memory: the sequence from the last layer of the encoder (required). | |
| tgt_mask: the mask for the tgt sequence (optional). | |
| memory_mask: the mask for the memory sequence (optional). | |
| tgt_key_padding_mask: the mask for the tgt keys per batch (optional). | |
| memory_key_padding_mask: the mask for the memory keys per batch (optional). | |
| Shape: | |
| see the docs in Transformer class. | |
| """ | |
| tgt_is_tuple = False | |
| if isinstance(tgt, tuple): | |
| x, stage_embedding = tgt | |
| tgt_is_tuple = True | |
| else: | |
| x, stage_embedding = tgt, None | |
| if self.norm_first: | |
| x = x + self._sa_block( | |
| self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask, use_rope=use_rope, | |
| ) | |
| x_mha_out, attn_map = self._mha_block( | |
| self.norm2(x, stage_embedding), | |
| memory, | |
| memory_mask, | |
| memory_key_padding_mask, | |
| use_rope=use_rope, | |
| ) | |
| x = x + x_mha_out | |
| x = x + self._ff_block(self.norm3(x, stage_embedding)) | |
| else: | |
| x = self.norm1( | |
| x + self._sa_block(x, tgt_mask, tgt_key_padding_mask), | |
| stage_embedding, | |
| ) | |
| x = self.norm2( | |
| x | |
| + self._mha_block( | |
| x, memory, memory_mask, memory_key_padding_mask | |
| ), | |
| stage_embedding, | |
| ) | |
| x = self.norm3(x + self._ff_block(x), stage_embedding) | |
| if tgt_is_tuple: | |
| return (x, stage_embedding) | |
| return x, attn_map | |
| # self-attention block | |
| def _sa_block( | |
| self, | |
| x: Tensor, | |
| attn_mask: Optional[Tensor], | |
| key_padding_mask: Optional[Tensor], | |
| use_rope: bool = False, | |
| ) -> Tensor: | |
| x = self.self_attn( | |
| x, | |
| x, | |
| x, | |
| attn_mask=attn_mask, | |
| key_padding_mask=key_padding_mask, | |
| need_weights=False, | |
| use_rope=use_rope, | |
| rope=self.rotary_emb | |
| )[0] | |
| return self.dropout1(x) | |
| # multihead attention block | |
| def _mha_block( | |
| self, | |
| x: Tensor, | |
| mem: Tensor, | |
| attn_mask: Optional[Tensor], | |
| key_padding_mask: Optional[Tensor], | |
| use_rope: bool = False, | |
| ) -> Tensor: | |
| x = self.multihead_attn( | |
| x, | |
| mem, | |
| mem, | |
| attn_mask=attn_mask, | |
| key_padding_mask=key_padding_mask, | |
| need_weights=False, | |
| use_rope=use_rope, | |
| rope=self.rotary_emb | |
| )[0] | |
| x, attn_map = x | |
| return self.dropout2(x[0]), attn_map | |
| # feed forward block | |
| def _ff_block(self, x: Tensor) -> Tensor: | |
| x = self.linear2(self.dropout(self.activation(self.linear1(x)))) | |
| return self.dropout3(x) | |
| class TransformerDecoder(nn.Module): | |
| r"""TransformerDecoder is a stack of N decoder layers. Users can build the | |
| BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. | |
| Args: | |
| decoder_layer: an instance of the TransformerDecoderLayer() class (required). | |
| num_layers: the number of sub-decoder-layers in the decoder (required). | |
| norm: the layer normalization component (optional). | |
| enable_nested_tensor: if True, input will automatically convert to nested tensor | |
| (and convert back on output). This will improve the overall performance of | |
| TransformerDecoder when padding rate is high. Default: ``True`` (enabled). | |
| Examples:: | |
| >>> decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8) | |
| >>> transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6) | |
| >>> tgt = torch.rand(10, 32, 512) | |
| >>> memory = torch.rand(20, 32, 512) | |
| >>> out = transformer_decoder(tgt, memory) | |
| """ | |
| __constants__ = ["norm"] | |
| def __init__(self, decoder_layer, num_layers, norm=None): | |
| super(TransformerDecoder, self).__init__() | |
| self.layers = _get_clones(decoder_layer, num_layers) | |
| self.num_layers = num_layers | |
| self.norm = norm | |
| def forward( | |
| self, | |
| tgt: Tensor, | |
| memory: Tensor, | |
| tgt_mask: Optional[Tensor] = None, | |
| memory_mask: Optional[Tensor] = None, | |
| tgt_key_padding_mask: Optional[Tensor] = None, | |
| memory_key_padding_mask: Optional[Tensor] = None, | |
| return_attn: bool = False, | |
| use_rope: bool = False, | |
| ) -> Tensor: | |
| r"""Pass the inputs (and mask) through the decoder layers in turn. | |
| Args: | |
| tgt: the sequence to the decoder (required). | |
| memory: the sequence from the last layer of the encoder (required). | |
| tgt_mask: the mask for the tgt sequence (optional). | |
| memory_mask: the mask for the memory sequence (optional). | |
| tgt_key_padding_mask: the mask for the tgt keys per batch (optional). | |
| memory_key_padding_mask: the mask for the memory keys per batch (optional). | |
| return_attn: return cross attention maps of each layer (optional). | |
| Shape: | |
| see the docs in Transformer class. | |
| """ | |
| attn_maps = [] | |
| output = tgt | |
| for mod in self.layers: | |
| output, attn_map = mod( | |
| output, | |
| memory, | |
| tgt_mask=tgt_mask, | |
| memory_mask=memory_mask, | |
| tgt_key_padding_mask=tgt_key_padding_mask, | |
| memory_key_padding_mask=memory_key_padding_mask, | |
| use_rope=use_rope, | |
| ) | |
| if return_attn: | |
| attn_maps.append(attn_map) | |
| if self.norm is not None: | |
| output = self.norm(output) | |
| return output, attn_maps | |
| def _get_clones(module, N): | |
| return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |
| def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: | |
| if activation == "relu": | |
| return F.relu | |
| elif activation == "gelu": | |
| return F.gelu | |
| raise RuntimeError( | |
| "activation should be relu/gelu, not {}".format(activation) | |
| ) | |