reach-vb's picture
reach-vb HF Staff
Update app.py
47e9afc
import io
import os
import torch
import torchaudio
from TTS.api import TTS
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from TTS.utils.generic_utils import get_user_data_dir
import gradio as gr
from scipy.io.wavfile import write
from pydub import AudioSegment
os.environ["COQUI_TOS_AGREED"] = "1"
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1")
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json"))
model = Xtts.init_from_config(config)
model.load_checkpoint(
config,
checkpoint_path=os.path.join(model_path, "model.pth"),
vocab_path=os.path.join(model_path, "vocab.json"),
eval=True,
use_deepspeed=True
)
model.cuda()
def stream_audio(synthesis_text):
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path="female.wav")
wav_chunks = []
chunks = model.inference_stream(
synthesis_text,
"en",
gpt_cond_latent,
speaker_embedding,
stream_chunk_size=10,
overlap_wav_len=512)
for i, chunk in enumerate(chunks):
print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
out_file = f'{i}.wav'
write(out_file, 24000, chunk.detach().cpu().numpy().squeeze())
audio = AudioSegment.from_file(out_file)
audio.export(out_file, format='wav')
yield out_file
demo = gr.Interface(
fn=stream_audio,
inputs=gr.Textbox(),
outputs=gr.Audio(autoplay=True, streaming=True),
)
if __name__ == "__main__":
demo.queue().launch(debug=True)