Spaces:
Runtime error
Runtime error
| """ | |
| test_phone_encoder.py | |
| Desc: Check to make sure that using the Grad-TTS Encoder will work | |
| """ | |
| import sys | |
| sys.path.append('./') | |
| import torch | |
| import numpy as np | |
| import math | |
| from text import text_to_sequence, cmudict | |
| from text.symbols import symbols | |
| from models.utils import intersperse | |
| from models.phoneme_encoder import TextEncoder | |
| from models.utils import sequence_mask, generate_path, duration_loss, fix_len_compatibility | |
| from models import monotonic_align | |
| from text import get_arpabet, _symbol_to_id | |
| import matplotlib.pyplot as plt | |
| def test_cmu_parser(): | |
| cmu = cmudict.CMUDict('./resources/cmu_dictionary') | |
| text = "Here I go breaking audio models again." | |
| x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=cmu), len(symbols)))[None] | |
| x_lengths = torch.LongTensor([x.shape[-1]]) | |
| arpabet_example = get_arpabet("Here", cmu) | |
| """ | |
| test_phone_encoder | |
| Desc: function for ensuring that the Text Encoder works with params | |
| """ | |
| def test_phone_encoder(): | |
| # Load in Sample Mel Spec | |
| mel = np.load('/data/jiachenlian/VCTK/mels_16k/p225/p225_219.npy') # Of shape (T, C) where C is channel num | |
| # Speech Config for Values set below | |
| add_blank = True | |
| n_feats = 80 | |
| n_spks = 1 # 247 for Libri-TTS filelist and 1 for LJSpeech | |
| spk_emb_dim = 64 | |
| n_feats = 80 | |
| n_fft = 1024 | |
| sample_rate = 22050 | |
| hop_length = 256 | |
| win_length = 1024 | |
| f_min = 0 | |
| f_max = 8000 | |
| # encoder parameters | |
| n_enc_channels = 192 | |
| filter_channels = 768 | |
| filter_channels_dp = 256 | |
| n_enc_layers = 6 | |
| enc_kernel = 3 | |
| enc_dropout = 0.1 | |
| n_heads = 2 | |
| window_size = 4 | |
| length_scale = 1.0 | |
| # Format for instantiating encoder | |
| # encoder = TextEncoder(n_vocab, n_feats, n_enc_channels, | |
| # filter_channels, filter_channels_dp, n_heads, | |
| # n_enc_layers, enc_kernel, enc_dropout, window_size) | |
| # Example Declaration | |
| encoder = TextEncoder(len(symbols) + 1, n_feats, n_enc_channels, | |
| filter_channels, filter_channels_dp, n_heads, | |
| n_enc_layers, enc_kernel, enc_dropout, window_size) | |
| # Get Parsed Text | |
| cmu = cmudict.CMUDict('./resources/cmu_dictionary') | |
| # Example transcript from the same VCTK clip | |
| # text = "We used to live with dignity in our country." | |
| text = "They did not attack the themes of the book." | |
| x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=cmu), len(symbols)))[None] | |
| x_lengths = torch.LongTensor([x.shape[-1]]) | |
| mu_x, logw, x_mask = encoder(x, x_lengths, None) # Pass in None for spk rn | |
| # Inference Time Code | |
| w = torch.exp(logw) * x_mask | |
| w_ceil = torch.ceil(w) * length_scale | |
| y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() | |
| y_max_length = int(y_lengths.max()) | |
| y_max_length_ = fix_len_compatibility(y_max_length) | |
| # Using obtained durations `w` construct alignment map `attn` | |
| y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) | |
| attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) | |
| attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) | |
| # Align encoded text and get mu_y | |
| mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) | |
| mu_y = mu_y.transpose(1, 2) | |
| encoder_outputs = mu_y[:, :, :y_max_length][0].detach().numpy() | |
| # Plotting the phoneme encodings | |
| plt.figure(figsize=(10, 4)) | |
| plt.imshow(encoder_outputs, aspect='auto', origin='lower', | |
| extent=[0, encoder_outputs.shape[1], 0, encoder_outputs.shape[0]]) | |
| plt.colorbar(label='Intensity') | |
| plt.xlabel('Time') | |
| plt.ylabel('Mel Frequency Bands') | |
| plt.title('Phoneme Encoding') | |
| plt.savefig('./assets/example_untrained_phone_encoding.png') | |
| # Train Time Code | |
| # Test out that duration loss works | |
| y = torch.Tensor(mel.T).unsqueeze(0) | |
| y_lengths = [y.shape[-1]] | |
| y_max_length = y.shape[-1] | |
| y_lengths = torch.LongTensor([y.shape[-1]]) | |
| y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) | |
| attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) | |
| # Use MAS to find most likely alignment `attn` between text and mel-spectrogram | |
| with torch.no_grad(): | |
| const = -0.5 * math.log(2 * math.pi) * n_feats | |
| factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) | |
| y_square = torch.matmul(factor.transpose(1, 2), y ** 2) | |
| y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) | |
| mu_square = torch.sum(factor * (mu_x ** 2), 1).unsqueeze(-1) | |
| log_prior = y_square - y_mu_double + mu_square + const | |
| attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) | |
| attn = attn.detach() | |
| attn_np = attn.numpy() | |
| # Compute loss between predicted log-scaled durations and those obtained from MAS | |
| logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask | |
| dur_loss = duration_loss(logw, logw_, x_lengths) | |
| # Align text with mel-spec to get mu_y | |
| mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) | |
| mu_y = mu_y.transpose(1, 2) | |
| mu_y_np = mu_y.detach().numpy() | |
| # Compute loss between aligned encoder outputs and mel-spectrogram | |
| prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) | |
| prior_loss = prior_loss / (torch.sum(y_mask) * n_feats) | |
| # Plot the Aligned Text with Mel-Spec | |
| plt.figure(figsize=(10, 4)) | |
| plt.imshow(attn_np.squeeze(0), aspect='auto', origin='lower', | |
| extent=[0, attn_np.shape[2], 0, attn_np.shape[1]]) | |
| plt.colorbar(label='Intensity') | |
| plt.xlabel('Time') | |
| plt.ylabel('Mel Frequency Bands') | |
| plt.title('Untrained Duration') | |
| plt.savefig('./assets/example_duration.png') | |
| # Plot the Aligned Text with Mel-Spec | |
| plt.figure(figsize=(10, 4)) | |
| plt.imshow(mu_y_np.squeeze(0), aspect='auto', origin='lower', | |
| extent=[0, mu_y_np.shape[2], 0, mu_y_np.shape[1]]) | |
| plt.colorbar(label='Intensity') | |
| plt.xlabel('Time') | |
| plt.ylabel('Mel Frequency Bands') | |
| plt.title('Untrained Alignment') | |
| plt.savefig('./assets/example_MAS.png') | |
| # Plot Example Mel | |
| plt.figure(figsize=(10, 4)) | |
| plt.imshow(mel.T, aspect='auto', origin='lower', | |
| extent=[0, mel.shape[0], 0, mel.shape[1]]) | |
| plt.colorbar(label='Intensity') | |
| plt.xlabel('Time') | |
| plt.ylabel('Mel Frequency Bands') | |
| plt.title('Goal Mel Spectrogram') | |
| plt.savefig('./assets/example_mel.png') | |
| if __name__ == "__main__": | |
| test_cmu_parser() | |
| test_phone_encoder() |