Spaces:
Paused
Paused
initial commits
Browse files- TTSInferencing.py +267 -0
- hyperparams.yaml +187 -0
- model.ckpt +3 -0
- module_classes.py +214 -0
TTSInferencing.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import re
|
| 3 |
+
import logging
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
import random
|
| 7 |
+
import speechbrain
|
| 8 |
+
from speechbrain.inference.interfaces import Pretrained
|
| 9 |
+
from speechbrain.inference.text import GraphemeToPhoneme
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
class TTSInferencing(Pretrained):
|
| 14 |
+
"""
|
| 15 |
+
A ready-to-use wrapper for TTS (text -> mel_spec).
|
| 16 |
+
Arguments
|
| 17 |
+
---------
|
| 18 |
+
hparams
|
| 19 |
+
Hyperparameters (from HyperPyYAML)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
HPARAMS_NEEDED = ["modules", "input_encoder"]
|
| 23 |
+
|
| 24 |
+
MODULES_NEEDED = ["encoder_prenet", "pos_emb_enc",
|
| 25 |
+
"decoder_prenet", "pos_emb_dec",
|
| 26 |
+
"Seq2SeqTransformer", "mel_lin",
|
| 27 |
+
"stop_lin", "decoder_postnet"]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def __init__(self, *args, **kwargs):
|
| 31 |
+
super().__init__(*args, **kwargs)
|
| 32 |
+
lexicon = self.hparams.lexicon
|
| 33 |
+
lexicon = ["@@"] + lexicon
|
| 34 |
+
self.input_encoder = self.hparams.input_encoder
|
| 35 |
+
self.input_encoder.update_from_iterable(lexicon, sequence_input=False)
|
| 36 |
+
self.input_encoder.add_unk()
|
| 37 |
+
|
| 38 |
+
self.modules = self.hparams.modules
|
| 39 |
+
|
| 40 |
+
self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def generate_padded_phonemes(self, texts):
|
| 46 |
+
"""Computes mel-spectrogram for a list of texts
|
| 47 |
+
|
| 48 |
+
Arguments
|
| 49 |
+
---------
|
| 50 |
+
texts: List[str]
|
| 51 |
+
texts to be converted to spectrogram
|
| 52 |
+
|
| 53 |
+
Returns
|
| 54 |
+
-------
|
| 55 |
+
tensors of output spectrograms
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
# Preprocessing required at the inference time for the input text
|
| 59 |
+
# "label" below contains input text
|
| 60 |
+
# "phoneme_labels" contain the phoneme sequences corresponding to input text labels
|
| 61 |
+
|
| 62 |
+
phoneme_labels = list()
|
| 63 |
+
|
| 64 |
+
for label in texts:
|
| 65 |
+
|
| 66 |
+
phoneme_label = list()
|
| 67 |
+
|
| 68 |
+
label = self.custom_clean(label).upper()
|
| 69 |
+
|
| 70 |
+
words = label.split()
|
| 71 |
+
words = [word.strip() for word in words]
|
| 72 |
+
words_phonemes = self.g2p(words)
|
| 73 |
+
|
| 74 |
+
for i in range(len(words_phonemes)):
|
| 75 |
+
words_phonemes_seq = words_phonemes[i]
|
| 76 |
+
for phoneme in words_phonemes_seq:
|
| 77 |
+
if not phoneme.isspace():
|
| 78 |
+
phoneme_label.append(phoneme)
|
| 79 |
+
phoneme_labels.append(phoneme_label)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# encode the phonemes with input text encoder
|
| 83 |
+
encoded_phonemes = list()
|
| 84 |
+
for i in range(len(phoneme_labels)):
|
| 85 |
+
phoneme_label = phoneme_labels[i]
|
| 86 |
+
encoded_phoneme = torch.LongTensor(self.input_encoder.encode_sequence(phoneme_label)).to(self.device)
|
| 87 |
+
encoded_phonemes.append(encoded_phoneme)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# Right zero-pad all one-hot text sequences to max input length
|
| 91 |
+
input_lengths, ids_sorted_decreasing = torch.sort(
|
| 92 |
+
torch.LongTensor([len(x) for x in encoded_phonemes]), dim=0, descending=True
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
max_input_len = input_lengths[0]
|
| 96 |
+
|
| 97 |
+
phoneme_padded = torch.LongTensor(len(encoded_phonemes), max_input_len).to(self.device)
|
| 98 |
+
phoneme_padded.zero_()
|
| 99 |
+
|
| 100 |
+
for seq_idx, seq in enumerate(encoded_phonemes):
|
| 101 |
+
phoneme_padded[seq_idx, : len(seq)] = seq
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
return phoneme_padded.to(self.device, non_blocking=True).float()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def encode_batch(self, texts):
|
| 108 |
+
"""Computes mel-spectrogram for a list of texts
|
| 109 |
+
|
| 110 |
+
Texts must be sorted in decreasing order on their lengths
|
| 111 |
+
|
| 112 |
+
Arguments
|
| 113 |
+
---------
|
| 114 |
+
texts: List[str]
|
| 115 |
+
texts to be encoded into spectrogram
|
| 116 |
+
|
| 117 |
+
Returns
|
| 118 |
+
-------
|
| 119 |
+
tensors of output spectrograms
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
# generate phonemes and padd the input texts
|
| 123 |
+
encoded_phoneme_padded = self.generate_padded_phonemes(texts)
|
| 124 |
+
phoneme_prenet_emb = self.modules['encoder_prenet'](encoded_phoneme_padded)
|
| 125 |
+
# Positional Embeddings
|
| 126 |
+
phoneme_pos_emb = self.modules['pos_emb_enc'](encoded_phoneme_padded)
|
| 127 |
+
# Summing up embeddings
|
| 128 |
+
enc_phoneme_emb = phoneme_prenet_emb.permute(0,2,1) + phoneme_pos_emb
|
| 129 |
+
enc_phoneme_emb = enc_phoneme_emb.to(self.device)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
|
| 134 |
+
# generate sequential predictions via transformer decoder
|
| 135 |
+
start_token = torch.full((80, 1), fill_value= 0)
|
| 136 |
+
start_token[1] = 2
|
| 137 |
+
decoder_input = start_token.repeat(enc_phoneme_emb.size(0), 1, 1)
|
| 138 |
+
decoder_input = decoder_input.to(self.device, non_blocking=True).float()
|
| 139 |
+
|
| 140 |
+
num_itr = 0
|
| 141 |
+
stop_condition = [False] * decoder_input.size(0)
|
| 142 |
+
max_iter = 100
|
| 143 |
+
|
| 144 |
+
# while not all(stop_condition) and num_itr < max_iter:
|
| 145 |
+
while num_itr < max_iter:
|
| 146 |
+
|
| 147 |
+
# Decoder Prenet
|
| 148 |
+
mel_prenet_emb = self.modules['decoder_prenet'](decoder_input).to(self.device).permute(0,2,1)
|
| 149 |
+
|
| 150 |
+
# Positional Embeddings
|
| 151 |
+
mel_pos_emb = self.modules['pos_emb_dec'](mel_prenet_emb).to(self.device)
|
| 152 |
+
# Summing up Embeddings
|
| 153 |
+
dec_mel_spec = mel_prenet_emb + mel_pos_emb
|
| 154 |
+
|
| 155 |
+
# Getting the target mask to avoid looking ahead
|
| 156 |
+
tgt_mask = self.hparams.lookahead_mask(dec_mel_spec).to(self.device)
|
| 157 |
+
|
| 158 |
+
# Getting the source mask
|
| 159 |
+
src_mask = torch.zeros(enc_phoneme_emb.shape[1], enc_phoneme_emb.shape[1]).to(self.device)
|
| 160 |
+
|
| 161 |
+
# Padding masks for source and targets
|
| 162 |
+
src_key_padding_mask = self.hparams.padding_mask(enc_phoneme_emb, pad_idx = self.hparams.blank_index).to(self.device)
|
| 163 |
+
tgt_key_padding_mask = self.hparams.padding_mask(dec_mel_spec, pad_idx = self.hparams.blank_index).to(self.device)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# Running the Seq2Seq Transformer
|
| 167 |
+
decoder_outputs = self.modules['Seq2SeqTransformer'](src = enc_phoneme_emb, tgt = dec_mel_spec, src_mask = src_mask, tgt_mask = tgt_mask,
|
| 168 |
+
src_key_padding_mask = src_key_padding_mask, tgt_key_padding_mask = tgt_key_padding_mask)
|
| 169 |
+
|
| 170 |
+
# Mel Linears
|
| 171 |
+
mel_linears = self.modules['mel_lin'](decoder_outputs).permute(0,2,1)
|
| 172 |
+
mel_postnet = self.modules['decoder_postnet'](mel_linears) # mel tensor output
|
| 173 |
+
mel_pred = mel_linears + mel_postnet # mel tensor output
|
| 174 |
+
|
| 175 |
+
stop_token_pred = self.modules['stop_lin'](decoder_outputs).squeeze(-1)
|
| 176 |
+
|
| 177 |
+
stop_condition_list = self.check_stop_condition(stop_token_pred)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# update the values of main stop conditions
|
| 181 |
+
stop_condition_update = [True if stop_condition_list[i] else stop_condition[i] for i in range(len(stop_condition))]
|
| 182 |
+
stop_condition = stop_condition_update
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# Prepare input for the transformer input for next iteration
|
| 186 |
+
current_output = mel_pred[:, :, -1:]
|
| 187 |
+
|
| 188 |
+
decoder_input=torch.cat([decoder_input,current_output],dim=2)
|
| 189 |
+
num_itr = num_itr+1
|
| 190 |
+
|
| 191 |
+
mel_outputs = decoder_input[:, :, 1:]
|
| 192 |
+
|
| 193 |
+
return mel_outputs
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def encode_text(self, text):
|
| 198 |
+
"""Runs inference for a single text str"""
|
| 199 |
+
return self.encode_batch([text])
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def forward(self, text_list):
|
| 203 |
+
"Encodes the input texts."
|
| 204 |
+
return self.encode_batch(text_list)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def check_stop_condition(self, stop_token_pred):
|
| 208 |
+
"""
|
| 209 |
+
check if stop token / EOS reached or not for mel_specs in the batch
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
# Applying sigmoid to perform binary classification
|
| 213 |
+
sigmoid_output = torch.sigmoid(stop_token_pred)
|
| 214 |
+
# Checking if the probability is greater than 0.5
|
| 215 |
+
stop_results = sigmoid_output > 0.8
|
| 216 |
+
stop_output = [all(result) for result in stop_results]
|
| 217 |
+
|
| 218 |
+
return stop_output
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def custom_clean(self, text):
|
| 223 |
+
"""
|
| 224 |
+
Uses custom criteria to clean text.
|
| 225 |
+
|
| 226 |
+
Arguments
|
| 227 |
+
---------
|
| 228 |
+
text : str
|
| 229 |
+
Input text to be cleaned
|
| 230 |
+
model_name : str
|
| 231 |
+
whether to treat punctuations
|
| 232 |
+
|
| 233 |
+
Returns
|
| 234 |
+
-------
|
| 235 |
+
text : str
|
| 236 |
+
Cleaned text
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
_abbreviations = [
|
| 240 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
| 241 |
+
for x in [
|
| 242 |
+
("mrs", "missus"),
|
| 243 |
+
("mr", "mister"),
|
| 244 |
+
("dr", "doctor"),
|
| 245 |
+
("st", "saint"),
|
| 246 |
+
("co", "company"),
|
| 247 |
+
("jr", "junior"),
|
| 248 |
+
("maj", "major"),
|
| 249 |
+
("gen", "general"),
|
| 250 |
+
("drs", "doctors"),
|
| 251 |
+
("rev", "reverend"),
|
| 252 |
+
("lt", "lieutenant"),
|
| 253 |
+
("hon", "honorable"),
|
| 254 |
+
("sgt", "sergeant"),
|
| 255 |
+
("capt", "captain"),
|
| 256 |
+
("esq", "esquire"),
|
| 257 |
+
("ltd", "limited"),
|
| 258 |
+
("col", "colonel"),
|
| 259 |
+
("ft", "fort"),
|
| 260 |
+
]
|
| 261 |
+
]
|
| 262 |
+
|
| 263 |
+
text = re.sub(" +", " ", text)
|
| 264 |
+
|
| 265 |
+
for regex, replacement in _abbreviations:
|
| 266 |
+
text = re.sub(regex, replacement, text)
|
| 267 |
+
return text
|
hyperparams.yaml
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
############################################################################
|
| 3 |
+
# Model: TTS with attention-based mechanism
|
| 4 |
+
# Tokens: g2p + possitional embeddings
|
| 5 |
+
# losses: MSE & BCE
|
| 6 |
+
# Training: LJSpeech
|
| 7 |
+
# ############################################################################
|
| 8 |
+
|
| 9 |
+
###################################
|
| 10 |
+
# Experiment Parameters and setup #
|
| 11 |
+
###################################
|
| 12 |
+
seed: 1234
|
| 13 |
+
__set_seed: !apply:torch.manual_seed [!ref <seed>]
|
| 14 |
+
|
| 15 |
+
# Folder set up
|
| 16 |
+
# output_folder: !ref .\\results\\tts\\<seed>
|
| 17 |
+
# save_folder: !ref <output_folder>\\save
|
| 18 |
+
|
| 19 |
+
output_folder: !ref ./results/<seed>
|
| 20 |
+
save_folder: !ref <output_folder>/save
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
################################
|
| 24 |
+
# Model Parameters and model #
|
| 25 |
+
################################
|
| 26 |
+
# Input parameters
|
| 27 |
+
lexicon:
|
| 28 |
+
- AA
|
| 29 |
+
- AE
|
| 30 |
+
- AH
|
| 31 |
+
- AO
|
| 32 |
+
- AW
|
| 33 |
+
- AY
|
| 34 |
+
- B
|
| 35 |
+
- CH
|
| 36 |
+
- D
|
| 37 |
+
- DH
|
| 38 |
+
- EH
|
| 39 |
+
- ER
|
| 40 |
+
- EY
|
| 41 |
+
- F
|
| 42 |
+
- G
|
| 43 |
+
- HH
|
| 44 |
+
- IH
|
| 45 |
+
- IY
|
| 46 |
+
- JH
|
| 47 |
+
- K
|
| 48 |
+
- L
|
| 49 |
+
- M
|
| 50 |
+
- N
|
| 51 |
+
- NG
|
| 52 |
+
- OW
|
| 53 |
+
- OY
|
| 54 |
+
- P
|
| 55 |
+
- R
|
| 56 |
+
- S
|
| 57 |
+
- SH
|
| 58 |
+
- T
|
| 59 |
+
- TH
|
| 60 |
+
- UH
|
| 61 |
+
- UW
|
| 62 |
+
- V
|
| 63 |
+
- W
|
| 64 |
+
- Y
|
| 65 |
+
- Z
|
| 66 |
+
- ZH
|
| 67 |
+
|
| 68 |
+
input_encoder: !new:speechbrain.dataio.encoder.TextEncoder
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
################################
|
| 73 |
+
# Model Parameters and model #
|
| 74 |
+
# Transformer Parameters
|
| 75 |
+
################################
|
| 76 |
+
d_model: 512
|
| 77 |
+
nhead: 8
|
| 78 |
+
num_encoder_layers: 3
|
| 79 |
+
num_decoder_layers: 3
|
| 80 |
+
dim_feedforward: 512
|
| 81 |
+
dropout: 0.1
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# Decoder parameters
|
| 85 |
+
# The number of frames in the target per encoder step
|
| 86 |
+
n_frames_per_step: 1
|
| 87 |
+
decoder_rnn_dim: 1024
|
| 88 |
+
prenet_dim: 256
|
| 89 |
+
max_decoder_steps: 1000
|
| 90 |
+
gate_threshold: 0.5
|
| 91 |
+
p_decoder_dropout: 0.1
|
| 92 |
+
decoder_no_early_stopping: False
|
| 93 |
+
|
| 94 |
+
blank_index: 0 # This special tokes is for padding
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# Masks
|
| 98 |
+
lookahead_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_lookahead_mask
|
| 99 |
+
padding_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
################################
|
| 103 |
+
# CNN 3-layers Prenet #
|
| 104 |
+
################################
|
| 105 |
+
# Encoder Prenet
|
| 106 |
+
encoder_prenet: !new:module_classes.CNNPrenet
|
| 107 |
+
|
| 108 |
+
# Decoder Prenet
|
| 109 |
+
decoder_prenet: !new:module_classes.CNNDecoderPrenet
|
| 110 |
+
|
| 111 |
+
################################
|
| 112 |
+
# Positional Encodings #
|
| 113 |
+
################################
|
| 114 |
+
|
| 115 |
+
#encoder
|
| 116 |
+
pos_emb_enc: !new:module_classes.ScaledPositionalEncoding
|
| 117 |
+
input_size: !ref <d_model>
|
| 118 |
+
max_len: 5000
|
| 119 |
+
|
| 120 |
+
#decoder
|
| 121 |
+
pos_emb_dec: !new:module_classes.ScaledPositionalEncoding
|
| 122 |
+
input_size: !ref <d_model>
|
| 123 |
+
max_len: 5000
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
################################
|
| 127 |
+
# S2S Transfomer #
|
| 128 |
+
################################
|
| 129 |
+
|
| 130 |
+
Seq2SeqTransformer: !new:torch.nn.Transformer
|
| 131 |
+
d_model: !ref <d_model>
|
| 132 |
+
nhead: !ref <nhead>
|
| 133 |
+
num_encoder_layers: !ref <num_encoder_layers>
|
| 134 |
+
num_decoder_layers: !ref <num_decoder_layers>
|
| 135 |
+
dim_feedforward: !ref <dim_feedforward>
|
| 136 |
+
dropout: !ref <dropout>
|
| 137 |
+
batch_first: True
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
################################
|
| 141 |
+
# CNN 5-layers PostNet #
|
| 142 |
+
################################
|
| 143 |
+
|
| 144 |
+
decoder_postnet: !new:speechbrain.lobes.models.Tacotron2.Postnet
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# Linear transformation on the top of the decoder.
|
| 148 |
+
stop_lin: !new:speechbrain.nnet.linear.Linear
|
| 149 |
+
input_size: !ref <d_model>
|
| 150 |
+
n_neurons: 1
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# Linear transformation on the top of the decoder.
|
| 154 |
+
mel_lin: !new:speechbrain.nnet.linear.Linear
|
| 155 |
+
input_size: !ref <d_model>
|
| 156 |
+
n_neurons: 80
|
| 157 |
+
|
| 158 |
+
modules:
|
| 159 |
+
encoder_prenet: !ref <encoder_prenet>
|
| 160 |
+
pos_emb_enc: !ref <pos_emb_enc>
|
| 161 |
+
decoder_prenet: !ref <decoder_prenet>
|
| 162 |
+
pos_emb_dec: !ref <pos_emb_dec>
|
| 163 |
+
Seq2SeqTransformer: !ref <Seq2SeqTransformer>
|
| 164 |
+
mel_lin: !ref <mel_lin>
|
| 165 |
+
stop_lin: !ref <stop_lin>
|
| 166 |
+
decoder_postnet: !ref <decoder_postnet>
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
model: !new:torch.nn.ModuleList
|
| 170 |
+
- [!ref <encoder_prenet>,!ref <pos_emb_enc>,
|
| 171 |
+
!ref <decoder_prenet>, !ref <pos_emb_dec>, !ref <Seq2SeqTransformer>,
|
| 172 |
+
!ref <mel_lin>, !ref <stop_lin>, !ref <decoder_postnet>]
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
pretrained_model_path: ./model.ckpt
|
| 176 |
+
|
| 177 |
+
# The pretrainer allows a mapping between pretrained files and instances that
|
| 178 |
+
# are declared in the yaml. E.g here, we will download the file model.ckpt
|
| 179 |
+
# and it will be loaded into "model" which is pointing to the <model> defined
|
| 180 |
+
# before.
|
| 181 |
+
|
| 182 |
+
pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
|
| 183 |
+
collect_in: !ref <save_folder>
|
| 184 |
+
loadables:
|
| 185 |
+
model: !ref <model>
|
| 186 |
+
paths:
|
| 187 |
+
model: !ref <pretrained_model_path>
|
model.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4e5421fe987116817841652862ce070a421d7f5d7c8bbef68c83bec876b1eafb
|
| 3 |
+
size 95804314
|
module_classes.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
class CNNPrenet(torch.nn.Module):
|
| 8 |
+
def __init__(self):
|
| 9 |
+
super(CNNPrenet, self).__init__()
|
| 10 |
+
|
| 11 |
+
# Define the layers using Sequential container
|
| 12 |
+
self.conv_layers = nn.Sequential(
|
| 13 |
+
nn.Conv1d(in_channels=1, out_channels=512, kernel_size=3, padding=1),
|
| 14 |
+
nn.BatchNorm1d(512),
|
| 15 |
+
nn.ReLU(),
|
| 16 |
+
nn.Dropout(0.1),
|
| 17 |
+
|
| 18 |
+
nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
|
| 19 |
+
nn.BatchNorm1d(512),
|
| 20 |
+
nn.ReLU(),
|
| 21 |
+
nn.Dropout(0.1),
|
| 22 |
+
|
| 23 |
+
nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
|
| 24 |
+
nn.BatchNorm1d(512),
|
| 25 |
+
nn.ReLU(),
|
| 26 |
+
nn.Dropout(0.1)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
|
| 31 |
+
# Add a new dimension for the channel
|
| 32 |
+
x = x.unsqueeze(1)
|
| 33 |
+
|
| 34 |
+
# Pass input through convolutional layers
|
| 35 |
+
x = self.conv_layers(x)
|
| 36 |
+
|
| 37 |
+
# Remove the channel dimension
|
| 38 |
+
x = x.squeeze(1)
|
| 39 |
+
|
| 40 |
+
# Scale the output to the range [-1, 1]
|
| 41 |
+
x = torch.tanh(x)
|
| 42 |
+
|
| 43 |
+
return x
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class CNNDecoderPrenet(nn.Module):
|
| 48 |
+
def __init__(self, input_dim=80, hidden_dim=256, output_dim=256, final_dim=512, dropout_rate=0.5):
|
| 49 |
+
super(CNNDecoderPrenet, self).__init__()
|
| 50 |
+
self.layer1 = nn.Linear(input_dim, hidden_dim)
|
| 51 |
+
self.layer2 = nn.Linear(hidden_dim, output_dim)
|
| 52 |
+
self.linear_projection = nn.Linear(output_dim, final_dim) # Added linear projection
|
| 53 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
|
| 57 |
+
# Transpose the input tensor to have the feature dimension as the last dimension
|
| 58 |
+
x = x.transpose(1, 2)
|
| 59 |
+
# Apply the linear layers
|
| 60 |
+
x = F.relu(self.layer1(x))
|
| 61 |
+
x = self.dropout(x)
|
| 62 |
+
x = F.relu(self.layer2(x))
|
| 63 |
+
x = self.dropout(x)
|
| 64 |
+
# Apply the linear projection
|
| 65 |
+
x = self.linear_projection(x)
|
| 66 |
+
x = x.transpose(1, 2)
|
| 67 |
+
|
| 68 |
+
return x
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class CNNPostNet(torch.nn.Module):
|
| 74 |
+
"""
|
| 75 |
+
Conv Postnet
|
| 76 |
+
Arguments
|
| 77 |
+
---------
|
| 78 |
+
n_mel_channels: int
|
| 79 |
+
input feature dimension for convolution layers
|
| 80 |
+
postnet_embedding_dim: int
|
| 81 |
+
output feature dimension for convolution layers
|
| 82 |
+
postnet_kernel_size: int
|
| 83 |
+
postnet convolution kernal size
|
| 84 |
+
postnet_n_convolutions: int
|
| 85 |
+
number of convolution layers
|
| 86 |
+
postnet_dropout: float
|
| 87 |
+
dropout probability fot postnet
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
n_mel_channels=80,
|
| 93 |
+
postnet_embedding_dim=512,
|
| 94 |
+
postnet_kernel_size=5,
|
| 95 |
+
postnet_n_convolutions=5,
|
| 96 |
+
postnet_dropout=0.1,
|
| 97 |
+
):
|
| 98 |
+
super(CNNPostNet, self).__init__()
|
| 99 |
+
|
| 100 |
+
self.conv_pre = nn.Conv1d(
|
| 101 |
+
in_channels=n_mel_channels,
|
| 102 |
+
out_channels=postnet_embedding_dim,
|
| 103 |
+
kernel_size=postnet_kernel_size,
|
| 104 |
+
padding="same",
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.convs_intermedite = nn.ModuleList()
|
| 108 |
+
for i in range(1, postnet_n_convolutions - 1):
|
| 109 |
+
self.convs_intermedite.append(
|
| 110 |
+
nn.Conv1d(
|
| 111 |
+
in_channels=postnet_embedding_dim,
|
| 112 |
+
out_channels=postnet_embedding_dim,
|
| 113 |
+
kernel_size=postnet_kernel_size,
|
| 114 |
+
padding="same",
|
| 115 |
+
),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
self.conv_post = nn.Conv1d(
|
| 119 |
+
in_channels=postnet_embedding_dim,
|
| 120 |
+
out_channels=n_mel_channels,
|
| 121 |
+
kernel_size=postnet_kernel_size,
|
| 122 |
+
padding="same",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
self.tanh = nn.Tanh()
|
| 126 |
+
self.ln1 = nn.LayerNorm(postnet_embedding_dim)
|
| 127 |
+
self.ln2 = nn.LayerNorm(postnet_embedding_dim)
|
| 128 |
+
self.ln3 = nn.LayerNorm(n_mel_channels)
|
| 129 |
+
self.dropout1 = nn.Dropout(postnet_dropout)
|
| 130 |
+
self.dropout2 = nn.Dropout(postnet_dropout)
|
| 131 |
+
self.dropout3 = nn.Dropout(postnet_dropout)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
"""Computes the forward pass
|
| 136 |
+
Arguments
|
| 137 |
+
---------
|
| 138 |
+
x: torch.Tensor
|
| 139 |
+
a (batch, time_steps, features) input tensor
|
| 140 |
+
Returns
|
| 141 |
+
-------
|
| 142 |
+
output: torch.Tensor (the spectrogram predicted)
|
| 143 |
+
"""
|
| 144 |
+
x = self.conv_pre(x)
|
| 145 |
+
x = self.ln1(x.permute(0, 2, 1)).permute(0, 2, 1) # Transpose to [batch_size, feature_dim, sequence_length]
|
| 146 |
+
x = self.tanh(x)
|
| 147 |
+
x = self.dropout1(x)
|
| 148 |
+
|
| 149 |
+
for i in range(len(self.convs_intermedite)):
|
| 150 |
+
x = self.convs_intermedite[i](x)
|
| 151 |
+
x = self.ln2(x.permute(0, 2, 1)).permute(0, 2, 1) # Transpose to [batch_size, feature_dim, sequence_length]
|
| 152 |
+
x = self.tanh(x)
|
| 153 |
+
x = self.dropout2(x)
|
| 154 |
+
|
| 155 |
+
x = self.conv_post(x)
|
| 156 |
+
x = self.ln3(x.permute(0, 2, 1)).permute(0, 2, 1) # Transpose to [batch_size, feature_dim, sequence_length]
|
| 157 |
+
x = self.dropout3(x)
|
| 158 |
+
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class ScaledPositionalEncoding(nn.Module):
|
| 163 |
+
"""
|
| 164 |
+
This class implements the absolute sinusoidal positional encoding function
|
| 165 |
+
with an adaptive weight parameter alpha.
|
| 166 |
+
|
| 167 |
+
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
|
| 168 |
+
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
|
| 169 |
+
|
| 170 |
+
Arguments
|
| 171 |
+
---------
|
| 172 |
+
input_size: int
|
| 173 |
+
Embedding dimension.
|
| 174 |
+
max_len : int, optional
|
| 175 |
+
Max length of the input sequences (default 2500).
|
| 176 |
+
Example
|
| 177 |
+
-------
|
| 178 |
+
>>> a = torch.rand((8, 120, 512))
|
| 179 |
+
>>> enc = PositionalEncoding(input_size=a.shape[-1])
|
| 180 |
+
>>> b = enc(a)
|
| 181 |
+
>>> b.shape
|
| 182 |
+
torch.Size([1, 120, 512])
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
def __init__(self, input_size, max_len=2500):
|
| 186 |
+
super().__init__()
|
| 187 |
+
if input_size % 2 != 0:
|
| 188 |
+
raise ValueError(
|
| 189 |
+
f"Cannot use sin/cos positional encoding with odd channels (got channels={input_size})"
|
| 190 |
+
)
|
| 191 |
+
self.max_len = max_len
|
| 192 |
+
self.alpha = nn.Parameter(torch.ones(1)) # Define alpha as a trainable parameter
|
| 193 |
+
pe = torch.zeros(self.max_len, input_size, requires_grad=False)
|
| 194 |
+
positions = torch.arange(0, self.max_len).unsqueeze(1).float()
|
| 195 |
+
denominator = torch.exp(
|
| 196 |
+
torch.arange(0, input_size, 2).float()
|
| 197 |
+
* -(math.log(10000.0) / input_size)
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
pe[:, 0::2] = torch.sin(positions * denominator)
|
| 201 |
+
pe[:, 1::2] = torch.cos(positions * denominator)
|
| 202 |
+
pe = pe.unsqueeze(0)
|
| 203 |
+
self.register_buffer("pe", pe)
|
| 204 |
+
|
| 205 |
+
def forward(self, x):
|
| 206 |
+
"""
|
| 207 |
+
Arguments
|
| 208 |
+
---------
|
| 209 |
+
x : tensor
|
| 210 |
+
Input feature shape (batch, time, fea)
|
| 211 |
+
"""
|
| 212 |
+
pe_scaled = self.pe[:, :x.size(1)].clone().detach() * self.alpha # Scale positional encoding by alpha
|
| 213 |
+
return pe_scaled
|
| 214 |
+
|