Spaces:
Running
Running
| import argparse | |
| import models | |
| import time | |
| import sys | |
| import warnings | |
| #from pathlib import Path | |
| import torch | |
| import numpy as np | |
| from scipy.stats import norm | |
| from scipy.io.wavfile import write | |
| from torch.nn.utils.rnn import pad_sequence | |
| #import style_controller | |
| from common.utils import load_wav_to_torch | |
| from langid.langid import WordLid | |
| from common import utils, layers | |
| from common.text.text_processing import TextProcessing | |
| from collections import Counter | |
| import os | |
| #os.environ["CUDA_VISIBLE_DEVICES"]="" | |
| device = "cuda:0" | |
| #device = "cpu" | |
| vocoder = "hifigan" | |
| SHARPEN = True | |
| from hifigan.data_function import MAX_WAV_VALUE, mel_spectrogram | |
| from hifigan.models import Denoiser | |
| import json | |
| from scipy import ndimage | |
| import os | |
| def parse_args(parser): | |
| """ | |
| Parse commandline arguments. | |
| """ | |
| parser.add_argument('-i', '--input', type=str, required=False, | |
| help='Full path to the input text (phareses separated by newlines)') | |
| parser.add_argument('-o', '--output', default=None, | |
| help='Output folder to save audio (file per phrase)') | |
| parser.add_argument('--log-file', type=str, default=None, | |
| help='Path to a DLLogger log file') | |
| parser.add_argument('--save-mels', action='store_true', help='') | |
| parser.add_argument('--cuda', action='store_true', | |
| help='Run inference on a GPU using CUDA') | |
| parser.add_argument('--cudnn-benchmark', action='store_true', | |
| help='Enable cudnn benchmark mode') | |
| #parser.add_argument('--fastpitch', type=str, default='output_multilang/FastPitch_checkpoint_200.pt', | |
| # help='Full path to the generator checkpoint file (skip to use ground truth mels)') ######### | |
| #parser.add_argument('--fastpitch', type=str, default='output_uralic/FastPitch_checkpoint_200.pt', | |
| # help='Full path to the generator checkpoint file (skip to use ground truth mels)') ######### | |
| parser.add_argument('--fastpitch', type=str, default='output_6lang/FastPitch_checkpoint_50.pt', | |
| help='Full path to the generator checkpoint file (skip to use ground truth mels)') ######### | |
| parser.add_argument('-d', '--denoising-strength', default=0.01, type=float, | |
| help='WaveGlow denoising') | |
| parser.add_argument('-sr', '--sampling-rate', default=22050, type=int, | |
| help='Sampling rate') | |
| parser.add_argument('--stft-hop-length', type=int, default=256, | |
| help='STFT hop length for estimating audio length from mel size') | |
| parser.add_argument('--amp', action='store_true',default=False, | |
| help='Inference with AMP') | |
| parser.add_argument('-bs', '--batch-size', type=int, default=1) | |
| parser.add_argument('--ema', action='store_true', | |
| help='Use EMA averaged model (if saved in checkpoints)') | |
| parser.add_argument('--speaker', type=int, default=0, | |
| help='Speaker ID for a multi-speaker model') | |
| parser.add_argument('--language', type=int, default=0, | |
| help='Language ID for a multilingual model') | |
| parser.add_argument('--p-arpabet', type=float, default=0.0, help='') ################ | |
| text_processing = parser.add_argument_group('Text processing parameters') | |
| text_processing.add_argument('--text-cleaners', nargs='*', | |
| default=['basic_cleaners'], type=str, | |
| help='Type of text cleaners for input text') | |
| text_processing.add_argument('--symbol-set', type=str, default='uralic', #'all_sami', ################# | |
| help='Define symbol set for input text') | |
| cond = parser.add_argument_group('conditioning on additional attributes') | |
| cond.add_argument('--n-speakers', type=int, default=30, #10 | |
| help='Number of speakers in the model.') | |
| cond.add_argument('--n-languages', type=int, default=6, #3 | |
| help='Number of languages in the model.') | |
| return parser | |
| class AttrDict(dict): | |
| def __init__(self, *args, **kwargs): | |
| super(AttrDict, self).__init__(*args, **kwargs) | |
| self.__dict__ = self | |
| def load_model_from_ckpt(checkpoint_path, ema, model): | |
| checkpoint_data = torch.load(checkpoint_path,map_location = device) | |
| status = '' | |
| if 'state_dict' in checkpoint_data: | |
| sd = checkpoint_data['state_dict'] | |
| if ema and 'ema_state_dict' in checkpoint_data: | |
| sd = checkpoint_data['ema_state_dict'] | |
| status += ' (EMA)' | |
| elif ema and not 'ema_state_dict' in checkpoint_data: | |
| print(f'WARNING: EMA weights missing for {checkpoint_data}') | |
| if any(key.startswith('module.') for key in sd): | |
| sd = {k.replace('module.', ''): v for k,v in sd.items()} | |
| status += ' ' + str(model.load_state_dict(sd, strict=False)) | |
| else: | |
| model = checkpoint_data['model'] | |
| print(f'Loaded {checkpoint_path}{status}') | |
| return model | |
| def load_and_setup_model(model_name, parser, checkpoint, amp, device, | |
| unk_args=[], forward_is_infer=False, ema=True, | |
| jitable=False): | |
| model_parser = models.parse_model_args(model_name, parser, add_help=False) | |
| model_args, model_unk_args = model_parser.parse_known_args() | |
| unk_args[:] = list(set(unk_args) & set(model_unk_args)) | |
| setattr(model_args, "energy_conditioning",True) | |
| model_config = models.get_model_config(model_name, model_args) | |
| # print(model_config) | |
| model = models.get_model(model_name, model_config, device, | |
| forward_is_infer=forward_is_infer, | |
| jitable=jitable) | |
| if checkpoint is not None: | |
| model = load_model_from_ckpt(checkpoint, ema, model) | |
| amp = False | |
| if amp: | |
| model.half() | |
| model.eval() | |
| return model.to(device) | |
| class Synthesizer: | |
| def _load_pyt_or_ts_model(self, model_name, ckpt_path, format = 'pyt'): | |
| if format == 'ts': | |
| model = models.load_and_setup_ts_model(model_name, ckpt_path, | |
| False, device) | |
| model_train_setup = {} | |
| return model, model_train_setup | |
| is_ts_based_infer = False | |
| model, _, model_train_setup = models.load_and_setup_model( | |
| model_name, self.parser, ckpt_path, False, device, | |
| unk_args=self.unk_args, forward_is_infer=True, jitable=is_ts_based_infer) | |
| if is_ts_based_infer: | |
| model = torch.jit.script(model) | |
| return model, model_train_setup | |
| def __init__(self): | |
| parser = argparse.ArgumentParser(description='PyTorch FastPitch Inference', | |
| allow_abbrev=False) | |
| self.parser = parse_args(parser) | |
| self.args, self.unk_args = self.parser.parse_known_args() | |
| self.generator = load_and_setup_model( | |
| 'FastPitch', parser, self.args.fastpitch, self.args.amp, device, | |
| unk_args=self.unk_args, forward_is_infer=True, ema=self.args.ema, | |
| jitable=False) | |
| self.hifigan_model = "pretrained_models/hifigan/hifigan_gen_checkpoint_10000_ft.pt" # Better with Sander! | |
| #self.hifigan_model = "pretrained_models/hifigan/hifigan_gen_checkpoint_6500.pt" | |
| #self.vocoder = UnivNetModel.from_pretrained(model_name="tts_en_libritts_univnet") | |
| self.vocoder, voc_train_setup= self._load_pyt_or_ts_model('HiFi-GAN', self.hifigan_model) | |
| self.denoiser = Denoiser(self.vocoder,device=device) #, win_length=self.args.win_length).to(device) | |
| self.tp = TextProcessing(self.args.symbol_set, self.args.text_cleaners, p_arpabet=0.0) | |
| self.lid = WordLid("langid/lang_id_model_q.bin") | |
| def unsharp_mask(self, img, radius=1, amount=1): | |
| blurred = ndimage.gaussian_filter(img, radius) | |
| sharpened = img + amount * ( img - blurred) | |
| return sharpened | |
| def speak(self, text, output_file="/tmp/tmp", spkr=0, lang=0, l_weight=1, s_weight=1, pace=0.95, clarity=1, guess_lang=True): | |
| text = " "+text+" " | |
| if guess_lang: | |
| lang = self.lid.get_lang_array(text) | |
| main_lang = Counter(lang).most_common(1)[0][0] | |
| lang = torch.tensor(lang).to(device) | |
| lang_weight = torch.zeros(len(lang)) | |
| lang_weight[:] = l_weight | |
| lang_weight[lang!=main_lang] = 0.5*l_weight | |
| text = self.tp.encode_text(text) | |
| if guess_lang == False: | |
| lang = torch.tensor(lang).to(device) | |
| else: | |
| if len(text) != len(lang): | |
| print("text length not equal to language list length!") | |
| lang = lang[0] | |
| l_weight = l_weight[0] | |
| text = torch.LongTensor([text]).to(device) | |
| for p in [0]: | |
| with torch.no_grad(): | |
| print(s_weight, l_weight) | |
| mel, mel_lens, *_ = self.generator(text, pace, max_duration=15, speaker=spkr, language=lang, speaker_weight=s_weight, language_weight=l_weight) #, ref_vector=embedding, speaker=speaker_i) #, **gen_kw, speaker 0 = bad audio, speaker 1 = better audio | |
| if SHARPEN: | |
| mel_np = mel.float().data.cpu().numpy()[0] | |
| tgt_min = -11 | |
| tgt_max = 1.5 | |
| #print(np.min(mel_np) , np.max(mel_np)) | |
| mel_np = self.unsharp_mask(mel_np, radius = 0.5, amount=0.5) | |
| mel_np = self.unsharp_mask(mel_np, radius = 3, amount=.05) | |
| # mel_np = self.unsharp_mask(mel_np, radius = 7, amount=0.05) | |
| for i in range(0, 80): | |
| mel_np[i,:]+=(i-30)*clarity*0.02 | |
| mel_np = (mel_np-np.min(mel_np))/ (np.max(mel_np)-np.min(mel_np)) * (tgt_max - tgt_min) + tgt_min | |
| mel[0] = torch.from_numpy(mel_np).float().to(device) | |
| """ | |
| mel_np = mel.float().data.cpu().numpy()[0] | |
| blurred_f = ndimage.gaussian_filter(mel_np, 1.0) #3 | |
| alpha = 0.2 #0.3 ta | |
| mel_np = mel_np + alpha * (mel_np - blurred_f) | |
| blurred_f = ndimage.gaussian_filter(mel_np, 3.0) #3 | |
| alpha = 0.1 # 0.1 ta | |
| sharpened = mel_np + alpha * (mel_np - blurred_f) | |
| for i in range(0,80): | |
| sharpened[i, :]+=(i-40)*0.01 #0.01 ta | |
| mel[0] = torch.from_numpy(sharpened).float().to(device) | |
| """ | |
| with torch.no_grad(): | |
| y_g_hat = self.vocoder(mel).float() ########### | |
| #y_g_hat = self.denoiser(y_g_hat.squeeze(1), strength=0.01) #[:, 0] | |
| audio = y_g_hat.squeeze() | |
| # normalize volume | |
| audio = audio/torch.max(torch.abs(audio))*0.95*32768 | |
| audio = audio.cpu().numpy().astype('int16') | |
| write(output_file+".wav", 22050, audio) | |
| #os.system("play -q "+output_file+".wav") | |
| return audio | |
| if __name__ == '__main__': | |
| syn = Synthesizer() | |
| hifigan = syn.hifigan_model | |
| hifigan_n = hifigan.replace(".pt", "") | |
| fastpitch = syn.args.fastpitch | |
| fastpitch_n = fastpitch.replace(".pt", "") | |
| print(hifigan_n + " " + fastpitch_n) | |
| hifigan_n_short = hifigan_n.split("/") | |
| hifigan_n_shorter = hifigan_n_short[2].split("_") | |
| hifigan_n_shortest = hifigan_n_shorter[3] | |
| fastpitch_n_short = fastpitch_n.split("/") | |
| fastpitch_n_shorter = fastpitch_n_short[1].split("_") | |
| fastpitch_n_shortest = fastpitch_n_shorter[2] | |
| #syn.speak("Gå lij riek mælggadav vádtsám, de bådij vijmak tjáppa vuobmáj.") | |
| i = 0 | |
| spkr = 1 | |
| lang = 1 | |
| while (1==1): | |
| text = input(">") | |
| text1 = text.split(" ") | |
| syn.speak(text, output_file="/tmp/tmp.wav", spkr=14, lang=4) | |
| syn.speak(text, output_file="/tmp/tmp.wav", spkr=14, lang=4, guess_lang=False) | |
| continue | |
| for s in range(1,10): | |
| for l in range(3): ## | |
| print("speaker", s, "language", l) ## | |
| syn.speak(text, output_file="/tmp/"+str(i)+"_"+text1[0]+"_"+str(s)+"_"+str(l)+"_FP_"+fastpitch_n_shortest+"univnet", spkr=s, lang=l) | |
| #syn.speak(text, output_file="/home/hiovain/DeepLearningExamples/PyTorch/SpeechSynthesis/FastPitchMulti/inf_output_multi/"+str(i)+"_"+text1[0]+"_"+str(s)+"_"+str(l)+"_FP_"+fastpitch_n_shortest+"univnet", spkr=s, lang=l) | |
| i += 1 | |