|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
import soundfile as sf |
|
import logging |
|
import gradio as gr |
|
from datetime import datetime |
|
from cli.SparkTTS import SparkTTS |
|
from sparktts.utils.token_parser import LEVELS_MAP_UI |
|
|
|
|
|
def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B", device=0): |
|
"""Load the model once at the beginning.""" |
|
logging.info(f"Loading model from: {model_dir}") |
|
device = torch.device(f"cuda:{device}") |
|
model = SparkTTS(model_dir, device) |
|
return model |
|
|
|
|
|
def run_tts( |
|
text, |
|
model, |
|
prompt_text=None, |
|
prompt_speech=None, |
|
gender=None, |
|
pitch=None, |
|
speed=None, |
|
save_dir="example/results", |
|
): |
|
"""Perform TTS inference and save the generated audio.""" |
|
logging.info(f"Saving audio to: {save_dir}") |
|
|
|
if prompt_text is not None: |
|
prompt_text = None if len(prompt_text) <= 1 else prompt_text |
|
|
|
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") |
|
save_path = os.path.join(save_dir, f"{timestamp}.wav") |
|
|
|
logging.info("Starting inference...") |
|
|
|
|
|
with torch.no_grad(): |
|
wav = model.inference( |
|
text, |
|
prompt_speech, |
|
prompt_text, |
|
gender, |
|
pitch, |
|
speed, |
|
) |
|
|
|
sf.write(save_path, wav, samplerate=16000) |
|
|
|
logging.info(f"Audio saved at: {save_path}") |
|
|
|
return save_path, model |
|
|
|
|
|
def voice_clone(text, model, prompt_text, prompt_wav_upload, prompt_wav_record): |
|
"""Gradio interface for TTS with prompt speech input.""" |
|
|
|
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record |
|
prompt_text = None if len(prompt_text) < 2 else prompt_text |
|
audio_output_path, model = run_tts( |
|
text, model, prompt_text=prompt_text, prompt_speech=prompt_speech |
|
) |
|
|
|
return audio_output_path, model |
|
|
|
|
|
def voice_creation(text, model, gender, pitch, speed): |
|
"""Gradio interface for TTS with control over voice attributes.""" |
|
pitch = LEVELS_MAP_UI[int(pitch)] |
|
speed = LEVELS_MAP_UI[int(speed)] |
|
audio_output_path, model = run_tts( |
|
text, model, gender=gender, pitch=pitch, speed=speed |
|
) |
|
return audio_output_path, model |
|
|
|
|
|
def build_ui(model_dir, device=0): |
|
with gr.Blocks() as demo: |
|
|
|
model = initialize_model(model_dir, device=device) |
|
|
|
gr.HTML('<h1 style="text-align: center;">Spark-TTS by SparkAudio</h1>') |
|
with gr.Tabs(): |
|
|
|
with gr.TabItem("Voice Clone"): |
|
gr.Markdown( |
|
"### Upload reference audio or recording (上传参考音频或者录音)" |
|
) |
|
|
|
with gr.Row(): |
|
prompt_wav_upload = gr.Audio( |
|
sources="upload", |
|
type="filepath", |
|
label="Choose the prompt audio file, ensuring the sampling rate is no lower than 16kHz.", |
|
) |
|
prompt_wav_record = gr.Audio( |
|
sources="microphone", |
|
type="filepath", |
|
label="Record the prompt audio file.", |
|
) |
|
|
|
with gr.Row(): |
|
text_input = gr.Textbox( |
|
label="Text", lines=3, placeholder="Enter text here" |
|
) |
|
prompt_text_input = gr.Textbox( |
|
label="Text of prompt speech (Optional; recommended for cloning in the same language.)", |
|
lines=3, |
|
placeholder="Enter text of the prompt speech.", |
|
) |
|
|
|
audio_output = gr.Audio( |
|
label="Generated Audio", autoplay=True, streaming=True |
|
) |
|
|
|
generate_buttom_clone = gr.Button("Generate") |
|
|
|
generate_buttom_clone.click( |
|
voice_clone, |
|
inputs=[ |
|
text_input, |
|
gr.State(model), |
|
prompt_text_input, |
|
prompt_wav_upload, |
|
prompt_wav_record, |
|
], |
|
outputs=[audio_output, gr.State(model)], |
|
) |
|
|
|
|
|
with gr.TabItem("Voice Creation"): |
|
gr.Markdown( |
|
"### Create your own voice based on the following parameters" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gender = gr.Radio( |
|
choices=["male", "female"], value="male", label="Gender" |
|
) |
|
pitch = gr.Slider( |
|
minimum=1, maximum=5, step=1, value=3, label="Pitch" |
|
) |
|
speed = gr.Slider( |
|
minimum=1, maximum=5, step=1, value=3, label="Speed" |
|
) |
|
with gr.Column(): |
|
text_input_creation = gr.Textbox( |
|
label="Input Text", |
|
lines=3, |
|
placeholder="Enter text here", |
|
value="You can generate a customized voice by adjusting parameters such as pitch and speed.", |
|
) |
|
create_button = gr.Button("Create Voice") |
|
|
|
audio_output = gr.Audio( |
|
label="Generated Audio", autoplay=True, streaming=True |
|
) |
|
create_button.click( |
|
voice_creation, |
|
inputs=[text_input_creation, gr.State(model), gender, pitch, speed], |
|
outputs=[audio_output, gr.State(model)], |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = build_ui(model_dir="pretrained_models/Spark-TTS-0.5B", device=5) |
|
demo.launch() |
|
|