mrfakename commited on
Commit
2afcda9
·
1 Parent(s): ceb1c72

Add Gradio app, MPS support

Browse files
Files changed (4) hide show
  1. README.md +83 -25
  2. gradio_app.py +265 -0
  3. requirements_gradio.txt +3 -0
  4. test_infer_single.py +1 -1
README.md CHANGED
@@ -1,26 +1,34 @@
1
-
2
  # F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
3
 
4
  [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
5
  [![demo](https://img.shields.io/badge/GitHub-Demo%20page-blue.svg)](https://swivid.github.io/F5-TTS/)
6
- [![space](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS) \
7
- **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference. \
8
- **E2 TTS**: Flat-UNet Transformer, closest reproduction.\
 
 
 
9
  **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
10
 
11
  ## Installation
12
- Clone this repository.
 
 
13
  ```bash
14
- git clone git@github.com:SWivid/F5-TTS.git
15
  cd F5-TTS
16
  ```
17
- Install packages.
 
 
18
  ```bash
19
  pip install -r requirements.txt
20
  ```
21
 
22
  ## Prepare Dataset
 
23
  Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
 
24
  ```bash
25
  # prepare custom dataset up to your need
26
  # download corresponding dataset first, and fill in the path in scripts
@@ -33,7 +41,9 @@ python scripts/prepare_wenetspeech4tts.py
33
  ```
34
 
35
  ## Training
 
36
  Once your datasets are prepared, you can start the training process.
 
37
  ```bash
38
  # setup accelerate config, e.g. use multi-gpu ddp, fp16
39
  # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
@@ -42,10 +52,13 @@ accelerate launch test_train.py
42
  ```
43
 
44
  ## Inference
45
- To inference with pretrained models, download the checkpoints from [🤗](https://huggingface.co/SWivid/F5-TTS).
 
46
 
47
  ### Single Inference
 
48
  You can test single inference using the following command. Before running the command, modify the config up to your need.
 
49
  ```bash
50
  # modify the config up to your need,
51
  # e.g. fix_duration (the total length of prompt + to_generate, currently support up to 30s)
@@ -54,14 +67,46 @@ You can test single inference using the following command. Before running the co
54
  # ( though 'midpoint' is 2nd-order ode solver, slower compared to 1st-order 'Euler')
55
  python test_infer_single.py
56
  ```
57
- ### Speech Edit
 
58
  To test speech editing capabilities, use the following command.
59
- ```
 
60
  python test_infer_single_edit.py
61
  ```
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  ## Evaluation
 
64
  ### Prepare Test Datasets
 
65
  1. Seed-TTS test set: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
66
  2. LibriSpeech test-clean: Download from [OpenSLR](http://www.openslr.org/12/).
67
  3. Unzip the downloaded datasets and place them in the data/ directory.
@@ -69,7 +114,9 @@ python test_infer_single_edit.py
69
  5. Our filtered LibriSpeech-PC 4-10s subset is already under data/ in this repo
70
 
71
  ### Batch Inference for Test Set
 
72
  To run batch inference for evaluations, execute the following commands:
 
73
  ```bash
74
  # batch inference for evaluations
75
  accelerate config # if not set before
@@ -77,16 +124,26 @@ bash test_infer_batch.sh
77
  ```
78
 
79
  ### Download Evaluation Model Checkpoints
 
80
  1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
81
  2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
82
  3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
83
 
84
  ### Objective Evaluation
85
- **Some Notes**\
86
- For faster-whisper with CUDA 11: \
87
- `pip install --force-reinstall ctranslate2==3.24.0`\
88
- (Recommended) To avoid possible ASR failures, such as abnormal repetitions in output:\
89
- `pip install faster-whisper==0.10.1`
 
 
 
 
 
 
 
 
 
90
 
91
  Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
92
  ```bash
@@ -99,14 +156,14 @@ python scripts/eval_librispeech_test_clean.py
99
 
100
  ## Acknowledgements
101
 
102
- - <a href="https://arxiv.org/abs/2406.18009">E2-TTS</a> brilliant work, simple and effective
103
- - <a href="https://arxiv.org/abs/2407.05361">Emilia</a>, <a href="https://arxiv.org/abs/2406.05763">WenetSpeech4TTS</a> valuable datasets
104
- - <a href="https://github.com/lucidrains/e2-tts-pytorch">lucidrains</a> initial CFM structure</a> with also <a href="https://github.com/bfs18">bfs18</a> for discussion</a>
105
- - <a href="https://arxiv.org/abs/2403.03206">SD3</a> & <a href="https://github.com/huggingface/diffusers">Huggingface diffusers</a> DiT and MMDiT code structure
106
- - <a href="https://github.com/rtqichen/torchdiffeq">torchdiffeq</a> as ODE solver, <a href="https://huggingface.co/charactr/vocos-mel-24khz">Vocos</a> as vocoder
107
- - <a href="https://x.com/realmrfakename">mrfakename</a> huggingface space demo ~
108
- - <a href="https://github.com/modelscope/FunASR">FunASR</a>, <a href="https://github.com/SYSTRAN/faster-whisper">faster-whisper</a> & <a href="https://github.com/microsoft/UniSpeech">UniSpeech</a> for evaluation tools
109
- - <a href="https://github.com/MahmoudAshraf97/ctc-forced-aligner">ctc-forced-aligner</a> for speech edit test
110
 
111
  ## Citation
112
  ```
@@ -117,5 +174,6 @@ python scripts/eval_librispeech_test_clean.py
117
  year={2024},
118
  }
119
  ```
120
- ## LICENSE
121
- Our code is released under MIT License.
 
 
 
1
  # F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
2
 
3
  [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
4
  [![demo](https://img.shields.io/badge/GitHub-Demo%20page-blue.svg)](https://swivid.github.io/F5-TTS/)
5
+ [![space](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
6
+
7
+ **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
8
+
9
+ **E2 TTS**: Flat-UNet Transformer, closest reproduction.
10
+
11
  **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
12
 
13
  ## Installation
14
+
15
+ Clone the repository:
16
+
17
  ```bash
18
+ git clone https://github.com/SWivid/F5-TTS.git
19
  cd F5-TTS
20
  ```
21
+
22
+ Install packages:
23
+
24
  ```bash
25
  pip install -r requirements.txt
26
  ```
27
 
28
  ## Prepare Dataset
29
+
30
  Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
31
+
32
  ```bash
33
  # prepare custom dataset up to your need
34
  # download corresponding dataset first, and fill in the path in scripts
 
41
  ```
42
 
43
  ## Training
44
+
45
  Once your datasets are prepared, you can start the training process.
46
+
47
  ```bash
48
  # setup accelerate config, e.g. use multi-gpu ddp, fp16
49
  # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
 
52
  ```
53
 
54
  ## Inference
55
+
56
+ To run inference with pretrained models, download the checkpoints from [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS).
57
 
58
  ### Single Inference
59
+
60
  You can test single inference using the following command. Before running the command, modify the config up to your need.
61
+
62
  ```bash
63
  # modify the config up to your need,
64
  # e.g. fix_duration (the total length of prompt + to_generate, currently support up to 30s)
 
67
  # ( though 'midpoint' is 2nd-order ode solver, slower compared to 1st-order 'Euler')
68
  python test_infer_single.py
69
  ```
70
+ ### Speech Editing
71
+
72
  To test speech editing capabilities, use the following command.
73
+
74
+ ```bash
75
  python test_infer_single_edit.py
76
  ```
77
 
78
+ ### Gradio App
79
+
80
+ You can launch a Gradio app (web interface) to launch a GUI for inference.
81
+
82
+ First, make sure you have the dependencies installed (`pip install -r requirements.txt`). Then, install the Gradio app dependencies:
83
+
84
+ ```bash
85
+ pip install -r requirements_gradio.txt
86
+ ```
87
+
88
+ After installing the dependencies, launch the app:
89
+
90
+ ```bash
91
+ python gradio_app.py
92
+ ```
93
+
94
+ You can specify the port/host:
95
+
96
+ ```bash
97
+ python gradio_app.py --port 7860 --host 0.0.0.0
98
+ ```
99
+
100
+ Or launch a share link:
101
+
102
+ ```bash
103
+ python gradio_app.py --share
104
+ ```
105
+
106
  ## Evaluation
107
+
108
  ### Prepare Test Datasets
109
+
110
  1. Seed-TTS test set: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
111
  2. LibriSpeech test-clean: Download from [OpenSLR](http://www.openslr.org/12/).
112
  3. Unzip the downloaded datasets and place them in the data/ directory.
 
114
  5. Our filtered LibriSpeech-PC 4-10s subset is already under data/ in this repo
115
 
116
  ### Batch Inference for Test Set
117
+
118
  To run batch inference for evaluations, execute the following commands:
119
+
120
  ```bash
121
  # batch inference for evaluations
122
  accelerate config # if not set before
 
124
  ```
125
 
126
  ### Download Evaluation Model Checkpoints
127
+
128
  1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
129
  2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
130
  3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
131
 
132
  ### Objective Evaluation
133
+
134
+ **Some Notes**
135
+
136
+ For faster-whisper with CUDA 11:
137
+
138
+ ```bash
139
+ pip install --force-reinstall ctranslate2==3.24.0
140
+ ```
141
+
142
+ (Recommended) To avoid possible ASR failures, such as abnormal repetitions in output:
143
+
144
+ ```bash
145
+ pip install faster-whisper==0.10.1
146
+ ```
147
 
148
  Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
149
  ```bash
 
156
 
157
  ## Acknowledgements
158
 
159
+ - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
160
+ - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763) valuable datasets
161
+ - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
162
+ - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
163
+ - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
164
+ - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
165
+ - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
166
+ - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
167
 
168
  ## Citation
169
  ```
 
174
  year={2024},
175
  }
176
  ```
177
+ ## License
178
+
179
+ Our code is released under MIT License.
gradio_app.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
+ import torchaudio
5
+ import gradio as gr
6
+ import numpy as np
7
+ import tempfile
8
+ from einops import rearrange
9
+ from ema_pytorch import EMA
10
+ from vocos import Vocos
11
+ from pydub import AudioSegment
12
+ from model import CFM, UNetT, DiT, MMDiT
13
+ from cached_path import cached_path
14
+ from model.utils import (
15
+ get_tokenizer,
16
+ convert_char_to_pinyin,
17
+ save_spectrogram,
18
+ )
19
+ from transformers import pipeline
20
+ import librosa
21
+ import click
22
+
23
+ device = (
24
+ "cuda"
25
+ if torch.cuda.is_available()
26
+ else "mps" if torch.backends.mps.is_available() else "cpu"
27
+ )
28
+
29
+ print(f"Using {device} device")
30
+
31
+ pipe = pipeline(
32
+ "automatic-speech-recognition",
33
+ model="openai/whisper-large-v3-turbo",
34
+ torch_dtype=torch.float16,
35
+ device=device,
36
+ )
37
+
38
+ # --------------------- Settings -------------------- #
39
+
40
+ target_sample_rate = 24000
41
+ n_mel_channels = 100
42
+ hop_length = 256
43
+ target_rms = 0.1
44
+ nfe_step = 32 # 16, 32
45
+ cfg_strength = 2.0
46
+ ode_method = "euler"
47
+ sway_sampling_coef = -1.0
48
+ speed = 1.0
49
+ # fix_duration = 27 # None or float (duration in seconds)
50
+ fix_duration = None
51
+
52
+
53
+ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
54
+ checkpoint = torch.load(
55
+ str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")),
56
+ map_location=device,
57
+ )
58
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
59
+ model = CFM(
60
+ transformer=model_cls(
61
+ **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
62
+ ),
63
+ mel_spec_kwargs=dict(
64
+ target_sample_rate=target_sample_rate,
65
+ n_mel_channels=n_mel_channels,
66
+ hop_length=hop_length,
67
+ ),
68
+ odeint_kwargs=dict(
69
+ method=ode_method,
70
+ ),
71
+ vocab_char_map=vocab_char_map,
72
+ ).to(device)
73
+
74
+ ema_model = EMA(model, include_online_model=False).to(device)
75
+ ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
76
+ ema_model.copy_params_from_ema_to_model()
77
+
78
+ return ema_model, model
79
+
80
+
81
+ # load models
82
+ F5TTS_model_cfg = dict(
83
+ dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
84
+ )
85
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
86
+
87
+ F5TTS_ema_model, F5TTS_base_model = load_model(
88
+ "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
89
+ )
90
+ E2TTS_ema_model, E2TTS_base_model = load_model(
91
+ "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
92
+ )
93
+
94
+
95
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
96
+ print(gen_text)
97
+ if len(gen_text) > 200:
98
+ raise gr.Error("Please keep your text under 200 chars.")
99
+ gr.Info("Converting audio...")
100
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
101
+ aseg = AudioSegment.from_file(ref_audio_orig)
102
+ audio_duration = len(aseg)
103
+ if audio_duration > 15000:
104
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
105
+ aseg = aseg[:15000]
106
+ aseg.export(f.name, format="wav")
107
+ ref_audio = f.name
108
+ if exp_name == "F5-TTS":
109
+ ema_model = F5TTS_ema_model
110
+ base_model = F5TTS_base_model
111
+ elif exp_name == "E2-TTS":
112
+ ema_model = E2TTS_ema_model
113
+ base_model = E2TTS_base_model
114
+
115
+ if not ref_text.strip():
116
+ gr.Info("No reference text provided, transcribing reference audio...")
117
+ ref_text = outputs = pipe(
118
+ ref_audio,
119
+ chunk_length_s=30,
120
+ batch_size=128,
121
+ generate_kwargs={"task": "transcribe"},
122
+ return_timestamps=False,
123
+ )["text"].strip()
124
+ gr.Info("Finished transcription")
125
+ else:
126
+ gr.Info("Using custom reference text...")
127
+ audio, sr = torchaudio.load(ref_audio)
128
+
129
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
130
+ if rms < target_rms:
131
+ audio = audio * target_rms / rms
132
+ if sr != target_sample_rate:
133
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
134
+ audio = resampler(audio)
135
+ audio = audio.to(device)
136
+
137
+ # Prepare the text
138
+ text_list = [ref_text + gen_text]
139
+ final_text_list = convert_char_to_pinyin(text_list)
140
+
141
+ # Calculate duration
142
+ ref_audio_len = audio.shape[-1] // hop_length
143
+ # if fix_duration is not None:
144
+ # duration = int(fix_duration * target_sample_rate / hop_length)
145
+ # else:
146
+ zh_pause_punc = r"。,、;:?!"
147
+ ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
148
+ gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
149
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
150
+
151
+ # inference
152
+ gr.Info(f"Generating audio using {exp_name}")
153
+ with torch.inference_mode():
154
+ generated, _ = base_model.sample(
155
+ cond=audio,
156
+ text=final_text_list,
157
+ duration=duration,
158
+ steps=nfe_step,
159
+ cfg_strength=cfg_strength,
160
+ sway_sampling_coef=sway_sampling_coef,
161
+ )
162
+
163
+ generated = generated[:, ref_audio_len:, :]
164
+ generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
165
+ gr.Info("Running vocoder")
166
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
167
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
168
+ if rms < target_rms:
169
+ generated_wave = generated_wave * rms / target_rms
170
+
171
+ # wav -> numpy
172
+ generated_wave = generated_wave.squeeze().cpu().numpy()
173
+
174
+ if remove_silence:
175
+ gr.Info("Removing audio silences... This may take a moment")
176
+ non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
177
+ non_silent_wave = np.array([])
178
+ for interval in non_silent_intervals:
179
+ start, end = interval
180
+ non_silent_wave = np.concatenate(
181
+ [non_silent_wave, generated_wave[start:end]]
182
+ )
183
+ generated_wave = non_silent_wave
184
+
185
+ # spectogram
186
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
187
+ spectrogram_path = tmp_spectrogram.name
188
+ save_spectrogram(generated_mel_spec[0].cpu().numpy(), spectrogram_path)
189
+
190
+ return (target_sample_rate, generated_wave), spectrogram_path
191
+
192
+
193
+ with gr.Blocks() as app:
194
+ gr.Markdown(
195
+ """
196
+ # E2/F5 TTS
197
+
198
+ This is a local web UI for F5 TTS, based on the unofficial [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS). This app supports the following TTS models:
199
+
200
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
201
+ * [E2-TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
202
+
203
+ The checkpoints support English and Chinese.
204
+
205
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
206
+
207
+ **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
208
+ """
209
+ )
210
+
211
+ ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
212
+ gen_text_input = gr.Textbox(label="Text to Generate (max 200 chars.)", lines=4)
213
+ model_choice = gr.Radio(
214
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
215
+ )
216
+ generate_btn = gr.Button("Synthesize", variant="primary")
217
+ with gr.Accordion("Advanced Settings", open=False):
218
+ ref_text_input = gr.Textbox(
219
+ label="Reference Text",
220
+ info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
221
+ lines=2,
222
+ )
223
+ remove_silence = gr.Checkbox(
224
+ label="Remove Silences",
225
+ info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
226
+ value=True,
227
+ )
228
+
229
+ audio_output = gr.Audio(label="Synthesized Audio")
230
+ spectrogram_output = gr.Image(label="Spectrogram")
231
+
232
+ generate_btn.click(
233
+ infer,
234
+ inputs=[
235
+ ref_audio_input,
236
+ ref_text_input,
237
+ gen_text_input,
238
+ model_choice,
239
+ remove_silence,
240
+ ],
241
+ outputs=[audio_output, spectrogram_output],
242
+ )
243
+
244
+
245
+ @click.command()
246
+ @click.option("--port", "-p", default=None, help="Port to run the app on")
247
+ @click.option("--host", "-H", default=None, help="Host to run the app on")
248
+ @click.option(
249
+ "--share",
250
+ "-s",
251
+ default=False,
252
+ is_flag=True,
253
+ help="Share the app via Gradio share link",
254
+ )
255
+ @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
256
+ def main(port, host, share, api):
257
+ global app
258
+ print(f"Starting app...")
259
+ app.queue(api_open=api).launch(
260
+ server_name=host, server_port=port, share=share, show_api=api
261
+ )
262
+
263
+
264
+ if __name__ == "__main__":
265
+ main()
requirements_gradio.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ cached_path
2
+ pydub
3
+ click
test_infer_single.py CHANGED
@@ -14,7 +14,7 @@ from model.utils import (
14
  save_spectrogram,
15
  )
16
 
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
 
20
  # --------------------- Dataset Settings -------------------- #
 
14
  save_spectrogram,
15
  )
16
 
17
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
18
 
19
 
20
  # --------------------- Dataset Settings -------------------- #