Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""melotts training.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1srmto1Bf7xQl7la1-5cTZOvbTnL-KWDG | |
""" | |
# Fetch `notebook_utils` module | |
import requests | |
from pathlib import Path | |
if not Path("notebook_utils.py").exists(): | |
r = requests.get( | |
url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py", | |
) | |
open("notebook_utils.py", "w").write(r.text) | |
if not Path("cmd_helper.py").exists(): | |
r = requests.get( | |
url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/cmd_helper.py", | |
) | |
open("cmd_helper.py", "w").write(r.text) | |
if not Path("pip_helper.py").exists(): | |
r = requests.get( | |
url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/pip_helper.py", | |
) | |
open("pip_helper.py", "w").write(r.text) | |
# !!! have to restart session | |
from pathlib import Path | |
from cmd_helper import clone_repo | |
from pip_helper import pip_install | |
import platform | |
repo_dir = Path("OpenVoice") | |
clone_repo("https://github.com/myshell-ai/OpenVoice") | |
orig_english_path = Path("OpenVoice/openvoice/text/_orig_english.py") | |
english_path = Path("OpenVoice/openvoice/text/english.py") | |
if not orig_english_path.exists(): | |
orig_english_path = Path("OpenVoice/openvoice/text/_orig_english.py") | |
english_path = Path("OpenVoice/openvoice/text/english.py") | |
english_path.rename(orig_english_path) | |
with orig_english_path.open("r") as f: | |
data = f.read() | |
data = data.replace("unidecode", "anyascii") | |
with english_path.open("w") as out_f: | |
out_f.write(data) | |
# fix a problem with silero downloading and installing | |
with Path("OpenVoice/openvoice/se_extractor.py").open("r") as orig_file: | |
data = orig_file.read() | |
data = data.replace('method="silero"', 'method="silero:3.0"') | |
with Path("OpenVoice/openvoice/se_extractor.py").open("w") as out_f: | |
out_f.write(data) | |
# clone melotts | |
clone_repo("https://github.com/myshell-ai/MeloTTS") | |
pip_install( | |
"--no-deps", | |
"librosa==0.9.1", | |
"pydub==0.25.1", | |
"tqdm", | |
"inflect==7.0.0", | |
"pypinyin==0.50.0", | |
"openvino>=2025.0", | |
) | |
# Since we don't convert Japanese models, we have removed many heavy Japanese-related pip install dependencies. If you want to try, we recommend using a Python 3.10 environment on Ubuntu and uncommenting the relevant lines. | |
pip_install( | |
"--extra-index-url", | |
"https://download.pytorch.org/whl/cpu", | |
# "mecab-python3==1.0.9", | |
"nncf", | |
"wavmark>=0.0.3", | |
"faster-whisper>=0.9.0", | |
"eng_to_ipa==0.0.2", | |
"cn2an==0.5.22", | |
"jieba==0.42.1", | |
"langid==1.1.6", | |
"ipywebrtc", | |
"anyascii==0.3.2", | |
"torch>=2.1", | |
"torchaudio", | |
"cached_path", | |
"transformers>=4.38,<5.0", | |
"num2words==0.5.12", | |
# "unidic_lite==1.0.8", | |
# "unidic==1.1.0", | |
"pykakasi==2.2.1", | |
# "fugashi==1.3.0", | |
"g2p_en==2.1.0", | |
"jamo==0.4.1", | |
"gruut[de,es,fr]==2.2.3", | |
"g2pkk>=0.1.1", | |
"dtw-python", | |
"more-itertools", | |
"tiktoken", | |
"tensorboard==2.16.2", | |
"loguru==0.7.2", | |
"nltk", | |
"gradio", | |
) | |
pip_install("--no-deps", "whisper-timestamped>=1.14.2", "openai-whisper") | |
if platform.system() == "Darwin": | |
pip_install("numpy<2.0") | |
# fix the problem of `module 'botocore.exceptions' has no attribute 'HTTPClientError'` | |
pip_install("--upgrade", "botocore") | |
# donwload nltk data | |
import nltk | |
nltk.download("averaged_perceptron_tagger_eng") | |
# install unidic | |
# !python -m unidic download | |
# remove Japanese-related module in MeloTTS to fix dependencies issue | |
# If you want to use Japanese, please do not modify these files | |
import re | |
with Path("MeloTTS/melo/text/english.py").open("r", encoding="utf-8") as orig_file: | |
data = orig_file.read() | |
japanese_import = "from .japanese import distribute_phone" | |
replacement_function = """ | |
def distribute_phone(n_phone, n_word): | |
phones_per_word = [0] * n_word | |
for task in range(n_phone): | |
min_tasks = min(phones_per_word) | |
min_index = phones_per_word.index(min_tasks) | |
phones_per_word[min_index] += 1 | |
return phones_per_word | |
""" | |
data = data.replace(japanese_import, replacement_function) # replace `from .japanese import distribute_phone` with the function | |
with Path("MeloTTS/melo/text/english.py").open("w", encoding="utf-8") as out_f: | |
out_f.write(data) | |
with Path("MeloTTS/melo/text/__init__.py").open("r", encoding="utf-8") as orig_file: | |
data = orig_file.read() | |
data = data.replace("from .japanese_bert import get_bert_feature as jp_bert", "") | |
data = data.replace("from .spanish_bert import get_bert_feature as sp_bert", "") | |
data = data.replace("from .french_bert import get_bert_feature as fr_bert", "") | |
data = data.replace("from .korean import get_bert_feature as kr_bert", "") | |
# Replace the lang_bert_func_map dictionary, keeping only the keys ZH, EN, and ZH_MIX_EN | |
pattern = re.compile(r"lang_bert_func_map\s*=\s*\{[^}]+\}", re.DOTALL) | |
replacement = """lang_bert_func_map = { | |
"ZH": zh_bert, | |
"EN": en_bert, | |
"ZH_MIX_EN": zh_mix_en_bert, | |
}""" | |
data = pattern.sub(replacement, data) | |
with Path("MeloTTS/melo/text/__init__.py").open("w", encoding="utf-8") as out_f: | |
out_f.write(data) | |
# clean the modules | |
for filename in ["japanese.py", "japanese_bert.py"]: | |
Path(f"MeloTTS/melo/text/{filename}").write_text("", encoding="utf-8") | |
import os | |
import torch | |
import openvino as ov | |
import ipywidgets as widgets | |
from IPython.display import Audio | |
from notebook_utils import download_file, device_widget | |
core = ov.Core() | |
from openvoice.api import ToneColorConverter, OpenVoiceBaseClass | |
import openvoice.se_extractor as se_extractor | |
from melo.api import TTS | |
CKPT_BASE_PATH = Path("checkpoints") | |
base_speakers_suffix = CKPT_BASE_PATH / "base_speakers" / "ses" | |
converter_suffix = CKPT_BASE_PATH / "converter" | |
melotts_chinese_suffix = CKPT_BASE_PATH / "MeloTTS-Chinese" | |
melotts_english_suffix = CKPT_BASE_PATH / "MeloTTS-English-v3" | |
def download_from_hf_hub(repo_id, filename, local_dir="./"): | |
from huggingface_hub import hf_hub_download | |
local_path = Path(local_dir) | |
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_path) | |
# Download OpenVoice2 | |
download_from_hf_hub("myshell-ai/OpenVoiceV2", "converter/checkpoint.pth", CKPT_BASE_PATH) | |
download_from_hf_hub("myshell-ai/OpenVoiceV2", "converter/config.json", CKPT_BASE_PATH) | |
download_from_hf_hub("myshell-ai/OpenVoiceV2", "base_speakers/ses/en-newest.pth", CKPT_BASE_PATH) | |
download_from_hf_hub("myshell-ai/OpenVoiceV2", "base_speakers/ses/zh.pth", CKPT_BASE_PATH) | |
# Download MeloTTS | |
download_from_hf_hub("myshell-ai/MeloTTS-Chinese", "checkpoint.pth", melotts_chinese_suffix) | |
download_from_hf_hub("myshell-ai/MeloTTS-Chinese", "config.json", melotts_chinese_suffix) | |
download_from_hf_hub("myshell-ai/MeloTTS-English-v3", "checkpoint.pth", melotts_english_suffix) | |
download_from_hf_hub("myshell-ai/MeloTTS-English-v3", "config.json", melotts_english_suffix) | |
class OVSynthesizerTTSWrapper(torch.nn.Module): | |
""" | |
Wrapper for SynthesizerTrn model from MeloTTS to make it compatible with Torch-style inference. | |
""" | |
def __init__(self, model, language): | |
super().__init__() | |
self.model = model | |
self.language = language | |
def forward( | |
self, | |
x, | |
x_lengths, | |
sid, | |
tone, | |
language, | |
bert, | |
ja_bert, | |
noise_scale, | |
length_scale, | |
noise_scale_w, | |
sdp_ratio, | |
): | |
""" | |
Forward call to the underlying SynthesizerTrn model. Accepts arbitrary arguments | |
and forwards them directly to the model's inference method. | |
""" | |
return self.model.infer( | |
x, | |
x_lengths, | |
sid, | |
tone, | |
language, | |
bert, | |
ja_bert, | |
sdp_ratio=sdp_ratio, | |
noise_scale=noise_scale, | |
noise_scale_w=noise_scale_w, | |
length_scale=length_scale, | |
) | |
def get_example_input(self): | |
""" | |
Return a tuple of example inputs for tracing/ONNX exporting or debugging. | |
When exporting the SynthesizerTrn function, | |
This model has been found to be very sensitive to the example_input used for model transformation. | |
Here, we have implemented some simple rules or considered using real input data. | |
""" | |
def gen_interleaved_random_tensor(length, value_range): | |
"""Generate a Tensor in the format [0, val, 0, val, ..., 0], val ∈ [low, high).""" | |
return torch.tensor([[0 if i % 2 == 0 else torch.randint(*value_range, (1,)).item() for i in range(length)]], dtype=torch.int64).to(pt_device) | |
def gen_interleaved_fixed_tensor(length, fixed_value): | |
"""Generate a Tensor in the format [0, val, 0, val, ..., 0]""" | |
interleaved = [0 if i % 2 == 0 else fixed_value for i in range(length)] | |
return torch.tensor([interleaved], dtype=torch.int64).to(pt_device) | |
if self.language == "EN_NEWEST": | |
seq_len = 73 | |
x_tst = gen_interleaved_random_tensor(seq_len, (14, 220)) | |
x_tst[:3] = 0 | |
x_tst[-3:] = 0 | |
x_tst_lengths = torch.tensor([seq_len], dtype=torch.int64).to(pt_device) | |
speakers = torch.tensor([0], dtype=torch.int64).to(pt_device) # This model has only one fixed id for speakers. | |
tones = gen_interleaved_random_tensor(seq_len, (5, 10)) | |
lang_ids = gen_interleaved_fixed_tensor(seq_len, 2) # lang_id for english | |
bert = torch.randn((1, 1024, seq_len), dtype=torch.float32).to(pt_device) | |
ja_bert = torch.randn(1, 768, seq_len, dtype=torch.float32).to(pt_device) | |
sdp_ratio = torch.tensor(0.2).to(pt_device) | |
noise_scale = torch.tensor(0.6).to(pt_device) | |
noise_scale_w = torch.tensor(0.8).to(pt_device) | |
length_scale = torch.tensor(1.0).to(pt_device) | |
elif self.language == "ZH": | |
seq_len = 37 | |
x_tst = gen_interleaved_random_tensor(seq_len, (7, 100)) | |
x_tst[:3] = 0 | |
x_tst[-3:] = 0 | |
x_tst_lengths = torch.tensor([37], dtype=torch.int64).to(pt_device) | |
speakers = torch.tensor([1], dtype=torch.int64).to(pt_device) # This model has only one fixed id for speakers. | |
tones = gen_interleaved_random_tensor(seq_len, (4, 9)) | |
lang_ids = gen_interleaved_fixed_tensor(seq_len, 3) # lang_id for chinese | |
bert = torch.zeros((1, 1024, 37), dtype=torch.float32).to(pt_device) | |
ja_bert = torch.randn(1, 768, 37).float().to(pt_device) | |
sdp_ratio = torch.tensor(0.2).to(pt_device) | |
noise_scale = torch.tensor(0.6).to(pt_device) | |
noise_scale_w = torch.tensor(0.8).to(pt_device) | |
length_scale = torch.tensor(1.0).to(pt_device) | |
return ( | |
x_tst, | |
x_tst_lengths, | |
speakers, | |
tones, | |
lang_ids, | |
bert, | |
ja_bert, | |
noise_scale, | |
length_scale, | |
noise_scale_w, | |
sdp_ratio, | |
) | |
class OVOpenVoiceConverter(torch.nn.Module): | |
def __init__(self, voice_model: OpenVoiceBaseClass): | |
super().__init__() | |
self.voice_model = voice_model | |
for par in voice_model.model.parameters(): | |
par.requires_grad = False | |
def get_example_input(self): | |
y = torch.randn([1, 513, 238], dtype=torch.float32) | |
y_lengths = torch.LongTensor([y.size(-1)]) | |
target_se = torch.randn(*(1, 256, 1)) | |
source_se = torch.randn(*(1, 256, 1)) | |
tau = torch.tensor(0.3) | |
return (y, y_lengths, source_se, target_se, tau) | |
def forward(self, y, y_lengths, sid_src, sid_tgt, tau): | |
""" | |
wraps the 'voice_conversion' method with forward. | |
""" | |
return self.voice_model.model.voice_conversion(y, y_lengths, sid_src, sid_tgt, tau) | |
pt_device = "cpu" | |
melo_tts_en_newest = TTS( | |
"EN_NEWEST", | |
pt_device, | |
use_hf=False, | |
config_path=melotts_english_suffix / "config.json", | |
ckpt_path=melotts_english_suffix / "checkpoint.pth", | |
) | |
melo_tts_zh = TTS( | |
"ZH", | |
pt_device, | |
use_hf=False, | |
config_path=melotts_chinese_suffix / "config.json", | |
ckpt_path=melotts_chinese_suffix / "checkpoint.pth", | |
) | |
tone_color_converter = ToneColorConverter(converter_suffix / "config.json", device=pt_device) | |
tone_color_converter.load_ckpt(converter_suffix / "checkpoint.pth") | |
print(f"ToneColorConverter version: {tone_color_converter.version}") | |
import nncf | |
IRS_PATH = Path("openvino_irs/") | |
EN_TTS_IR = IRS_PATH / "melo_tts_en_newest.xml" | |
ZH_TTS_IR = IRS_PATH / "melo_tts_zh.xml" | |
VOICE_CONVERTER_IR = IRS_PATH / "openvoice2_tone_conversion.xml" | |
paths = [EN_TTS_IR, ZH_TTS_IR, VOICE_CONVERTER_IR] | |
models = [ | |
OVSynthesizerTTSWrapper(melo_tts_en_newest.model, "EN_NEWEST"), | |
OVSynthesizerTTSWrapper(melo_tts_zh.model, "ZH"), | |
OVOpenVoiceConverter(tone_color_converter), | |
] | |
ov_models = [] | |
for model, path in zip(models, paths): | |
if not path.exists(): | |
ov_model = ov.convert_model(model, example_input=model.get_example_input()) | |
ov_model = nncf.compress_weights(ov_model) | |
ov.save_model(ov_model, path) | |
else: | |
ov_model = core.read_model(path) | |
ov_models.append(ov_model) | |
ov_en_tts, ov_zh_tts, ov_voice_conversion = ov_models | |
core = ov.Core() | |
device = device_widget("CPU", exclude=["NPU"]) | |
device | |
REFERENCE_VOICES_PATH = f"{repo_dir}/resources/" | |
reference_speakers = [ | |
*[path for path in os.listdir(REFERENCE_VOICES_PATH) if os.path.splitext(path)[-1] == ".mp3"], | |
"record_manually", | |
"load_manually", | |
] | |
ref_speaker = widgets.Dropdown( | |
options=reference_speakers, | |
value=reference_speakers[0], | |
description="reference voice from which tone color will be copied", | |
disabled=False, | |
) | |
ref_speaker | |
OUTPUT_DIR = Path("outputs/") | |
OUTPUT_DIR.mkdir(exist_ok=True) | |
ref_speaker_path = f"{REFERENCE_VOICES_PATH}/{ref_speaker.value}" | |
allowed_audio_types = ".mp4,.mp3,.wav,.wma,.aac,.m4a,.m4b,.webm" | |
if ref_speaker.value == "record_manually": | |
ref_speaker_path = OUTPUT_DIR / "custom_example_sample.webm" | |
from ipywebrtc import AudioRecorder, CameraStream | |
camera = CameraStream(constraints={"audio": True, "video": False}) | |
recorder = AudioRecorder(stream=camera, filename=ref_speaker_path, autosave=True) | |
display(recorder) | |
elif ref_speaker.value == "load_manually": | |
upload_ref = widgets.FileUpload( | |
accept=allowed_audio_types, | |
multiple=False, | |
description="Select audio with reference voice", | |
) | |
display(upload_ref) | |
def save_audio(voice_source: widgets.FileUpload, out_path: str): | |
with open(out_path, "wb") as output_file: | |
assert len(voice_source.value) > 0, "Please select audio file" | |
output_file.write(voice_source.value[0]["content"]) | |
if ref_speaker.value == "load_manually": | |
ref_speaker_path = f"{OUTPUT_DIR}/{upload_ref.value[0].name}" | |
save_audio(upload_ref, ref_speaker_path) | |
Audio(ref_speaker_path) | |
# Commented out IPython magic to ensure Python compatibility. | |
torch_hub_local = Path("torch_hub_local/") | |
# %env TORCH_HOME={str(torch_hub_local.absolute())} | |
# second step to fix a problem with silero downloading and installing | |
import os | |
import zipfile | |
url = "https://github.com/snakers4/silero-vad/zipball/v3.0" | |
torch_hub_dir = torch_hub_local / "hub" | |
torch.hub.set_dir(torch_hub_dir.as_posix()) | |
zip_filename = "v3.0.zip" | |
output_path = torch_hub_dir / "v3.0" | |
if not (torch_hub_dir / zip_filename).exists(): | |
download_file(url, directory=torch_hub_dir, filename=zip_filename) | |
zip_ref = zipfile.ZipFile((torch_hub_dir / zip_filename).as_posix(), "r") | |
zip_ref.extractall(path=output_path.as_posix()) | |
zip_ref.close() | |
v3_dirs = [d for d in output_path.iterdir() if "snakers4-silero-vad" in d.as_posix()] | |
if len(v3_dirs) > 0 and not (torch_hub_dir / "snakers4_silero-vad_v3.0").exists(): | |
v3_dir = str(v3_dirs[0]) | |
os.rename(str(v3_dirs[0]), (torch_hub_dir / "snakers4_silero-vad_v3.0").as_posix()) | |
en_source_newest_se = torch.load(base_speakers_suffix / "en-newest.pth") | |
zh_source_se = torch.load(base_speakers_suffix / "zh.pth") | |
target_se, audio_name = se_extractor.get_se(ref_speaker_path, tone_color_converter, target_dir=OUTPUT_DIR, vad=True) | |
def get_pathched_infer(ov_model: ov.Model, device: str) -> callable: | |
compiled_model = core.compile_model(ov_model, device) | |
def infer_impl( | |
x, | |
x_lengths, | |
sid, | |
tone, | |
language, | |
bert, | |
ja_bert, | |
noise_scale, | |
length_scale, | |
noise_scale_w, | |
max_len=None, | |
sdp_ratio=1.0, | |
y=None, | |
g=None, | |
): | |
ov_output = compiled_model( | |
( | |
x, | |
x_lengths, | |
sid, | |
tone, | |
language, | |
bert, | |
ja_bert, | |
noise_scale, | |
length_scale, | |
noise_scale_w, | |
sdp_ratio, | |
) | |
) | |
return (torch.tensor(ov_output[0]),) | |
return infer_impl | |
def get_patched_voice_conversion(ov_model: ov.Model, device: str) -> callable: | |
compiled_model = core.compile_model(ov_model, device) | |
def voice_conversion_impl(y, y_lengths, sid_src, sid_tgt, tau): | |
ov_output = compiled_model((y, y_lengths, sid_src, sid_tgt, tau)) | |
return (torch.tensor(ov_output[0]),) | |
return voice_conversion_impl | |
melo_tts_en_newest.model.infer = get_pathched_infer(ov_en_tts, device.value) | |
melo_tts_zh.model.infer = get_pathched_infer(ov_zh_tts, device.value) | |
tone_color_converter.model.voice_conversion = get_patched_voice_conversion(ov_voice_conversion, device.value) | |
voice_source = widgets.Dropdown( | |
options=["use TTS", "choose_manually"], | |
value="use TTS", | |
description="Voice source", | |
disabled=False, | |
) | |
voice_source | |
if voice_source.value == "choose_manually": | |
upload_orig_voice = widgets.FileUpload( | |
accept=allowed_audio_types, | |
multiple=False, | |
description="audio whose tone will be replaced", | |
) | |
display(upload_orig_voice) | |
from IPython.display import Audio, display | |
if voice_source.value == "choose_manually": | |
orig_voice_path = f"{OUTPUT_DIR}/{upload_orig_voice.value[0].name}" | |
save_audio(upload_orig_voice, orig_voice_path) | |
source_se, _ = se_extractor.get_se(orig_voice_path, tone_color_converter, target_dir=OUTPUT_DIR, vad=True) | |
else: | |
en_text = """ | |
I love going to school by bus | |
""" | |
# source_se = en_source_newest_se | |
en_orig_voice_path = OUTPUT_DIR / "output_ov_en-newest.wav" | |
print("use output_ov_en-newest.wav") | |
speaker_id = 0 # Choose the first speaker | |
melo_tts_en_newest.tts_to_file(en_text, speaker_id, en_orig_voice_path, speed=1.0) | |
zh_text = """ | |
OpenVINO 是一个全面的开发工具集,旨在快速开发和部署各类应用程序及解决方案,可用于模仿人类视觉、自动语音识别、自然语言处理、 | |
推荐系统等多种任务。 | |
""" | |
# source_se = zh_source_se | |
zh_orig_voice_path = OUTPUT_DIR / "output_ov_zh.wav" | |
print("use output_ov_zh.wav") | |
speaker_id = 1 # Choose the first speaker | |
melo_tts_zh.tts_to_file(zh_text, speaker_id, zh_orig_voice_path, speed=1.0) | |
print("Playing English Original voice") | |
display(Audio(en_orig_voice_path)) | |
print("Playing Chinese Original voice") | |
display(Audio(zh_orig_voice_path)) | |
tau_slider = widgets.FloatSlider( | |
value=0.3, | |
min=0.01, | |
max=2.0, | |
step=0.01, | |
description="tau", | |
disabled=False, | |
readout_format=".2f", | |
) | |
tau_slider | |
from IPython.display import Audio, display | |
if voice_source.value == "choose_manually": | |
resulting_voice_path = OUTPUT_DIR / "output_ov_cloned.wav" | |
tone_color_converter.convert( | |
audio_src_path=orig_voice_path, | |
src_se=source_se, | |
tgt_se=target_se, | |
output_path=resulting_voice_path, | |
tau=tau_slider.value, | |
message="@MyShell", | |
) | |
print("Playing manually chosen cloned voice:") | |
display(Audio(resulting_voice_path)) | |
else: | |
en_resulting_voice_path = OUTPUT_DIR / "output_ov_en-newest_cloned.wav" | |
zh_resulting_voice_path = OUTPUT_DIR / "output_ov_zh_cloned.wav" | |
tone_color_converter.convert( | |
audio_src_path=en_orig_voice_path, | |
src_se=en_source_newest_se, | |
tgt_se=target_se, | |
output_path=en_resulting_voice_path, | |
tau=tau_slider.value, | |
message="@MyShell", | |
) | |
tone_color_converter.convert( | |
audio_src_path=zh_orig_voice_path, | |
src_se=zh_source_se, | |
tgt_se=target_se, | |
output_path=zh_resulting_voice_path, | |
tau=tau_slider.value, | |
message="@MyShell", | |
) | |
print("Playing English cloned voice:") | |
display(Audio(en_resulting_voice_path)) | |
print("Playing Chinese cloned voice:") | |
display(Audio(zh_resulting_voice_path)) | |
import gradio as gr | |
import langid | |
supported_languages = ["zh", "en"] | |
supported_styles = { | |
"zh": "zh_default", | |
"en": [ | |
"en_latest", | |
], | |
} | |
def predict_impl( | |
prompt, | |
style, | |
audio_file_pth, | |
agree, | |
output_dir, | |
tone_color_converter, | |
en_tts_model, | |
zh_tts_model, | |
en_source_se, | |
zh_source_se, | |
): | |
text_hint = "" | |
if not agree: | |
text_hint += "[ERROR] Please accept the Terms & Condition!\n" | |
gr.Warning("Please accept the Terms & Condition!") | |
return ( | |
text_hint, | |
None, | |
None, | |
) | |
language_predicted = langid.classify(prompt)[0].strip() | |
if language_predicted not in supported_languages: | |
text_hint += f"[ERROR] The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}\n" | |
gr.Warning(f"The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}") | |
return ( | |
text_hint, | |
None, | |
None, | |
) | |
# check the style | |
if style not in supported_styles[language_predicted]: | |
text_hint += f"[Warming] The style {style} is not supported for detected language {language_predicted}. For language {language_predicted}, we support styles: {supported_styles[language_predicted]}. Using the wrong style may result in unexpected behavior.\n" | |
gr.Warning( | |
f"[Warming] The style {style} is not supported for detected language {language_predicted}. For language {language_predicted}, we support styles: {supported_styles[language_predicted]}. Using the wrong style may result in unexpected behavior." | |
) | |
if len(prompt.split()) < 2: | |
text_hint += "[ERROR] Please give a longer prompt text \n" | |
gr.Warning("Please give a longer prompt text") | |
return ( | |
text_hint, | |
None, | |
None, | |
) | |
if len(prompt.split()) > 50: | |
text_hint += "[ERROR] Text length limited to 50 words for this demo, please try shorter text. You can clone our open-source repo or try it on our website https://app.myshell.ai/robot-workshop/widget/174760057433406749 \n" | |
gr.Warning( | |
"Text length limited to 50 words for this demo, please try shorter text. You can clone our open-source repo or try it on our website https://app.myshell.ai/robot-workshop/widget/174760057433406749" | |
) | |
return ( | |
text_hint, | |
None, | |
None, | |
) | |
speaker_wav = audio_file_pth | |
if language_predicted == "zh": | |
tts_model = zh_tts_model | |
if zh_tts_model is None: | |
gr.Warning("TTS model for Chinece language was not loaded") | |
return ( | |
text_hint, | |
None, | |
None, | |
) | |
source_se = zh_source_se | |
speaker_id = 1 | |
else: | |
tts_model = en_tts_model | |
if en_tts_model is None: | |
gr.Warning("TTS model for English language was not loaded") | |
return ( | |
text_hint, | |
None, | |
None, | |
) | |
source_se = en_source_se | |
speaker_id = 0 | |
# note diffusion_conditioning not used on hifigan (default mode), it will be empty but need to pass it to model.inference | |
try: | |
target_se, audio_name = se_extractor.get_se(speaker_wav, tone_color_converter, target_dir=OUTPUT_DIR, vad=True) | |
except Exception as e: | |
text_hint += f"[ERROR] Get target tone color error {str(e)} \n" | |
gr.Warning("[ERROR] Get target tone color error {str(e)} \n") | |
return ( | |
text_hint, | |
None, | |
None, | |
) | |
src_path = f"{output_dir}/tmp.wav" | |
tts_model.tts_to_file(prompt, speaker_id, src_path, speed=1.0) | |
if tone_color_converter is None or source_se is None: | |
gr.Warning("Tone Color Converter model was not loaded") | |
return ( | |
text_hint, | |
None, | |
None, | |
) | |
save_path = f"{output_dir}/output.wav" | |
encode_message = "@MyShell" | |
tone_color_converter.convert( | |
audio_src_path=src_path, | |
src_se=source_se, | |
tgt_se=target_se, | |
output_path=save_path, | |
tau=0.3, | |
message=encode_message, | |
) | |
text_hint += "Get response successfully \n" | |
return ( | |
text_hint, | |
src_path, | |
save_path, | |
) | |
from functools import partial | |
predict = partial( | |
predict_impl, | |
output_dir=OUTPUT_DIR, | |
tone_color_converter=tone_color_converter, | |
en_tts_model=melo_tts_en_newest, | |
zh_tts_model=melo_tts_zh, | |
en_source_se=en_source_newest_se, | |
zh_source_se=zh_source_se, | |
) | |
import sys | |
if "gradio_helper" in sys.modules: | |
del sys.modules["gradio_helper"] | |
if not Path("gradio_helper.py").exists(): | |
r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/openvoice/gradio_helper.py") | |
open("gradio_helper.py", "w").write(r.text) | |
from gradio_helper import make_demo | |
demo = make_demo(fn=predict) | |
# demo.queue(max_size=1).launch(share=True, debug=True, height=1000) | |
demo.queue(max_size=1).launch(server_name="0.0.0.0", server_port=7860) | |
# try: | |
# demo.queue(max_size=1).launch(debug=True, height=1000) | |
# except Exception: | |
# demo.queue(max_size=1).launch(share=True, debug=True, height=1000) | |
# if you are launching remotely, specify server_name and server_port | |
# demo.launch(server_name='your server name', server_port='server port in int') | |
# Read more in the docs: https://gradio.app/docs/ |