Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| from pathlib import Path | |
| import nltk | |
| import torch | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| from voxtream.generator import SpeechGenerator, SpeechGeneratorConfig | |
| with open("configs/generator.json") as f: | |
| config = SpeechGeneratorConfig(**json.load(f)) | |
| # Loading speaker encoder | |
| torch.hub.load( | |
| config.spk_enc_repo, | |
| config.spk_enc_model, | |
| model_name=config.spk_enc_model_name, | |
| train_type=config.spk_enc_train_type, | |
| dataset=config.spk_enc_dataset, | |
| trust_repo=True, | |
| verbose=False, | |
| ) | |
| # Loading NLTK packages | |
| nltk.download("averaged_perceptron_tagger_eng", quiet=True, raise_on_error=True) | |
| nltk.download("punkt", quiet=True, raise_on_error=True) | |
| # Initialize speech generator | |
| speech_generator = SpeechGenerator(config) | |
| CUSTOM_CSS = """ | |
| /* overall width */ | |
| .gradio-container {max-width: 1100px !important} | |
| /* stack labels tighter and even heights */ | |
| #cols .wrap > .form {gap: 10px} | |
| #left-col, #right-col {gap: 14px} | |
| /* make submit centered + bigger */ | |
| #submit {width: 260px; margin: 10px auto 0 auto;} | |
| /* make clear align left and look secondary */ | |
| #clear {width: 120px;} | |
| /* give audio a little breathing room */ | |
| audio {outline: none;} | |
| """ | |
| def synthesize_fn(prompt_audio_path, prompt_text, target_text): | |
| if speech_generator.model.device == "cpu": | |
| speech_generator.model.to("cuda") | |
| speech_generator.mimi.to("cuda") | |
| speech_generator.spk_enc.to("cuda") | |
| speech_generator.aligner.to("cuda") | |
| if not prompt_audio_path or not target_text: | |
| return None | |
| stream = speech_generator.generate_stream( | |
| prompt_text=prompt_text, | |
| prompt_audio_path=Path(prompt_audio_path), | |
| text=target_text, | |
| ) | |
| frames = [frame for frame, _ in stream] | |
| if not frames: | |
| return None | |
| waveform = np.concatenate(frames).astype(np.float32) | |
| # Fade out | |
| fade_len_sec = 0.1 | |
| fade_out = np.linspace(1.0, 0.0, int(config.mimi_sr * fade_len_sec)) | |
| waveform[-int(config.mimi_sr * fade_len_sec) :] *= fade_out | |
| return (config.mimi_sr, waveform) | |
| def main(): | |
| with gr.Blocks(css=CUSTOM_CSS, title="VoXtream") as demo: | |
| gr.Markdown("# VoXtream TTS demo") | |
| with gr.Row(equal_height=True, elem_id="cols"): | |
| with gr.Column(scale=1, elem_id="left-col"): | |
| prompt_audio = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="Prompt audio (3-5 sec of target voice)", | |
| ) | |
| prompt_text = gr.Textbox( | |
| lines=3, | |
| label="Prompt transcript", | |
| placeholder="Text that matches the prompt audio (Required)", | |
| ) | |
| with gr.Column(scale=1, elem_id="right-col"): | |
| target_text = gr.Textbox( | |
| lines=3, | |
| label="Target text", | |
| placeholder="What you want the model to say", | |
| ) | |
| output_audio = gr.Audio( | |
| type="numpy", | |
| label="Synthesized audio", | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear", elem_id="clear", variant="secondary") | |
| submit_btn = gr.Button("Submit", elem_id="submit", variant="primary") | |
| # wire up actions | |
| submit_btn.click( | |
| fn=synthesize_fn, | |
| inputs=[prompt_audio, prompt_text, target_text], | |
| outputs=output_audio, | |
| ) | |
| # reset everything | |
| clear_btn.click( | |
| fn=lambda: (None, "", "", None), | |
| inputs=[], | |
| outputs=[prompt_audio, prompt_text, target_text, output_audio], | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |