Spaces:
Sleeping
Sleeping
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import time | |
| import numpy as np | |
| from tqdm import tqdm | |
| import torch | |
| import json | |
| from models.tts.base.tts_inferece import TTSInference | |
| from models.tts.vits.vits_dataset import VITSTestDataset, VITSTestCollator | |
| from models.tts.vits.vits import SynthesizerTrn | |
| from processors.phone_extractor import phoneExtractor | |
| from text.text_token_collation import phoneIDCollation | |
| from utils.data_utils import * | |
| class VitsInference(TTSInference): | |
| def __init__(self, args=None, cfg=None): | |
| TTSInference.__init__(self, args, cfg) | |
| def _build_model(self): | |
| net_g = SynthesizerTrn( | |
| # symbols_len, | |
| self.cfg.model.text_token_num, | |
| self.cfg.preprocess.n_fft // 2 + 1, | |
| self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size, | |
| **self.cfg.model, | |
| ) | |
| return net_g | |
| def _build_test_dataset(sefl): | |
| return VITSTestDataset, VITSTestCollator | |
| def build_save_dir(self, dataset, speaker): | |
| save_dir = os.path.join( | |
| self.args.output_dir, | |
| "tts_am_step-{}_{}".format(self.am_restore_step, self.args.mode), | |
| ) | |
| if dataset is not None: | |
| save_dir = os.path.join(save_dir, "data_{}".format(dataset)) | |
| if speaker != -1: | |
| save_dir = os.path.join( | |
| save_dir, | |
| "spk_{}".format(speaker), | |
| ) | |
| os.makedirs(save_dir, exist_ok=True) | |
| print("Saving to ", save_dir) | |
| return save_dir | |
| def inference_for_batches( | |
| self, noise_scale=0.667, noise_scale_w=0.8, length_scale=1 | |
| ): | |
| ###### Construct test_batch ###### | |
| n_batch = len(self.test_dataloader) | |
| now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) | |
| print( | |
| "Model eval time: {}, batch_size = {}, n_batch = {}".format( | |
| now, self.test_batch_size, n_batch | |
| ) | |
| ) | |
| self.model.eval() | |
| ###### Inference for each batch ###### | |
| pred_res = [] | |
| with torch.no_grad(): | |
| for i, batch_data in enumerate( | |
| self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader) | |
| ): | |
| spk_id = None | |
| if self.cfg.preprocess.use_spkid and self.cfg.train.multi_speaker_training: | |
| spk_id = batch_data["spk_id"] | |
| outputs = self.model.infer( | |
| batch_data["phone_seq"], | |
| batch_data["phone_len"], | |
| spk_id, | |
| noise_scale=noise_scale, | |
| noise_scale_w=noise_scale_w, | |
| length_scale=length_scale, | |
| ) | |
| audios = outputs["y_hat"] | |
| masks = outputs["mask"] | |
| for idx in range(audios.size(0)): | |
| audio = audios[idx, 0, :].data.cpu().float() | |
| mask = masks[idx, :, :] | |
| audio_length = ( | |
| mask.sum([0, 1]).long() * self.cfg.preprocess.hop_size | |
| ) | |
| audio_length = audio_length.cpu().numpy() | |
| audio = audio[:audio_length] | |
| pred_res.append(audio) | |
| return pred_res | |
| def inference_for_single_utterance( | |
| self, noise_scale=0.667, noise_scale_w=0.8, length_scale=1 | |
| ): | |
| text = self.args.text | |
| # get phone symbol file | |
| phone_symbol_file = os.path.join(self.exp_dir, self.cfg.preprocess.symbols_dict) | |
| assert os.path.exists(phone_symbol_file) | |
| # convert text to phone sequence | |
| phone_extractor = phoneExtractor(self.cfg) | |
| phone_seq = phone_extractor.extract_phone(text) # phone_seq: list | |
| # convert phone sequence to phone id sequence | |
| phon_id_collator = phoneIDCollation(self.cfg, symbols_dict_file=phone_symbol_file) | |
| phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq) | |
| if self.cfg.preprocess.add_blank: | |
| phone_id_seq = intersperse(phone_id_seq, 0) | |
| # convert phone sequence to phone id sequence | |
| phone_id_seq = np.array(phone_id_seq) | |
| phone_id_seq = torch.from_numpy(phone_id_seq) | |
| # get speaker id if multi-speaker training and use speaker id | |
| if self.cfg.preprocess.use_spkid and self.cfg.train.multi_speaker_training: | |
| spk2id_file = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) | |
| with open(spk2id_file, 'r') as f: | |
| spk2id = json.load(f) | |
| speaker_id = spk2id[self.args.speaker_name] | |
| print("speaker name:",self.args.speaker_name) | |
| print("speaker id:",speaker_id) | |
| speaker_id = torch.from_numpy( | |
| np.array([speaker_id], dtype=np.int32) | |
| ).unsqueeze(0) | |
| with torch.no_grad(): | |
| x_tst = phone_id_seq.to(self.device).unsqueeze(0) | |
| x_tst_lengths = torch.LongTensor([phone_id_seq.size(0)]).to(self.device) | |
| speaker_id = speaker_id.to(self.device) | |
| outputs = self.model.infer( | |
| x_tst, | |
| x_tst_lengths, | |
| sid=speaker_id, | |
| noise_scale=noise_scale, | |
| noise_scale_w=noise_scale_w, | |
| length_scale=length_scale, | |
| ) | |
| audio = outputs["y_hat"][0, 0].data.cpu().float().numpy() | |
| return audio | |