|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script generates speech with our pre-trained ZipVoice or |
|
ZipVoice-Distill models. If no local model is specified, |
|
Required files will be automatically downloaded from HuggingFace. |
|
|
|
Usage: |
|
|
|
Note: If you having trouble connecting to HuggingFace, |
|
try switching endpoint to mirror site: |
|
export HF_ENDPOINT=https://hf-mirror.com |
|
|
|
(1) Inference of a single sentence: |
|
|
|
python3 -m zipvoice.bin.infer_zipvoice \ |
|
--model-name "zipvoice" \ |
|
--prompt-wav prompt.wav \ |
|
--prompt-text "I am a prompt." \ |
|
--text "I am a sentence." \ |
|
--res-wav-path result.wav |
|
|
|
(2) Inference of a list of sentences: |
|
|
|
python3 -m zipvoice.bin.infer_zipvoice \ |
|
--model-name "zipvoice" \ |
|
--test-list test.tsv \ |
|
--res-dir results |
|
|
|
`--model-name` can be `zipvoice` or `zipvoice_distill`, |
|
which are the models before and after distillation, respectively. |
|
|
|
Each line of `test.tsv` is in the format of |
|
`{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`. |
|
""" |
|
|
|
import argparse |
|
import datetime as dt |
|
import json |
|
import os |
|
from typing import Optional |
|
|
|
import numpy as np |
|
import safetensors.torch |
|
import torch |
|
import torchaudio |
|
from huggingface_hub import hf_hub_download |
|
from lhotse.utils import fix_random_seed |
|
from vocos import Vocos |
|
|
|
from zipvoice.models.zipvoice import ZipVoice |
|
from zipvoice.models.zipvoice_distill import ZipVoiceDistill |
|
from zipvoice.tokenizer.tokenizer import ( |
|
EmiliaTokenizer, |
|
EspeakTokenizer, |
|
LibriTTSTokenizer, |
|
SimpleTokenizer, |
|
) |
|
from zipvoice.utils.checkpoint import load_checkpoint |
|
from zipvoice.utils.common import AttributeDict |
|
from zipvoice.utils.feature import VocosFbank |
|
|
|
HUGGINGFACE_REPO = "k2-fsa/ZipVoice" |
|
PRETRAINED_MODEL = { |
|
"zipvoice": "zipvoice/model.pt", |
|
"zipvoice_distill": "zipvoice_distill/model.pt", |
|
} |
|
TOKEN_FILE = { |
|
"zipvoice": "zipvoice/tokens.txt", |
|
"zipvoice_distill": "zipvoice_distill/tokens.txt", |
|
} |
|
MODEL_CONFIG = { |
|
"zipvoice": "zipvoice/zipvoice_base.json", |
|
"zipvoice_distill": "zipvoice_distill/zipvoice_base.json", |
|
} |
|
|
|
torch.set_num_threads(1) |
|
torch.set_num_interop_threads(1) |
|
|
|
def get_vocoder(vocos_local_path: Optional[str] = None): |
|
if vocos_local_path: |
|
vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") |
|
state_dict = torch.load( |
|
f"{vocos_local_path}/pytorch_model.bin", |
|
weights_only=True, |
|
map_location="cpu", |
|
) |
|
vocoder.load_state_dict(state_dict) |
|
else: |
|
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz") |
|
return vocoder |
|
|
|
|
|
def generate_sentence( |
|
prompt_text: str, |
|
prompt_wav: str, |
|
text: str, |
|
model: torch.nn.Module, |
|
vocoder: torch.nn.Module, |
|
tokenizer: EmiliaTokenizer, |
|
feature_extractor: VocosFbank, |
|
device: torch.device, |
|
num_step: int = 16, |
|
guidance_scale: float = 1.0, |
|
speed: float = 1.0, |
|
t_shift: float = 0.5, |
|
target_rms: float = 0.1, |
|
feat_scale: float = 0.1, |
|
sampling_rate: int = 24000, |
|
): |
|
""" |
|
Generate waveform of a text based on a given prompt |
|
waveform and its transcription. |
|
|
|
Args: |
|
save_path (str): Path to save the generated wav. |
|
prompt_text (str): Transcription of the prompt wav. |
|
prompt_wav (str): Path to the prompt wav file. |
|
text (str): Text to be synthesized into a waveform. |
|
model (torch.nn.Module): The model used for generation. |
|
vocoder (torch.nn.Module): The vocoder used to convert features to waveforms. |
|
tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens. |
|
feature_extractor (VocosFbank): The feature extractor used to |
|
extract acoustic features. |
|
device (torch.device): The device on which computations are performed. |
|
num_step (int, optional): Number of steps for decoding. Defaults to 16. |
|
guidance_scale (float, optional): Scale for classifier-free guidance. |
|
Defaults to 1.0. |
|
speed (float, optional): Speed control. Defaults to 1.0. |
|
t_shift (float, optional): Time shift. Defaults to 0.5. |
|
target_rms (float, optional): Target RMS for waveform normalization. |
|
Defaults to 0.1. |
|
feat_scale (float, optional): Scale for features. |
|
Defaults to 0.1. |
|
sampling_rate (int, optional): Sampling rate for the waveform. |
|
Defaults to 24000. |
|
Returns: |
|
metrics (dict): Dictionary containing time and real-time |
|
factor metrics for processing. |
|
""" |
|
|
|
tokens = tokenizer.texts_to_token_ids([text]) |
|
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text]) |
|
|
|
|
|
prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav) |
|
|
|
if prompt_sampling_rate != sampling_rate: |
|
resampler = torchaudio.transforms.Resample( |
|
orig_freq=prompt_sampling_rate, new_freq=sampling_rate |
|
) |
|
prompt_wav = resampler(prompt_wav) |
|
|
|
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav))) |
|
if prompt_rms < target_rms: |
|
prompt_wav = prompt_wav * target_rms / prompt_rms |
|
|
|
|
|
prompt_features = feature_extractor.extract( |
|
prompt_wav, sampling_rate=sampling_rate |
|
).to(device) |
|
|
|
prompt_features = prompt_features.unsqueeze(0) * feat_scale |
|
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device) |
|
|
|
|
|
start_t = dt.datetime.now() |
|
|
|
|
|
( |
|
pred_features, |
|
pred_features_lens, |
|
pred_prompt_features, |
|
pred_prompt_features_lens, |
|
) = model.sample( |
|
tokens=tokens, |
|
prompt_tokens=prompt_tokens, |
|
prompt_features=prompt_features, |
|
prompt_features_lens=prompt_features_lens, |
|
speed=speed, |
|
t_shift=t_shift, |
|
duration="predict", |
|
num_step=num_step, |
|
guidance_scale=guidance_scale, |
|
) |
|
|
|
|
|
pred_features = pred_features.permute(0, 2, 1) / feat_scale |
|
|
|
|
|
start_vocoder_t = dt.datetime.now() |
|
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1) |
|
|
|
|
|
t = (dt.datetime.now() - start_t).total_seconds() |
|
t_no_vocoder = (start_vocoder_t - start_t).total_seconds() |
|
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds() |
|
wav_seconds = wav.shape[-1] / sampling_rate |
|
rtf = t / wav_seconds |
|
rtf_no_vocoder = t_no_vocoder / wav_seconds |
|
rtf_vocoder = t_vocoder / wav_seconds |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if prompt_rms < target_rms: |
|
wav = wav * prompt_rms / target_rms |
|
|
|
|
|
return wav.cpu() |
|
|
|
model_defaults = { |
|
"zipvoice": { |
|
"num_step": 16, |
|
"guidance_scale": 1.0, |
|
}, |
|
"zipvoice_distill": { |
|
"num_step": 8, |
|
"guidance_scale": 3.0, |
|
}, |
|
} |
|
|
|
device = torch.device("cuda", 0) |
|
|
|
print("Loading model...") |
|
model_config = "config.json" |
|
|
|
with open(model_config, "r") as f: |
|
model_config = json.load(f) |
|
|
|
token_file = "tokens.txt" |
|
|
|
tokenizer = EspeakTokenizer(token_file=token_file, lang="vi") |
|
|
|
tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id} |
|
|
|
model_ckpt = "iter-96000-avg-2.pt" |
|
|
|
model = ZipVoice( |
|
**model_config["model"], |
|
**tokenizer_config, |
|
) |
|
|
|
load_checkpoint(filename=model_ckpt, model=model, strict=True) |
|
|
|
model = model.to(device) |
|
model.eval() |
|
|
|
vocoder = get_vocoder(None) |
|
vocoder = vocoder.to(device) |
|
vocoder.eval() |
|
|
|
if model_config["feature"]["type"] == "vocos": |
|
feature_extractor = VocosFbank() |
|
else: |
|
raise NotImplementedError( |
|
f"Unsupported feature type: {model_config['feature']['type']}" |
|
) |
|
sampling_rate = model_config["feature"]["sampling_rate"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|