Spaces:
Sleeping
Sleeping
| # Copyright 2020 Nagoya University (Tomoki Hayashi) | |
| # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
| import torch | |
| from Modules.GeneralLayers.Attention import MultiHeadedAttention as BaseMultiHeadedAttention | |
| class GSTStyleEncoder(torch.nn.Module): | |
| """Style encoder. | |
| This module is style encoder introduced in `Style Tokens: Unsupervised Style | |
| Modeling, Control and Transfer in End-to-End Speech Synthesis`. | |
| .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End | |
| Speech Synthesis`: https://arxiv.org/abs/1803.09017 | |
| Args: | |
| idim (int, optional): Dimension of the input features. | |
| gst_tokens (int, optional): The number of GST embeddings. | |
| gst_token_dim (int, optional): Dimension of each GST embedding. | |
| gst_heads (int, optional): The number of heads in GST multihead attention. | |
| conv_layers (int, optional): The number of conv layers in the reference encoder. | |
| conv_chans_list: (Sequence[int], optional): | |
| List of the number of channels of conv layers in the reference encoder. | |
| conv_kernel_size (int, optional): | |
| Kernel size of conv layers in the reference encoder. | |
| conv_stride (int, optional): | |
| Stride size of conv layers in the reference encoder. | |
| gst_layers (int, optional): The number of GRU layers in the reference encoder. | |
| gst_units (int, optional): The number of GRU units in the reference encoder. | |
| """ | |
| def __init__( | |
| self, | |
| idim: int = 128, | |
| gst_tokens: int = 512, # adaspeech suggests to use many more "basis vectors", but I believe that this is already sufficient | |
| gst_token_dim: int = 64, | |
| gst_heads: int = 8, | |
| conv_layers: int = 8, | |
| conv_chans_list=(32, 32, 64, 64, 128, 128, 256, 256), | |
| conv_kernel_size: int = 3, | |
| conv_stride: int = 2, | |
| gst_layers: int = 2, | |
| gst_units: int = 256, | |
| ): | |
| """Initialize global style encoder module.""" | |
| super(GSTStyleEncoder, self).__init__() | |
| self.num_tokens = gst_tokens | |
| self.ref_enc = ReferenceEncoder(idim=idim, | |
| conv_layers=conv_layers, | |
| conv_chans_list=conv_chans_list, | |
| conv_kernel_size=conv_kernel_size, | |
| conv_stride=conv_stride, | |
| gst_layers=gst_layers, | |
| gst_units=gst_units, ) | |
| self.stl = StyleTokenLayer(ref_embed_dim=gst_units, | |
| gst_tokens=gst_tokens, | |
| gst_token_dim=gst_token_dim, | |
| gst_heads=gst_heads, ) | |
| def forward(self, speech): | |
| """Calculate forward propagation. | |
| Args: | |
| speech (Tensor): Batch of padded target features (B, Lmax, odim). | |
| Returns: | |
| Tensor: Style token embeddings (B, token_dim). | |
| """ | |
| ref_embs = self.ref_enc(speech) | |
| style_embs = self.stl(ref_embs) | |
| return style_embs | |
| def calculate_ada4_regularization_loss(self): | |
| losses = list() | |
| for emb1_index in range(self.num_tokens): | |
| for emb2_index in range(emb1_index + 1, self.num_tokens): | |
| if emb1_index != emb2_index: | |
| losses.append(torch.nn.functional.cosine_similarity(self.stl.gst_embs[emb1_index], | |
| self.stl.gst_embs[emb2_index], dim=0)) | |
| return sum(losses) | |
| class ReferenceEncoder(torch.nn.Module): | |
| """Reference encoder module. | |
| This module is reference encoder introduced in `Style Tokens: Unsupervised Style | |
| Modeling, Control and Transfer in End-to-End Speech Synthesis`. | |
| .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End | |
| Speech Synthesis`: https://arxiv.org/abs/1803.09017 | |
| Args: | |
| idim (int, optional): Dimension of the input features. | |
| conv_layers (int, optional): The number of conv layers in the reference encoder. | |
| conv_chans_list: (Sequence[int], optional): | |
| List of the number of channels of conv layers in the reference encoder. | |
| conv_kernel_size (int, optional): | |
| Kernel size of conv layers in the reference encoder. | |
| conv_stride (int, optional): | |
| Stride size of conv layers in the reference encoder. | |
| gst_layers (int, optional): The number of GRU layers in the reference encoder. | |
| gst_units (int, optional): The number of GRU units in the reference encoder. | |
| """ | |
| def __init__( | |
| self, | |
| idim=80, | |
| conv_layers: int = 6, | |
| conv_chans_list=(32, 32, 64, 64, 128, 128), | |
| conv_kernel_size: int = 3, | |
| conv_stride: int = 2, | |
| gst_layers: int = 1, | |
| gst_units: int = 128, | |
| ): | |
| """Initialize reference encoder module.""" | |
| super(ReferenceEncoder, self).__init__() | |
| # check hyperparameters are valid | |
| assert conv_kernel_size % 2 == 1, "kernel size must be odd." | |
| assert ( | |
| len(conv_chans_list) == conv_layers), "the number of conv layers and length of channels list must be the same." | |
| convs = [] | |
| padding = (conv_kernel_size - 1) // 2 | |
| for i in range(conv_layers): | |
| conv_in_chans = 1 if i == 0 else conv_chans_list[i - 1] | |
| conv_out_chans = conv_chans_list[i] | |
| convs += [torch.nn.Conv2d(conv_in_chans, | |
| conv_out_chans, | |
| kernel_size=conv_kernel_size, | |
| stride=conv_stride, | |
| padding=padding, | |
| # Do not use bias due to the following batch norm | |
| bias=False, ), | |
| torch.nn.BatchNorm2d(conv_out_chans), | |
| torch.nn.ReLU(inplace=True), ] | |
| self.convs = torch.nn.Sequential(*convs) | |
| self.conv_layers = conv_layers | |
| self.kernel_size = conv_kernel_size | |
| self.stride = conv_stride | |
| self.padding = padding | |
| # get the number of GRU input units | |
| gst_in_units = idim | |
| for i in range(conv_layers): | |
| gst_in_units = (gst_in_units - conv_kernel_size + 2 * padding) // conv_stride + 1 | |
| gst_in_units *= conv_out_chans | |
| self.gst = torch.nn.GRU(gst_in_units, gst_units, gst_layers, batch_first=True) | |
| def forward(self, speech): | |
| """Calculate forward propagation. | |
| Args: | |
| speech (Tensor): Batch of padded target features (B, Lmax, idim). | |
| Returns: | |
| Tensor: Reference embedding (B, gst_units) | |
| """ | |
| batch_size = speech.size(0) | |
| xs = speech.unsqueeze(1) # (B, 1, Lmax, idim) | |
| hs = self.convs(xs).transpose(1, 2) # (B, Lmax', conv_out_chans, idim') | |
| time_length = hs.size(1) | |
| hs = hs.contiguous().view(batch_size, time_length, -1) # (B, Lmax', gst_units) | |
| self.gst.flatten_parameters() | |
| # pack_padded_sequence(hs, speech_lens, enforce_sorted=False, batch_first=True) | |
| _, ref_embs = self.gst(hs) # (gst_layers, batch_size, gst_units) | |
| ref_embs = ref_embs[-1] # (batch_size, gst_units) | |
| return ref_embs | |
| class StyleTokenLayer(torch.nn.Module): | |
| """Style token layer module. | |
| This module is style token layer introduced in `Style Tokens: Unsupervised Style | |
| Modeling, Control and Transfer in End-to-End Speech Synthesis`. | |
| .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End | |
| Speech Synthesis`: https://arxiv.org/abs/1803.09017 | |
| Args: | |
| ref_embed_dim (int, optional): Dimension of the input reference embedding. | |
| gst_tokens (int, optional): The number of GST embeddings. | |
| gst_token_dim (int, optional): Dimension of each GST embedding. | |
| gst_heads (int, optional): The number of heads in GST multihead attention. | |
| dropout_rate (float, optional): Dropout rate in multi-head attention. | |
| """ | |
| def __init__( | |
| self, | |
| ref_embed_dim: int = 128, | |
| gst_tokens: int = 10, | |
| gst_token_dim: int = 128, | |
| gst_heads: int = 4, | |
| dropout_rate: float = 0.0, | |
| ): | |
| """Initialize style token layer module.""" | |
| super(StyleTokenLayer, self).__init__() | |
| gst_embs = torch.randn(gst_tokens, gst_token_dim // gst_heads) | |
| self.register_parameter("gst_embs", torch.nn.Parameter(gst_embs)) | |
| self.mha = MultiHeadedAttention(q_dim=ref_embed_dim, | |
| k_dim=gst_token_dim // gst_heads, | |
| v_dim=gst_token_dim // gst_heads, | |
| n_head=gst_heads, | |
| n_feat=gst_token_dim, | |
| dropout_rate=dropout_rate, ) | |
| def forward(self, ref_embs): | |
| """Calculate forward propagation. | |
| Args: | |
| ref_embs (Tensor): Reference embeddings (B, ref_embed_dim). | |
| Returns: | |
| Tensor: Style token embeddings (B, gst_token_dim). | |
| """ | |
| batch_size = ref_embs.size(0) | |
| # (num_tokens, token_dim) -> (batch_size, num_tokens, token_dim) | |
| gst_embs = torch.tanh(self.gst_embs).unsqueeze(0).expand(batch_size, -1, -1) | |
| # NOTE(kan-bayashi): Shoule we apply Tanh? | |
| ref_embs = ref_embs.unsqueeze(1) # (batch_size, 1 ,ref_embed_dim) | |
| style_embs = self.mha(ref_embs, gst_embs, gst_embs, None) | |
| return style_embs.squeeze(1) | |
| class MultiHeadedAttention(BaseMultiHeadedAttention): | |
| """Multi head attention module with different input dimension.""" | |
| def __init__(self, q_dim, k_dim, v_dim, n_head, n_feat, dropout_rate=0.0): | |
| """Initialize multi head attention module.""" | |
| # NOTE(kan-bayashi): Do not use super().__init__() here since we want to | |
| # overwrite BaseMultiHeadedAttention.__init__() method. | |
| torch.nn.Module.__init__(self) | |
| assert n_feat % n_head == 0 | |
| # We assume d_v always equals d_k | |
| self.d_k = n_feat // n_head | |
| self.h = n_head | |
| self.linear_q = torch.nn.Linear(q_dim, n_feat) | |
| self.linear_k = torch.nn.Linear(k_dim, n_feat) | |
| self.linear_v = torch.nn.Linear(v_dim, n_feat) | |
| self.linear_out = torch.nn.Linear(n_feat, n_feat) | |
| self.attn = None | |
| self.dropout = torch.nn.Dropout(p=dropout_rate) | |