Zvo / infer_zipvoice.py
hynt's picture
Update infer_zipvoice.py
54abc85 verified
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script generates speech with our pre-trained ZipVoice or
ZipVoice-Distill models. If no local model is specified,
Required files will be automatically downloaded from HuggingFace.
Usage:
Note: If you having trouble connecting to HuggingFace,
try switching endpoint to mirror site:
export HF_ENDPOINT=https://hf-mirror.com
(1) Inference of a single sentence:
python3 -m zipvoice.bin.infer_zipvoice \
--model-name "zipvoice" \
--prompt-wav prompt.wav \
--prompt-text "I am a prompt." \
--text "I am a sentence." \
--res-wav-path result.wav
(2) Inference of a list of sentences:
python3 -m zipvoice.bin.infer_zipvoice \
--model-name "zipvoice" \
--test-list test.tsv \
--res-dir results
`--model-name` can be `zipvoice` or `zipvoice_distill`,
which are the models before and after distillation, respectively.
Each line of `test.tsv` is in the format of
`{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
"""
import argparse
import datetime as dt
import json
import os
from typing import Optional
import numpy as np
import safetensors.torch
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from lhotse.utils import fix_random_seed
from vocos import Vocos
from zipvoice.models.zipvoice import ZipVoice
from zipvoice.models.zipvoice_distill import ZipVoiceDistill
from zipvoice.tokenizer.tokenizer import (
EmiliaTokenizer,
EspeakTokenizer,
LibriTTSTokenizer,
SimpleTokenizer,
)
from zipvoice.utils.checkpoint import load_checkpoint
from zipvoice.utils.common import AttributeDict
from zipvoice.utils.feature import VocosFbank
HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
PRETRAINED_MODEL = {
"zipvoice": "zipvoice/model.pt",
"zipvoice_distill": "zipvoice_distill/model.pt",
}
TOKEN_FILE = {
"zipvoice": "zipvoice/tokens.txt",
"zipvoice_distill": "zipvoice_distill/tokens.txt",
}
MODEL_CONFIG = {
"zipvoice": "zipvoice/zipvoice_base.json",
"zipvoice_distill": "zipvoice_distill/zipvoice_base.json",
}
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_vocoder(vocos_local_path: Optional[str] = None):
if vocos_local_path:
vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(
f"{vocos_local_path}/pytorch_model.bin",
weights_only=True,
map_location="cpu",
)
vocoder.load_state_dict(state_dict)
else:
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
return vocoder
def generate_sentence(
prompt_text: str,
prompt_wav: str,
text: str,
model: torch.nn.Module,
vocoder: torch.nn.Module,
tokenizer: EmiliaTokenizer,
feature_extractor: VocosFbank,
device: torch.device,
num_step: int = 16,
guidance_scale: float = 1.0,
speed: float = 1.0,
t_shift: float = 0.5,
target_rms: float = 0.1,
feat_scale: float = 0.1,
sampling_rate: int = 24000,
):
"""
Generate waveform of a text based on a given prompt
waveform and its transcription.
Args:
save_path (str): Path to save the generated wav.
prompt_text (str): Transcription of the prompt wav.
prompt_wav (str): Path to the prompt wav file.
text (str): Text to be synthesized into a waveform.
model (torch.nn.Module): The model used for generation.
vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
feature_extractor (VocosFbank): The feature extractor used to
extract acoustic features.
device (torch.device): The device on which computations are performed.
num_step (int, optional): Number of steps for decoding. Defaults to 16.
guidance_scale (float, optional): Scale for classifier-free guidance.
Defaults to 1.0.
speed (float, optional): Speed control. Defaults to 1.0.
t_shift (float, optional): Time shift. Defaults to 0.5.
target_rms (float, optional): Target RMS for waveform normalization.
Defaults to 0.1.
feat_scale (float, optional): Scale for features.
Defaults to 0.1.
sampling_rate (int, optional): Sampling rate for the waveform.
Defaults to 24000.
Returns:
metrics (dict): Dictionary containing time and real-time
factor metrics for processing.
"""
# Convert text to tokens
tokens = tokenizer.texts_to_token_ids([text])
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
# Load and preprocess prompt wav
prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
if prompt_sampling_rate != sampling_rate:
resampler = torchaudio.transforms.Resample(
orig_freq=prompt_sampling_rate, new_freq=sampling_rate
)
prompt_wav = resampler(prompt_wav)
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
if prompt_rms < target_rms:
prompt_wav = prompt_wav * target_rms / prompt_rms
# Extract features from prompt wav
prompt_features = feature_extractor.extract(
prompt_wav, sampling_rate=sampling_rate
).to(device)
prompt_features = prompt_features.unsqueeze(0) * feat_scale
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
# Start timing
start_t = dt.datetime.now()
# Generate features
(
pred_features,
pred_features_lens,
pred_prompt_features,
pred_prompt_features_lens,
) = model.sample(
tokens=tokens,
prompt_tokens=prompt_tokens,
prompt_features=prompt_features,
prompt_features_lens=prompt_features_lens,
speed=speed,
t_shift=t_shift,
duration="predict",
num_step=num_step,
guidance_scale=guidance_scale,
)
# Postprocess predicted features
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
# Start vocoder processing
start_vocoder_t = dt.datetime.now()
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
# Calculate processing times and real-time factors
t = (dt.datetime.now() - start_t).total_seconds()
t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
wav_seconds = wav.shape[-1] / sampling_rate
rtf = t / wav_seconds
rtf_no_vocoder = t_no_vocoder / wav_seconds
rtf_vocoder = t_vocoder / wav_seconds
# metrics = {
# "t": t,
# "t_no_vocoder": t_no_vocoder,
# "t_vocoder": t_vocoder,
# "wav_seconds": wav_seconds,
# "rtf": rtf,
# "rtf_no_vocoder": rtf_no_vocoder,
# "rtf_vocoder": rtf_vocoder,
# }
# Adjust wav volume if necessary
if prompt_rms < target_rms:
wav = wav * prompt_rms / target_rms
# torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
# return metrics
return wav.cpu()
model_defaults = {
"zipvoice": {
"num_step": 16,
"guidance_scale": 1.0,
},
"zipvoice_distill": {
"num_step": 8,
"guidance_scale": 3.0,
},
}
device = torch.device("cuda", 0)
print("Loading model...")
model_config = "config.json"
with open(model_config, "r") as f:
model_config = json.load(f)
token_file = "tokens.txt"
tokenizer = EspeakTokenizer(token_file=token_file, lang="vi")
tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
model_ckpt = "iter-96000-avg-2.pt"
model = ZipVoice(
**model_config["model"],
**tokenizer_config,
)
load_checkpoint(filename=model_ckpt, model=model, strict=True)
model = model.to(device)
model.eval()
vocoder = get_vocoder(None)
vocoder = vocoder.to(device)
vocoder.eval()
if model_config["feature"]["type"] == "vocos":
feature_extractor = VocosFbank()
else:
raise NotImplementedError(
f"Unsupported feature type: {model_config['feature']['type']}"
)
sampling_rate = model_config["feature"]["sampling_rate"]
# generate_sentence(
# save_path=res_wav_path,
# prompt_text=prompt_text,
# prompt_wav=prompt_wav,
# text=text,
# model=model,
# vocoder=vocoder,
# tokenizer=tokenizer,
# feature_extractor=feature_extractor,
# device=device,
# num_step=16,
# guidance_scale=1.0,
# speed=speed,
# t_shift=0.5,
# target_rms=0.1,
# feat_scale=0.1,
# sampling_rate=sampling_rate,
# )
# print("Done")