|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script generates speech with our pre-trained ZipVoice-Dialog or |
|
ZipVoice-Dialog-Stereo models. If no local model is specified, |
|
Required files will be automatically downloaded from HuggingFace. |
|
|
|
Usage: |
|
|
|
Note: If you having trouble connecting to HuggingFace, |
|
try switching endpoint to mirror site: |
|
export HF_ENDPOINT=https://hf-mirror.com |
|
|
|
python3 -m zipvoice.bin.infer_zipvoice_dialog \ |
|
--model-name "zipvoice_dialog" \ |
|
--test-list test.tsv \ |
|
--res-dir results |
|
|
|
`--model-name` can be `zipvoice_dialog` or `zipvoice_dialog_stereo`, |
|
which generate mono and stereo dialogues, respectively. |
|
|
|
Each line of `test.tsv` is in the format of merged conversation: |
|
'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}' |
|
or splited conversation: |
|
'{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription} |
|
\t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}' |
|
""" |
|
|
|
import argparse |
|
import datetime as dt |
|
import json |
|
import os |
|
from typing import List, Optional, Union |
|
|
|
import numpy as np |
|
import safetensors.torch |
|
import torch |
|
import torchaudio |
|
from huggingface_hub import hf_hub_download |
|
from lhotse.utils import fix_random_seed |
|
from vocos import Vocos |
|
|
|
from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo |
|
from zipvoice.tokenizer.tokenizer import DialogTokenizer |
|
from zipvoice.utils.checkpoint import load_checkpoint |
|
from zipvoice.utils.common import AttributeDict |
|
from zipvoice.utils.feature import VocosFbank |
|
|
|
HUGGINGFACE_REPO = "k2-fsa/ZipVoice" |
|
PRETRAINED_MODEL = { |
|
"zipvoice_dialog": "zipvoice_dialog/model.pt", |
|
"zipvoice_dialog_stereo": "zipvoice_dialog_stereo/model.pt", |
|
} |
|
TOKEN_FILE = { |
|
"zipvoice_dialog": "zipvoice_dialog/tokens.txt", |
|
"zipvoice_dialog_stereo": "zipvoice_dialog_stereo/tokens.txt", |
|
} |
|
MODEL_CONFIG = { |
|
"zipvoice_dialog": "zipvoice_dialog/model.json", |
|
"zipvoice_dialog_stereo": "zipvoice_dialog_stereo/model.json", |
|
} |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser( |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
) |
|
|
|
parser.add_argument( |
|
"--model-name", |
|
type=str, |
|
default="zipvoice_dialog", |
|
choices=["zipvoice_dialog", "zipvoice_dialog_stereo"], |
|
help="The model used for inference", |
|
) |
|
|
|
parser.add_argument( |
|
"--checkpoint", |
|
type=str, |
|
default=None, |
|
help="The model checkpoint. " |
|
"Will download pre-trained checkpoint from huggingface if not specified.", |
|
) |
|
|
|
parser.add_argument( |
|
"--model-config", |
|
type=str, |
|
default=None, |
|
help="The model configuration file. " |
|
"Will download model.json from huggingface if not specified.", |
|
) |
|
|
|
parser.add_argument( |
|
"--vocoder-path", |
|
type=str, |
|
default=None, |
|
help="The vocoder checkpoint. " |
|
"Will download pre-trained vocoder from huggingface if not specified.", |
|
) |
|
|
|
parser.add_argument( |
|
"--token-file", |
|
type=str, |
|
default=None, |
|
help="The file that contains information that maps tokens to ids," |
|
"which is a text file with '{token}\t{token_id}' per line. " |
|
"Will download tokens_emilia.txt from huggingface if not specified.", |
|
) |
|
|
|
parser.add_argument( |
|
"--test-list", |
|
type=str, |
|
default=None, |
|
help="The list of prompt speech, prompt_transcription, " |
|
"and text to synthesizein the format of merged conversation: " |
|
"'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}' " |
|
"or splited conversation: " |
|
"'{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription}" |
|
"\t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}'.", |
|
) |
|
|
|
parser.add_argument( |
|
"--res-dir", |
|
type=str, |
|
default="results", |
|
help=""" |
|
Path name of the generated wavs dir, |
|
used when test-list is not None |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--guidance-scale", |
|
type=float, |
|
default=1.5, |
|
help="The scale of classifier-free guidance during inference.", |
|
) |
|
|
|
parser.add_argument( |
|
"--num-step", |
|
type=int, |
|
default=16, |
|
help="The number of sampling steps.", |
|
) |
|
|
|
parser.add_argument( |
|
"--feat-scale", |
|
type=float, |
|
default=0.1, |
|
help="The scale factor of fbank feature", |
|
) |
|
|
|
parser.add_argument( |
|
"--speed", |
|
type=float, |
|
default=1.0, |
|
help="Control speech speed, 1.0 means normal, >1.0 means speed up", |
|
) |
|
|
|
parser.add_argument( |
|
"--t-shift", |
|
type=float, |
|
default=0.5, |
|
help="Shift t to smaller ones if t_shift < 1.0", |
|
) |
|
|
|
parser.add_argument( |
|
"--target-rms", |
|
type=float, |
|
default=0.1, |
|
help="Target speech normalization rms value, set to 0 to disable normalization", |
|
) |
|
|
|
parser.add_argument( |
|
"--seed", |
|
type=int, |
|
default=666, |
|
help="Random seed", |
|
) |
|
|
|
parser.add_argument( |
|
"--silence-wav", |
|
type=str, |
|
default="assets/silence.wav", |
|
help="Path of the silence wav file, used in two-channel generation " |
|
"with single-channel prompts", |
|
) |
|
|
|
return parser |
|
|
|
|
|
def get_vocoder(vocos_local_path: Optional[str] = None): |
|
if vocos_local_path: |
|
vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") |
|
state_dict = torch.load( |
|
f"{vocos_local_path}/pytorch_model.bin", |
|
weights_only=True, |
|
map_location="cpu", |
|
) |
|
vocoder.load_state_dict(state_dict) |
|
else: |
|
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz") |
|
return vocoder |
|
|
|
|
|
def generate_sentence( |
|
save_path: str, |
|
prompt_text: str, |
|
prompt_wav: Union[str, List[str]], |
|
text: str, |
|
model: torch.nn.Module, |
|
vocoder: torch.nn.Module, |
|
tokenizer: DialogTokenizer, |
|
feature_extractor: VocosFbank, |
|
device: torch.device, |
|
num_step: int = 16, |
|
guidance_scale: float = 1.0, |
|
speed: float = 1.0, |
|
t_shift: float = 0.5, |
|
target_rms: float = 0.1, |
|
feat_scale: float = 0.1, |
|
sampling_rate: int = 24000, |
|
): |
|
""" |
|
Generate waveform of a text based on a given prompt |
|
waveform and its transcription. |
|
|
|
Args: |
|
save_path (str): Path to save the generated wav. |
|
prompt_text (str): Transcription of the prompt wav. |
|
prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be |
|
one or two wav files, which corresponding to a merged conversational |
|
speech or two seperate speaker's speech. |
|
text (str): Text to be synthesized into a waveform. |
|
model (torch.nn.Module): The model used for generation. |
|
vocoder (torch.nn.Module): The vocoder used to convert features to waveforms. |
|
tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens. |
|
feature_extractor (VocosFbank): The feature extractor used to |
|
extract acoustic features. |
|
device (torch.device): The device on which computations are performed. |
|
num_step (int, optional): Number of steps for decoding. Defaults to 16. |
|
guidance_scale (float, optional): Scale for classifier-free guidance. |
|
Defaults to 1.0. |
|
speed (float, optional): Speed control. Defaults to 1.0. |
|
t_shift (float, optional): Time shift. Defaults to 0.5. |
|
target_rms (float, optional): Target RMS for waveform normalization. |
|
Defaults to 0.1. |
|
feat_scale (float, optional): Scale for features. |
|
Defaults to 0.1. |
|
sampling_rate (int, optional): Sampling rate for the waveform. |
|
Defaults to 24000. |
|
Returns: |
|
metrics (dict): Dictionary containing time and real-time |
|
factor metrics for processing. |
|
""" |
|
|
|
tokens = tokenizer.texts_to_token_ids([text]) |
|
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text]) |
|
|
|
|
|
if isinstance(prompt_wav, str): |
|
prompt_wav = [ |
|
prompt_wav, |
|
] |
|
else: |
|
assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str) |
|
|
|
loaded_prompt_wavs = prompt_wav |
|
for i in range(len(prompt_wav)): |
|
loaded_prompt_wavs[i], prompt_sampling_rate = torchaudio.load(prompt_wav[i]) |
|
if prompt_sampling_rate != sampling_rate: |
|
resampler = torchaudio.transforms.Resample( |
|
orig_freq=prompt_sampling_rate, new_freq=sampling_rate |
|
) |
|
loaded_prompt_wavs[i] = resampler(loaded_prompt_wavs[i]) |
|
|
|
if len(loaded_prompt_wavs) == 1: |
|
prompt_wav = loaded_prompt_wavs[0] |
|
else: |
|
prompt_wav = torch.cat(loaded_prompt_wavs, dim=1) |
|
|
|
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav))) |
|
if prompt_rms < target_rms: |
|
prompt_wav = prompt_wav * target_rms / prompt_rms |
|
|
|
|
|
prompt_features = feature_extractor.extract( |
|
prompt_wav, sampling_rate=sampling_rate |
|
).to(device) |
|
|
|
prompt_features = prompt_features.unsqueeze(0) * feat_scale |
|
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device) |
|
|
|
|
|
start_t = dt.datetime.now() |
|
|
|
|
|
( |
|
pred_features, |
|
pred_features_lens, |
|
pred_prompt_features, |
|
pred_prompt_features_lens, |
|
) = model.sample( |
|
tokens=tokens, |
|
prompt_tokens=prompt_tokens, |
|
prompt_features=prompt_features, |
|
prompt_features_lens=prompt_features_lens, |
|
speed=speed, |
|
t_shift=t_shift, |
|
duration="predict", |
|
num_step=num_step, |
|
guidance_scale=guidance_scale, |
|
) |
|
|
|
|
|
pred_features = pred_features.permute(0, 2, 1) / feat_scale |
|
|
|
|
|
start_vocoder_t = dt.datetime.now() |
|
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1) |
|
|
|
|
|
t = (dt.datetime.now() - start_t).total_seconds() |
|
t_no_vocoder = (start_vocoder_t - start_t).total_seconds() |
|
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds() |
|
wav_seconds = wav.shape[-1] / sampling_rate |
|
rtf = t / wav_seconds |
|
rtf_no_vocoder = t_no_vocoder / wav_seconds |
|
rtf_vocoder = t_vocoder / wav_seconds |
|
metrics = { |
|
"t": t, |
|
"t_no_vocoder": t_no_vocoder, |
|
"t_vocoder": t_vocoder, |
|
"wav_seconds": wav_seconds, |
|
"rtf": rtf, |
|
"rtf_no_vocoder": rtf_no_vocoder, |
|
"rtf_vocoder": rtf_vocoder, |
|
} |
|
|
|
|
|
if prompt_rms < target_rms: |
|
wav = wav * prompt_rms / target_rms |
|
torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate) |
|
|
|
return metrics |
|
|
|
|
|
def generate_sentence_stereo( |
|
save_path: str, |
|
prompt_text: str, |
|
prompt_wav: Union[str, List[str]], |
|
text: str, |
|
model: torch.nn.Module, |
|
vocoder: torch.nn.Module, |
|
tokenizer: DialogTokenizer, |
|
feature_extractor: VocosFbank, |
|
device: torch.device, |
|
num_step: int = 16, |
|
guidance_scale: float = 1.0, |
|
speed: float = 1.0, |
|
t_shift: float = 0.5, |
|
target_rms: float = 0.1, |
|
feat_scale: float = 0.1, |
|
sampling_rate: int = 24000, |
|
silence_wav: Optional[str] = None, |
|
): |
|
""" |
|
Generate waveform of a text based on a given prompt |
|
waveform and its transcription. |
|
|
|
Args: |
|
save_path (str): Path to save the generated wav. |
|
prompt_text (str): Transcription of the prompt wav. |
|
prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be |
|
one or two wav files, which corresponding to a merged conversational |
|
speech or two seperate speaker's speech. |
|
text (str): Text to be synthesized into a waveform. |
|
model (torch.nn.Module): The model used for generation. |
|
vocoder (torch.nn.Module): The vocoder used to convert features to waveforms. |
|
tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens. |
|
feature_extractor (VocosFbank): The feature extractor used to |
|
extract acoustic features. |
|
device (torch.device): The device on which computations are performed. |
|
num_step (int, optional): Number of steps for decoding. Defaults to 16. |
|
guidance_scale (float, optional): Scale for classifier-free guidance. |
|
Defaults to 1.0. |
|
speed (float, optional): Speed control. Defaults to 1.0. |
|
t_shift (float, optional): Time shift. Defaults to 0.5. |
|
target_rms (float, optional): Target RMS for waveform normalization. |
|
Defaults to 0.1. |
|
feat_scale (float, optional): Scale for features. |
|
Defaults to 0.1. |
|
sampling_rate (int, optional): Sampling rate for the waveform. |
|
Defaults to 24000. |
|
silence_wav (str): Path of the silence wav file, used in two-channel |
|
generation with single-channel prompts |
|
Returns: |
|
metrics (dict): Dictionary containing time and real-time |
|
factor metrics for processing. |
|
""" |
|
|
|
tokens = tokenizer.texts_to_token_ids([text]) |
|
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text]) |
|
|
|
|
|
if isinstance(prompt_wav, str): |
|
prompt_wav = [ |
|
prompt_wav, |
|
] |
|
else: |
|
assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str) |
|
|
|
loaded_prompt_wavs = prompt_wav |
|
for i in range(len(prompt_wav)): |
|
loaded_prompt_wavs[i], prompt_sampling_rate = torchaudio.load(prompt_wav[i]) |
|
if prompt_sampling_rate != sampling_rate: |
|
resampler = torchaudio.transforms.Resample( |
|
orig_freq=prompt_sampling_rate, new_freq=sampling_rate |
|
) |
|
loaded_prompt_wavs[i] = resampler(loaded_prompt_wavs[i]) |
|
|
|
if len(loaded_prompt_wavs) == 1: |
|
assert ( |
|
loaded_prompt_wavs[0].size(0) == 2 |
|
), "Merged prompt wav must be stereo for stereo dialogue generation" |
|
prompt_wav = loaded_prompt_wavs[0] |
|
|
|
else: |
|
assert len(loaded_prompt_wavs) == 2 |
|
if loaded_prompt_wavs[0].size(0) == 2: |
|
prompt_wav = torch.cat(loaded_prompt_wavs, dim=1) |
|
else: |
|
assert loaded_prompt_wavs[0].size(0) == 1 |
|
silence_wav, silence_sampling_rate = torchaudio.load(silence_wav) |
|
assert silence_sampling_rate == sampling_rate |
|
prompt_wav = silence_wav[ |
|
:, : loaded_prompt_wavs[0].size(1) + loaded_prompt_wavs[1].size(1) |
|
] |
|
prompt_wav[0, : loaded_prompt_wavs[0].size(1)] = loaded_prompt_wavs[0] |
|
prompt_wav[1, loaded_prompt_wavs[0].size(1) :] = loaded_prompt_wavs[1] |
|
|
|
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav))) |
|
if prompt_rms < target_rms: |
|
prompt_wav = prompt_wav * target_rms / prompt_rms |
|
|
|
|
|
prompt_features = feature_extractor.extract( |
|
prompt_wav, sampling_rate=sampling_rate |
|
).to(device) |
|
|
|
prompt_features = prompt_features.unsqueeze(0) * feat_scale |
|
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device) |
|
|
|
|
|
start_t = dt.datetime.now() |
|
|
|
|
|
( |
|
pred_features, |
|
pred_features_lens, |
|
pred_prompt_features, |
|
pred_prompt_features_lens, |
|
) = model.sample( |
|
tokens=tokens, |
|
prompt_tokens=prompt_tokens, |
|
prompt_features=prompt_features, |
|
prompt_features_lens=prompt_features_lens, |
|
speed=speed, |
|
t_shift=t_shift, |
|
duration="predict", |
|
num_step=num_step, |
|
guidance_scale=guidance_scale, |
|
) |
|
|
|
|
|
pred_features = pred_features.permute(0, 2, 1) / feat_scale |
|
|
|
|
|
start_vocoder_t = dt.datetime.now() |
|
feat_dim = pred_features.size(1) // 2 |
|
wav_left = vocoder.decode(pred_features[:, :feat_dim]).squeeze(1).clamp(-1, 1) |
|
wav_right = ( |
|
vocoder.decode(pred_features[:, feat_dim : feat_dim * 2]) |
|
.squeeze(1) |
|
.clamp(-1, 1) |
|
) |
|
|
|
wav = torch.cat([wav_left, wav_right], dim=0) |
|
|
|
|
|
t = (dt.datetime.now() - start_t).total_seconds() |
|
t_no_vocoder = (start_vocoder_t - start_t).total_seconds() |
|
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds() |
|
wav_seconds = wav.shape[-1] / sampling_rate |
|
rtf = t / wav_seconds |
|
rtf_no_vocoder = t_no_vocoder / wav_seconds |
|
rtf_vocoder = t_vocoder / wav_seconds |
|
metrics = { |
|
"t": t, |
|
"t_no_vocoder": t_no_vocoder, |
|
"t_vocoder": t_vocoder, |
|
"wav_seconds": wav_seconds, |
|
"rtf": rtf, |
|
"rtf_no_vocoder": rtf_no_vocoder, |
|
"rtf_vocoder": rtf_vocoder, |
|
} |
|
|
|
|
|
if prompt_rms < target_rms: |
|
wav = wav * prompt_rms / target_rms |
|
torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate) |
|
|
|
return metrics |
|
|
|
|
|
def generate_list( |
|
model_name: str, |
|
res_dir: str, |
|
test_list: str, |
|
model: torch.nn.Module, |
|
vocoder: torch.nn.Module, |
|
tokenizer: DialogTokenizer, |
|
feature_extractor: VocosFbank, |
|
device: torch.device, |
|
num_step: int = 16, |
|
guidance_scale: float = 1.5, |
|
speed: float = 1.0, |
|
t_shift: float = 0.5, |
|
target_rms: float = 0.1, |
|
feat_scale: float = 0.1, |
|
sampling_rate: int = 24000, |
|
silence_wav: Optional[str] = None, |
|
): |
|
total_t = [] |
|
total_t_no_vocoder = [] |
|
total_t_vocoder = [] |
|
total_wav_seconds = [] |
|
|
|
with open(test_list, "r") as fr: |
|
lines = fr.readlines() |
|
|
|
for i, line in enumerate(lines): |
|
items = line.strip().split("\t") |
|
if len(items) == 6: |
|
( |
|
wav_name, |
|
prompt_text_1, |
|
prompt_text_2, |
|
prompt_wav_1, |
|
prompt_wav_2, |
|
text, |
|
) = items |
|
prompt_text = f"[S1]{prompt_text_1}[S2]{prompt_text_2}" |
|
prompt_wav = [prompt_wav_1, prompt_wav_2] |
|
elif len(items) == 4: |
|
wav_name, prompt_text, prompt_wav, text = items |
|
else: |
|
raise ValueError(f"Invalid line: {line}") |
|
assert text.startswith("[S1]") |
|
|
|
save_path = f"{res_dir}/{wav_name}.wav" |
|
|
|
if model_name == "zipvoice_dialog": |
|
|
|
metrics = generate_sentence( |
|
save_path=save_path, |
|
prompt_text=prompt_text, |
|
prompt_wav=prompt_wav, |
|
text=text, |
|
model=model, |
|
vocoder=vocoder, |
|
tokenizer=tokenizer, |
|
feature_extractor=feature_extractor, |
|
device=device, |
|
num_step=num_step, |
|
guidance_scale=guidance_scale, |
|
speed=speed, |
|
t_shift=t_shift, |
|
target_rms=target_rms, |
|
feat_scale=feat_scale, |
|
sampling_rate=sampling_rate, |
|
) |
|
else: |
|
assert model_name == "zipvoice_dialog_stereo" |
|
metrics = generate_sentence_stereo( |
|
save_path=save_path, |
|
prompt_text=prompt_text, |
|
prompt_wav=prompt_wav, |
|
text=text, |
|
model=model, |
|
vocoder=vocoder, |
|
tokenizer=tokenizer, |
|
feature_extractor=feature_extractor, |
|
device=device, |
|
num_step=num_step, |
|
guidance_scale=guidance_scale, |
|
speed=speed, |
|
t_shift=t_shift, |
|
target_rms=target_rms, |
|
feat_scale=feat_scale, |
|
sampling_rate=sampling_rate, |
|
silence_wav=silence_wav, |
|
) |
|
|
|
print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}") |
|
total_t.append(metrics["t"]) |
|
total_t_no_vocoder.append(metrics["t_no_vocoder"]) |
|
total_t_vocoder.append(metrics["t_vocoder"]) |
|
total_wav_seconds.append(metrics["wav_seconds"]) |
|
|
|
print(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}") |
|
print( |
|
f"Average RTF w/o vocoder: " |
|
f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}" |
|
) |
|
print( |
|
f"Average RTF vocoder: " |
|
f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}" |
|
) |
|
|
|
|
|
@torch.inference_mode() |
|
def main(): |
|
parser = get_parser() |
|
args = parser.parse_args() |
|
|
|
params = AttributeDict() |
|
params.update(vars(args)) |
|
fix_random_seed(params.seed) |
|
|
|
assert ( |
|
params.test_list is not None |
|
), "For inference, please provide prompts and text with '--test-list'" |
|
|
|
if torch.cuda.is_available(): |
|
params.device = torch.device("cuda", 0) |
|
elif torch.backends.mps.is_available(): |
|
params.device = torch.device("mps") |
|
else: |
|
params.device = torch.device("cpu") |
|
|
|
print("Loading model...") |
|
if params.model_config is None: |
|
model_config = hf_hub_download( |
|
HUGGINGFACE_REPO, filename=MODEL_CONFIG[params.model_name] |
|
) |
|
else: |
|
model_config = params.model_config |
|
|
|
with open(model_config, "r") as f: |
|
model_config = json.load(f) |
|
|
|
if params.token_file is None: |
|
token_file = hf_hub_download( |
|
HUGGINGFACE_REPO, filename=TOKEN_FILE[params.model_name] |
|
) |
|
else: |
|
token_file = params.token_file |
|
|
|
tokenizer = DialogTokenizer(token_file=token_file) |
|
|
|
tokenizer_config = { |
|
"vocab_size": tokenizer.vocab_size, |
|
"pad_id": tokenizer.pad_id, |
|
"spk_a_id": tokenizer.spk_a_id, |
|
"spk_b_id": tokenizer.spk_b_id, |
|
} |
|
if params.checkpoint is None: |
|
model_ckpt = hf_hub_download( |
|
HUGGINGFACE_REPO, |
|
filename=PRETRAINED_MODEL[params.model_name], |
|
) |
|
else: |
|
model_ckpt = params.checkpoint |
|
|
|
if params.model_name == "zipvoice_dialog": |
|
model = ZipVoiceDialog( |
|
**model_config["model"], |
|
**tokenizer_config, |
|
) |
|
else: |
|
assert params.model_name == "zipvoice_dialog_stereo" |
|
model = ZipVoiceDialogStereo( |
|
**model_config["model"], |
|
**tokenizer_config, |
|
) |
|
|
|
if model_ckpt.endswith(".safetensors"): |
|
safetensors.torch.load_model(model, model_ckpt) |
|
elif model_ckpt.endswith(".pt"): |
|
load_checkpoint(filename=model_ckpt, model=model, strict=True) |
|
else: |
|
raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}") |
|
|
|
model = model.to(params.device) |
|
model.eval() |
|
|
|
vocoder = get_vocoder(params.vocoder_path) |
|
vocoder = vocoder.to(params.device) |
|
vocoder.eval() |
|
|
|
if model_config["feature"]["type"] == "vocos": |
|
if params.model_name == "zipvoice_dialog": |
|
num_channels = 1 |
|
else: |
|
assert params.model_name == "zipvoice_dialog_stereo" |
|
num_channels = 2 |
|
feature_extractor = VocosFbank(num_channels=num_channels) |
|
else: |
|
raise NotImplementedError( |
|
f"Unsupported feature type: {model_config['feature']['type']}" |
|
) |
|
params.sampling_rate = model_config["feature"]["sampling_rate"] |
|
|
|
print("Start generating...") |
|
os.makedirs(params.res_dir, exist_ok=True) |
|
generate_list( |
|
model_name=params.model_name, |
|
res_dir=params.res_dir, |
|
test_list=params.test_list, |
|
model=model, |
|
vocoder=vocoder, |
|
tokenizer=tokenizer, |
|
feature_extractor=feature_extractor, |
|
device=params.device, |
|
num_step=params.num_step, |
|
guidance_scale=params.guidance_scale, |
|
speed=params.speed, |
|
t_shift=params.t_shift, |
|
target_rms=params.target_rms, |
|
feat_scale=params.feat_scale, |
|
sampling_rate=params.sampling_rate, |
|
silence_wav=params.silence_wav, |
|
) |
|
print("Done") |
|
|
|
|
|
if __name__ == "__main__": |
|
torch.set_num_threads(1) |
|
torch.set_num_interop_threads(1) |
|
main() |
|
|