melotts-clone / melotts_training.py
wolfofbackstreet's picture
init
8aa5548 verified
# -*- coding: utf-8 -*-
"""melotts training.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1srmto1Bf7xQl7la1-5cTZOvbTnL-KWDG
"""
# Fetch `notebook_utils` module
import requests
from pathlib import Path
if not Path("notebook_utils.py").exists():
r = requests.get(
url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
)
open("notebook_utils.py", "w").write(r.text)
if not Path("cmd_helper.py").exists():
r = requests.get(
url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/cmd_helper.py",
)
open("cmd_helper.py", "w").write(r.text)
if not Path("pip_helper.py").exists():
r = requests.get(
url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/pip_helper.py",
)
open("pip_helper.py", "w").write(r.text)
# !!! have to restart session
from pathlib import Path
from cmd_helper import clone_repo
from pip_helper import pip_install
import platform
repo_dir = Path("OpenVoice")
clone_repo("https://github.com/myshell-ai/OpenVoice")
orig_english_path = Path("OpenVoice/openvoice/text/_orig_english.py")
english_path = Path("OpenVoice/openvoice/text/english.py")
if not orig_english_path.exists():
orig_english_path = Path("OpenVoice/openvoice/text/_orig_english.py")
english_path = Path("OpenVoice/openvoice/text/english.py")
english_path.rename(orig_english_path)
with orig_english_path.open("r") as f:
data = f.read()
data = data.replace("unidecode", "anyascii")
with english_path.open("w") as out_f:
out_f.write(data)
# fix a problem with silero downloading and installing
with Path("OpenVoice/openvoice/se_extractor.py").open("r") as orig_file:
data = orig_file.read()
data = data.replace('method="silero"', 'method="silero:3.0"')
with Path("OpenVoice/openvoice/se_extractor.py").open("w") as out_f:
out_f.write(data)
# clone melotts
clone_repo("https://github.com/myshell-ai/MeloTTS")
pip_install(
"--no-deps",
"librosa==0.9.1",
"pydub==0.25.1",
"tqdm",
"inflect==7.0.0",
"pypinyin==0.50.0",
"openvino>=2025.0",
)
# Since we don't convert Japanese models, we have removed many heavy Japanese-related pip install dependencies. If you want to try, we recommend using a Python 3.10 environment on Ubuntu and uncommenting the relevant lines.
pip_install(
"--extra-index-url",
"https://download.pytorch.org/whl/cpu",
# "mecab-python3==1.0.9",
"nncf",
"wavmark>=0.0.3",
"faster-whisper>=0.9.0",
"eng_to_ipa==0.0.2",
"cn2an==0.5.22",
"jieba==0.42.1",
"langid==1.1.6",
"ipywebrtc",
"anyascii==0.3.2",
"torch>=2.1",
"torchaudio",
"cached_path",
"transformers>=4.38,<5.0",
"num2words==0.5.12",
# "unidic_lite==1.0.8",
# "unidic==1.1.0",
"pykakasi==2.2.1",
# "fugashi==1.3.0",
"g2p_en==2.1.0",
"jamo==0.4.1",
"gruut[de,es,fr]==2.2.3",
"g2pkk>=0.1.1",
"dtw-python",
"more-itertools",
"tiktoken",
"tensorboard==2.16.2",
"loguru==0.7.2",
"nltk",
"gradio",
)
pip_install("--no-deps", "whisper-timestamped>=1.14.2", "openai-whisper")
if platform.system() == "Darwin":
pip_install("numpy<2.0")
# fix the problem of `module 'botocore.exceptions' has no attribute 'HTTPClientError'`
pip_install("--upgrade", "botocore")
# donwload nltk data
import nltk
nltk.download("averaged_perceptron_tagger_eng")
# install unidic
# !python -m unidic download
# remove Japanese-related module in MeloTTS to fix dependencies issue
# If you want to use Japanese, please do not modify these files
import re
with Path("MeloTTS/melo/text/english.py").open("r", encoding="utf-8") as orig_file:
data = orig_file.read()
japanese_import = "from .japanese import distribute_phone"
replacement_function = """
def distribute_phone(n_phone, n_word):
phones_per_word = [0] * n_word
for task in range(n_phone):
min_tasks = min(phones_per_word)
min_index = phones_per_word.index(min_tasks)
phones_per_word[min_index] += 1
return phones_per_word
"""
data = data.replace(japanese_import, replacement_function) # replace `from .japanese import distribute_phone` with the function
with Path("MeloTTS/melo/text/english.py").open("w", encoding="utf-8") as out_f:
out_f.write(data)
with Path("MeloTTS/melo/text/__init__.py").open("r", encoding="utf-8") as orig_file:
data = orig_file.read()
data = data.replace("from .japanese_bert import get_bert_feature as jp_bert", "")
data = data.replace("from .spanish_bert import get_bert_feature as sp_bert", "")
data = data.replace("from .french_bert import get_bert_feature as fr_bert", "")
data = data.replace("from .korean import get_bert_feature as kr_bert", "")
# Replace the lang_bert_func_map dictionary, keeping only the keys ZH, EN, and ZH_MIX_EN
pattern = re.compile(r"lang_bert_func_map\s*=\s*\{[^}]+\}", re.DOTALL)
replacement = """lang_bert_func_map = {
"ZH": zh_bert,
"EN": en_bert,
"ZH_MIX_EN": zh_mix_en_bert,
}"""
data = pattern.sub(replacement, data)
with Path("MeloTTS/melo/text/__init__.py").open("w", encoding="utf-8") as out_f:
out_f.write(data)
# clean the modules
for filename in ["japanese.py", "japanese_bert.py"]:
Path(f"MeloTTS/melo/text/{filename}").write_text("", encoding="utf-8")
import os
import torch
import openvino as ov
import ipywidgets as widgets
from IPython.display import Audio
from notebook_utils import download_file, device_widget
core = ov.Core()
from openvoice.api import ToneColorConverter, OpenVoiceBaseClass
import openvoice.se_extractor as se_extractor
from melo.api import TTS
CKPT_BASE_PATH = Path("checkpoints")
base_speakers_suffix = CKPT_BASE_PATH / "base_speakers" / "ses"
converter_suffix = CKPT_BASE_PATH / "converter"
melotts_chinese_suffix = CKPT_BASE_PATH / "MeloTTS-Chinese"
melotts_english_suffix = CKPT_BASE_PATH / "MeloTTS-English-v3"
def download_from_hf_hub(repo_id, filename, local_dir="./"):
from huggingface_hub import hf_hub_download
local_path = Path(local_dir)
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_path)
# Download OpenVoice2
download_from_hf_hub("myshell-ai/OpenVoiceV2", "converter/checkpoint.pth", CKPT_BASE_PATH)
download_from_hf_hub("myshell-ai/OpenVoiceV2", "converter/config.json", CKPT_BASE_PATH)
download_from_hf_hub("myshell-ai/OpenVoiceV2", "base_speakers/ses/en-newest.pth", CKPT_BASE_PATH)
download_from_hf_hub("myshell-ai/OpenVoiceV2", "base_speakers/ses/zh.pth", CKPT_BASE_PATH)
# Download MeloTTS
download_from_hf_hub("myshell-ai/MeloTTS-Chinese", "checkpoint.pth", melotts_chinese_suffix)
download_from_hf_hub("myshell-ai/MeloTTS-Chinese", "config.json", melotts_chinese_suffix)
download_from_hf_hub("myshell-ai/MeloTTS-English-v3", "checkpoint.pth", melotts_english_suffix)
download_from_hf_hub("myshell-ai/MeloTTS-English-v3", "config.json", melotts_english_suffix)
class OVSynthesizerTTSWrapper(torch.nn.Module):
"""
Wrapper for SynthesizerTrn model from MeloTTS to make it compatible with Torch-style inference.
"""
def __init__(self, model, language):
super().__init__()
self.model = model
self.language = language
def forward(
self,
x,
x_lengths,
sid,
tone,
language,
bert,
ja_bert,
noise_scale,
length_scale,
noise_scale_w,
sdp_ratio,
):
"""
Forward call to the underlying SynthesizerTrn model. Accepts arbitrary arguments
and forwards them directly to the model's inference method.
"""
return self.model.infer(
x,
x_lengths,
sid,
tone,
language,
bert,
ja_bert,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
)
def get_example_input(self):
"""
Return a tuple of example inputs for tracing/ONNX exporting or debugging.
When exporting the SynthesizerTrn function,
This model has been found to be very sensitive to the example_input used for model transformation.
Here, we have implemented some simple rules or considered using real input data.
"""
def gen_interleaved_random_tensor(length, value_range):
"""Generate a Tensor in the format [0, val, 0, val, ..., 0], val ∈ [low, high)."""
return torch.tensor([[0 if i % 2 == 0 else torch.randint(*value_range, (1,)).item() for i in range(length)]], dtype=torch.int64).to(pt_device)
def gen_interleaved_fixed_tensor(length, fixed_value):
"""Generate a Tensor in the format [0, val, 0, val, ..., 0]"""
interleaved = [0 if i % 2 == 0 else fixed_value for i in range(length)]
return torch.tensor([interleaved], dtype=torch.int64).to(pt_device)
if self.language == "EN_NEWEST":
seq_len = 73
x_tst = gen_interleaved_random_tensor(seq_len, (14, 220))
x_tst[:3] = 0
x_tst[-3:] = 0
x_tst_lengths = torch.tensor([seq_len], dtype=torch.int64).to(pt_device)
speakers = torch.tensor([0], dtype=torch.int64).to(pt_device) # This model has only one fixed id for speakers.
tones = gen_interleaved_random_tensor(seq_len, (5, 10))
lang_ids = gen_interleaved_fixed_tensor(seq_len, 2) # lang_id for english
bert = torch.randn((1, 1024, seq_len), dtype=torch.float32).to(pt_device)
ja_bert = torch.randn(1, 768, seq_len, dtype=torch.float32).to(pt_device)
sdp_ratio = torch.tensor(0.2).to(pt_device)
noise_scale = torch.tensor(0.6).to(pt_device)
noise_scale_w = torch.tensor(0.8).to(pt_device)
length_scale = torch.tensor(1.0).to(pt_device)
elif self.language == "ZH":
seq_len = 37
x_tst = gen_interleaved_random_tensor(seq_len, (7, 100))
x_tst[:3] = 0
x_tst[-3:] = 0
x_tst_lengths = torch.tensor([37], dtype=torch.int64).to(pt_device)
speakers = torch.tensor([1], dtype=torch.int64).to(pt_device) # This model has only one fixed id for speakers.
tones = gen_interleaved_random_tensor(seq_len, (4, 9))
lang_ids = gen_interleaved_fixed_tensor(seq_len, 3) # lang_id for chinese
bert = torch.zeros((1, 1024, 37), dtype=torch.float32).to(pt_device)
ja_bert = torch.randn(1, 768, 37).float().to(pt_device)
sdp_ratio = torch.tensor(0.2).to(pt_device)
noise_scale = torch.tensor(0.6).to(pt_device)
noise_scale_w = torch.tensor(0.8).to(pt_device)
length_scale = torch.tensor(1.0).to(pt_device)
return (
x_tst,
x_tst_lengths,
speakers,
tones,
lang_ids,
bert,
ja_bert,
noise_scale,
length_scale,
noise_scale_w,
sdp_ratio,
)
class OVOpenVoiceConverter(torch.nn.Module):
def __init__(self, voice_model: OpenVoiceBaseClass):
super().__init__()
self.voice_model = voice_model
for par in voice_model.model.parameters():
par.requires_grad = False
def get_example_input(self):
y = torch.randn([1, 513, 238], dtype=torch.float32)
y_lengths = torch.LongTensor([y.size(-1)])
target_se = torch.randn(*(1, 256, 1))
source_se = torch.randn(*(1, 256, 1))
tau = torch.tensor(0.3)
return (y, y_lengths, source_se, target_se, tau)
def forward(self, y, y_lengths, sid_src, sid_tgt, tau):
"""
wraps the 'voice_conversion' method with forward.
"""
return self.voice_model.model.voice_conversion(y, y_lengths, sid_src, sid_tgt, tau)
pt_device = "cpu"
melo_tts_en_newest = TTS(
"EN_NEWEST",
pt_device,
use_hf=False,
config_path=melotts_english_suffix / "config.json",
ckpt_path=melotts_english_suffix / "checkpoint.pth",
)
melo_tts_zh = TTS(
"ZH",
pt_device,
use_hf=False,
config_path=melotts_chinese_suffix / "config.json",
ckpt_path=melotts_chinese_suffix / "checkpoint.pth",
)
tone_color_converter = ToneColorConverter(converter_suffix / "config.json", device=pt_device)
tone_color_converter.load_ckpt(converter_suffix / "checkpoint.pth")
print(f"ToneColorConverter version: {tone_color_converter.version}")
import nncf
IRS_PATH = Path("openvino_irs/")
EN_TTS_IR = IRS_PATH / "melo_tts_en_newest.xml"
ZH_TTS_IR = IRS_PATH / "melo_tts_zh.xml"
VOICE_CONVERTER_IR = IRS_PATH / "openvoice2_tone_conversion.xml"
paths = [EN_TTS_IR, ZH_TTS_IR, VOICE_CONVERTER_IR]
models = [
OVSynthesizerTTSWrapper(melo_tts_en_newest.model, "EN_NEWEST"),
OVSynthesizerTTSWrapper(melo_tts_zh.model, "ZH"),
OVOpenVoiceConverter(tone_color_converter),
]
ov_models = []
for model, path in zip(models, paths):
if not path.exists():
ov_model = ov.convert_model(model, example_input=model.get_example_input())
ov_model = nncf.compress_weights(ov_model)
ov.save_model(ov_model, path)
else:
ov_model = core.read_model(path)
ov_models.append(ov_model)
ov_en_tts, ov_zh_tts, ov_voice_conversion = ov_models
core = ov.Core()
device = device_widget("CPU", exclude=["NPU"])
device
REFERENCE_VOICES_PATH = f"{repo_dir}/resources/"
reference_speakers = [
*[path for path in os.listdir(REFERENCE_VOICES_PATH) if os.path.splitext(path)[-1] == ".mp3"],
"record_manually",
"load_manually",
]
ref_speaker = widgets.Dropdown(
options=reference_speakers,
value=reference_speakers[0],
description="reference voice from which tone color will be copied",
disabled=False,
)
ref_speaker
OUTPUT_DIR = Path("outputs/")
OUTPUT_DIR.mkdir(exist_ok=True)
ref_speaker_path = f"{REFERENCE_VOICES_PATH}/{ref_speaker.value}"
allowed_audio_types = ".mp4,.mp3,.wav,.wma,.aac,.m4a,.m4b,.webm"
if ref_speaker.value == "record_manually":
ref_speaker_path = OUTPUT_DIR / "custom_example_sample.webm"
from ipywebrtc import AudioRecorder, CameraStream
camera = CameraStream(constraints={"audio": True, "video": False})
recorder = AudioRecorder(stream=camera, filename=ref_speaker_path, autosave=True)
display(recorder)
elif ref_speaker.value == "load_manually":
upload_ref = widgets.FileUpload(
accept=allowed_audio_types,
multiple=False,
description="Select audio with reference voice",
)
display(upload_ref)
def save_audio(voice_source: widgets.FileUpload, out_path: str):
with open(out_path, "wb") as output_file:
assert len(voice_source.value) > 0, "Please select audio file"
output_file.write(voice_source.value[0]["content"])
if ref_speaker.value == "load_manually":
ref_speaker_path = f"{OUTPUT_DIR}/{upload_ref.value[0].name}"
save_audio(upload_ref, ref_speaker_path)
Audio(ref_speaker_path)
# Commented out IPython magic to ensure Python compatibility.
torch_hub_local = Path("torch_hub_local/")
# %env TORCH_HOME={str(torch_hub_local.absolute())}
# second step to fix a problem with silero downloading and installing
import os
import zipfile
url = "https://github.com/snakers4/silero-vad/zipball/v3.0"
torch_hub_dir = torch_hub_local / "hub"
torch.hub.set_dir(torch_hub_dir.as_posix())
zip_filename = "v3.0.zip"
output_path = torch_hub_dir / "v3.0"
if not (torch_hub_dir / zip_filename).exists():
download_file(url, directory=torch_hub_dir, filename=zip_filename)
zip_ref = zipfile.ZipFile((torch_hub_dir / zip_filename).as_posix(), "r")
zip_ref.extractall(path=output_path.as_posix())
zip_ref.close()
v3_dirs = [d for d in output_path.iterdir() if "snakers4-silero-vad" in d.as_posix()]
if len(v3_dirs) > 0 and not (torch_hub_dir / "snakers4_silero-vad_v3.0").exists():
v3_dir = str(v3_dirs[0])
os.rename(str(v3_dirs[0]), (torch_hub_dir / "snakers4_silero-vad_v3.0").as_posix())
en_source_newest_se = torch.load(base_speakers_suffix / "en-newest.pth")
zh_source_se = torch.load(base_speakers_suffix / "zh.pth")
target_se, audio_name = se_extractor.get_se(ref_speaker_path, tone_color_converter, target_dir=OUTPUT_DIR, vad=True)
def get_pathched_infer(ov_model: ov.Model, device: str) -> callable:
compiled_model = core.compile_model(ov_model, device)
def infer_impl(
x,
x_lengths,
sid,
tone,
language,
bert,
ja_bert,
noise_scale,
length_scale,
noise_scale_w,
max_len=None,
sdp_ratio=1.0,
y=None,
g=None,
):
ov_output = compiled_model(
(
x,
x_lengths,
sid,
tone,
language,
bert,
ja_bert,
noise_scale,
length_scale,
noise_scale_w,
sdp_ratio,
)
)
return (torch.tensor(ov_output[0]),)
return infer_impl
def get_patched_voice_conversion(ov_model: ov.Model, device: str) -> callable:
compiled_model = core.compile_model(ov_model, device)
def voice_conversion_impl(y, y_lengths, sid_src, sid_tgt, tau):
ov_output = compiled_model((y, y_lengths, sid_src, sid_tgt, tau))
return (torch.tensor(ov_output[0]),)
return voice_conversion_impl
melo_tts_en_newest.model.infer = get_pathched_infer(ov_en_tts, device.value)
melo_tts_zh.model.infer = get_pathched_infer(ov_zh_tts, device.value)
tone_color_converter.model.voice_conversion = get_patched_voice_conversion(ov_voice_conversion, device.value)
voice_source = widgets.Dropdown(
options=["use TTS", "choose_manually"],
value="use TTS",
description="Voice source",
disabled=False,
)
voice_source
if voice_source.value == "choose_manually":
upload_orig_voice = widgets.FileUpload(
accept=allowed_audio_types,
multiple=False,
description="audio whose tone will be replaced",
)
display(upload_orig_voice)
from IPython.display import Audio, display
if voice_source.value == "choose_manually":
orig_voice_path = f"{OUTPUT_DIR}/{upload_orig_voice.value[0].name}"
save_audio(upload_orig_voice, orig_voice_path)
source_se, _ = se_extractor.get_se(orig_voice_path, tone_color_converter, target_dir=OUTPUT_DIR, vad=True)
else:
en_text = """
I love going to school by bus
"""
# source_se = en_source_newest_se
en_orig_voice_path = OUTPUT_DIR / "output_ov_en-newest.wav"
print("use output_ov_en-newest.wav")
speaker_id = 0 # Choose the first speaker
melo_tts_en_newest.tts_to_file(en_text, speaker_id, en_orig_voice_path, speed=1.0)
zh_text = """
OpenVINO 是一个全面的开发工具集,旨在快速开发和部署各类应用程序及解决方案,可用于模仿人类视觉、自动语音识别、自然语言处理、
推荐系统等多种任务。
"""
# source_se = zh_source_se
zh_orig_voice_path = OUTPUT_DIR / "output_ov_zh.wav"
print("use output_ov_zh.wav")
speaker_id = 1 # Choose the first speaker
melo_tts_zh.tts_to_file(zh_text, speaker_id, zh_orig_voice_path, speed=1.0)
print("Playing English Original voice")
display(Audio(en_orig_voice_path))
print("Playing Chinese Original voice")
display(Audio(zh_orig_voice_path))
tau_slider = widgets.FloatSlider(
value=0.3,
min=0.01,
max=2.0,
step=0.01,
description="tau",
disabled=False,
readout_format=".2f",
)
tau_slider
from IPython.display import Audio, display
if voice_source.value == "choose_manually":
resulting_voice_path = OUTPUT_DIR / "output_ov_cloned.wav"
tone_color_converter.convert(
audio_src_path=orig_voice_path,
src_se=source_se,
tgt_se=target_se,
output_path=resulting_voice_path,
tau=tau_slider.value,
message="@MyShell",
)
print("Playing manually chosen cloned voice:")
display(Audio(resulting_voice_path))
else:
en_resulting_voice_path = OUTPUT_DIR / "output_ov_en-newest_cloned.wav"
zh_resulting_voice_path = OUTPUT_DIR / "output_ov_zh_cloned.wav"
tone_color_converter.convert(
audio_src_path=en_orig_voice_path,
src_se=en_source_newest_se,
tgt_se=target_se,
output_path=en_resulting_voice_path,
tau=tau_slider.value,
message="@MyShell",
)
tone_color_converter.convert(
audio_src_path=zh_orig_voice_path,
src_se=zh_source_se,
tgt_se=target_se,
output_path=zh_resulting_voice_path,
tau=tau_slider.value,
message="@MyShell",
)
print("Playing English cloned voice:")
display(Audio(en_resulting_voice_path))
print("Playing Chinese cloned voice:")
display(Audio(zh_resulting_voice_path))
import gradio as gr
import langid
supported_languages = ["zh", "en"]
supported_styles = {
"zh": "zh_default",
"en": [
"en_latest",
],
}
def predict_impl(
prompt,
style,
audio_file_pth,
agree,
output_dir,
tone_color_converter,
en_tts_model,
zh_tts_model,
en_source_se,
zh_source_se,
):
text_hint = ""
if not agree:
text_hint += "[ERROR] Please accept the Terms & Condition!\n"
gr.Warning("Please accept the Terms & Condition!")
return (
text_hint,
None,
None,
)
language_predicted = langid.classify(prompt)[0].strip()
if language_predicted not in supported_languages:
text_hint += f"[ERROR] The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}\n"
gr.Warning(f"The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}")
return (
text_hint,
None,
None,
)
# check the style
if style not in supported_styles[language_predicted]:
text_hint += f"[Warming] The style {style} is not supported for detected language {language_predicted}. For language {language_predicted}, we support styles: {supported_styles[language_predicted]}. Using the wrong style may result in unexpected behavior.\n"
gr.Warning(
f"[Warming] The style {style} is not supported for detected language {language_predicted}. For language {language_predicted}, we support styles: {supported_styles[language_predicted]}. Using the wrong style may result in unexpected behavior."
)
if len(prompt.split()) < 2:
text_hint += "[ERROR] Please give a longer prompt text \n"
gr.Warning("Please give a longer prompt text")
return (
text_hint,
None,
None,
)
if len(prompt.split()) > 50:
text_hint += "[ERROR] Text length limited to 50 words for this demo, please try shorter text. You can clone our open-source repo or try it on our website https://app.myshell.ai/robot-workshop/widget/174760057433406749 \n"
gr.Warning(
"Text length limited to 50 words for this demo, please try shorter text. You can clone our open-source repo or try it on our website https://app.myshell.ai/robot-workshop/widget/174760057433406749"
)
return (
text_hint,
None,
None,
)
speaker_wav = audio_file_pth
if language_predicted == "zh":
tts_model = zh_tts_model
if zh_tts_model is None:
gr.Warning("TTS model for Chinece language was not loaded")
return (
text_hint,
None,
None,
)
source_se = zh_source_se
speaker_id = 1
else:
tts_model = en_tts_model
if en_tts_model is None:
gr.Warning("TTS model for English language was not loaded")
return (
text_hint,
None,
None,
)
source_se = en_source_se
speaker_id = 0
# note diffusion_conditioning not used on hifigan (default mode), it will be empty but need to pass it to model.inference
try:
target_se, audio_name = se_extractor.get_se(speaker_wav, tone_color_converter, target_dir=OUTPUT_DIR, vad=True)
except Exception as e:
text_hint += f"[ERROR] Get target tone color error {str(e)} \n"
gr.Warning("[ERROR] Get target tone color error {str(e)} \n")
return (
text_hint,
None,
None,
)
src_path = f"{output_dir}/tmp.wav"
tts_model.tts_to_file(prompt, speaker_id, src_path, speed=1.0)
if tone_color_converter is None or source_se is None:
gr.Warning("Tone Color Converter model was not loaded")
return (
text_hint,
None,
None,
)
save_path = f"{output_dir}/output.wav"
encode_message = "@MyShell"
tone_color_converter.convert(
audio_src_path=src_path,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
tau=0.3,
message=encode_message,
)
text_hint += "Get response successfully \n"
return (
text_hint,
src_path,
save_path,
)
from functools import partial
predict = partial(
predict_impl,
output_dir=OUTPUT_DIR,
tone_color_converter=tone_color_converter,
en_tts_model=melo_tts_en_newest,
zh_tts_model=melo_tts_zh,
en_source_se=en_source_newest_se,
zh_source_se=zh_source_se,
)
import sys
if "gradio_helper" in sys.modules:
del sys.modules["gradio_helper"]
if not Path("gradio_helper.py").exists():
r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/openvoice/gradio_helper.py")
open("gradio_helper.py", "w").write(r.text)
from gradio_helper import make_demo
demo = make_demo(fn=predict)
# demo.queue(max_size=1).launch(share=True, debug=True, height=1000)
demo.queue(max_size=1).launch(server_name="0.0.0.0", server_port=7860)
# try:
# demo.queue(max_size=1).launch(debug=True, height=1000)
# except Exception:
# demo.queue(max_size=1).launch(share=True, debug=True, height=1000)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/