|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import multiprocessing as mp |
|
import torch |
|
import os |
|
from functools import partial |
|
import gradio as gr |
|
import traceback |
|
from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav |
|
|
|
os.system('huggingface-cli download ByteDance/MegaTTS3 --local-dir ./checkpoints --repo-type model') |
|
CUDA_AVAILABLE = torch.cuda.is_available() |
|
infer_pipe = MegaTTS3DiTInfer(device='cuda' if CUDA_AVAILABLE else 'cpu') |
|
|
|
@spaces.GPU(duration=120) |
|
def forward_gpu(file_content, latent_file, inp_text, time_step, p_w, t_w): |
|
resource_context = infer_pipe.preprocess(file_content, latent_file) |
|
wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=time_step, p_w=p_w, t_w=t_w) |
|
return wav_bytes |
|
|
|
def model_worker(input_queue, output_queue, device_id): |
|
while True: |
|
task = input_queue.get() |
|
inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task |
|
try: |
|
convert_to_wav(inp_audio_path) |
|
wav_path = os.path.splitext(inp_audio_path)[0] + '.wav' |
|
cut_wav(wav_path, max_len=28) |
|
with open(wav_path, 'rb') as file: |
|
file_content = file.read() |
|
wav_bytes = forward_gpu(file_content, inp_npy_path, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w) |
|
output_queue.put(wav_bytes) |
|
except Exception as e: |
|
traceback.print_exc() |
|
print(task, str(e)) |
|
output_queue.put(None) |
|
|
|
|
|
def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue): |
|
print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w) |
|
input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)) |
|
res = output_queue.get() |
|
if res is not None: |
|
return res |
|
else: |
|
print("") |
|
return None |
|
|
|
|
|
if __name__ == '__main__': |
|
mp.set_start_method('spawn', force=True) |
|
mp_manager = mp.Manager() |
|
|
|
num_workers = 1 |
|
devices = [0] |
|
input_queue = mp_manager.Queue() |
|
output_queue = mp_manager.Queue() |
|
processes = [] |
|
|
|
print("Start open workers") |
|
for i in range(num_workers): |
|
p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None)) |
|
p.start() |
|
processes.append(p) |
|
|
|
api_interface = gr.Interface(fn= |
|
partial(main, processes=processes, input_queue=input_queue, |
|
output_queue=output_queue), |
|
inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text", |
|
gr.Number(label="infer timestep", value=32), |
|
gr.Number(label="Intelligibility Weight", value=1.4), |
|
gr.Number(label="Similarity Weight", value=3.0)], outputs=[gr.Audio(label="Synthesized Audio")], |
|
title="MegaTTS3", |
|
description="Upload a speech clip as a reference for timbre, " + |
|
"upload the pre-extracted latent file, "+ |
|
"input the target text, and receive the cloned voice. "+ |
|
"Tip: a generation process should be within 120s (check if your input text are too long).", concurrency_limit=1) |
|
api_interface.launch(server_name='0.0.0.0', server_port=7860, debug=True) |
|
for p in processes: |
|
p.join() |
|
|