Spaces:
Configuration error
Configuration error
Merge branch 'SWivid:main' into main
Browse files- README.md +18 -3
- gradio_app.py +95 -132
- inference-cli.py +111 -136
- inference-cli.toml +1 -1
- model/utils.py +6 -7
- requirements.txt +2 -8
- requirements_eval.txt +5 -0
README.md
CHANGED
@@ -1,16 +1,25 @@
|
|
1 |
# F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
[](https://github.com/SWivid/F5-TTS)
|
4 |
[](https://arxiv.org/abs/2410.06885)
|
5 |
[](https://swivid.github.io/F5-TTS/)
|
6 |
[](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
|
|
7 |
|
8 |
**F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
|
9 |
|
10 |
-
**E2 TTS**: Flat-UNet Transformer, closest reproduction.
|
11 |
|
12 |
**Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
|
13 |
|
|
|
|
|
14 |
## Installation
|
15 |
|
16 |
Clone the repository:
|
@@ -62,7 +71,7 @@ An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discuss
|
|
62 |
|
63 |
## Inference
|
64 |
|
65 |
-
|
66 |
|
67 |
Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`.
|
68 |
- To avoid possible inference failures, make sure you have seen through the following instructions.
|
@@ -148,6 +157,12 @@ bash scripts/eval_infer_batch.sh
|
|
148 |
|
149 |
### Objective Evaluation
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
**Some Notes**
|
152 |
|
153 |
For faster-whisper with CUDA 11:
|
@@ -193,4 +208,4 @@ python scripts/eval_librispeech_test_clean.py
|
|
193 |
```
|
194 |
## License
|
195 |
|
196 |
-
Our code is released under MIT License.
|
|
|
1 |
# F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
|
2 |
|
3 |
+
<div style="position: absolute; width: 100%;">
|
4 |
+
<div style="position: absolute; top: 0; right: 100px;">
|
5 |
+
<img src="https://avatars.githubusercontent.com/u/35554183?s=200&v=4" alt="Watermark" style="width: 140px; height: auto;">
|
6 |
+
</div>
|
7 |
+
</div>
|
8 |
+
|
9 |
[](https://github.com/SWivid/F5-TTS)
|
10 |
[](https://arxiv.org/abs/2410.06885)
|
11 |
[](https://swivid.github.io/F5-TTS/)
|
12 |
[](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
13 |
+
[](https://x-lance.sjtu.edu.cn/)
|
14 |
|
15 |
**F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
|
16 |
|
17 |
+
**E2 TTS**: Flat-UNet Transformer, closest reproduction from [paper](https://arxiv.org/abs/2406.18009).
|
18 |
|
19 |
**Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
|
20 |
|
21 |
+
### Thanks to all the contributors !
|
22 |
+
|
23 |
## Installation
|
24 |
|
25 |
Clone the repository:
|
|
|
71 |
|
72 |
## Inference
|
73 |
|
74 |
+
The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [⭐ Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or automatically downloaded with `inference-cli` and `gradio_app`.
|
75 |
|
76 |
Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`.
|
77 |
- To avoid possible inference failures, make sure you have seen through the following instructions.
|
|
|
157 |
|
158 |
### Objective Evaluation
|
159 |
|
160 |
+
Install packages for evaluation:
|
161 |
+
|
162 |
+
```bash
|
163 |
+
pip install -r requirements_eval.txt
|
164 |
+
```
|
165 |
+
|
166 |
**Some Notes**
|
167 |
|
168 |
For faster-whisper with CUDA 11:
|
|
|
208 |
```
|
209 |
## License
|
210 |
|
211 |
+
Our code is released under MIT License. The pre-trained models are licensed under the CC-BY-NC license due to the training data Emilia, which is an in-the-wild dataset. Sorry for any inconvenience this may cause.
|
gradio_app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import os
|
2 |
import re
|
3 |
import torch
|
4 |
import torchaudio
|
@@ -17,7 +16,6 @@ from model.utils import (
|
|
17 |
save_spectrogram,
|
18 |
)
|
19 |
from transformers import pipeline
|
20 |
-
import librosa
|
21 |
import click
|
22 |
import soundfile as sf
|
23 |
|
@@ -33,19 +31,6 @@ def gpu_decorator(func):
|
|
33 |
else:
|
34 |
return func
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
SPLIT_WORDS = [
|
39 |
-
"but", "however", "nevertheless", "yet", "still",
|
40 |
-
"therefore", "thus", "hence", "consequently",
|
41 |
-
"moreover", "furthermore", "additionally",
|
42 |
-
"meanwhile", "alternatively", "otherwise",
|
43 |
-
"namely", "specifically", "for example", "such as",
|
44 |
-
"in fact", "indeed", "notably",
|
45 |
-
"in contrast", "on the other hand", "conversely",
|
46 |
-
"in conclusion", "to summarize", "finally"
|
47 |
-
]
|
48 |
-
|
49 |
device = (
|
50 |
"cuda"
|
51 |
if torch.cuda.is_available()
|
@@ -73,7 +58,6 @@ cfg_strength = 2.0
|
|
73 |
ode_method = "euler"
|
74 |
sway_sampling_coef = -1.0
|
75 |
speed = 1.0
|
76 |
-
# fix_duration = 27 # None or float (duration in seconds)
|
77 |
fix_duration = None
|
78 |
|
79 |
|
@@ -114,104 +98,37 @@ E2TTS_ema_model = load_model(
|
|
114 |
"E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
|
115 |
)
|
116 |
|
117 |
-
def
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
word_batches = []
|
133 |
-
for word in words:
|
134 |
-
if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
|
135 |
-
current_word_part += word + ' '
|
136 |
-
else:
|
137 |
-
if current_word_part:
|
138 |
-
# Try to find a suitable split word
|
139 |
-
for split_word in split_words:
|
140 |
-
split_index = current_word_part.rfind(' ' + split_word + ' ')
|
141 |
-
if split_index != -1:
|
142 |
-
word_batches.append(current_word_part[:split_index].strip())
|
143 |
-
current_word_part = current_word_part[split_index:].strip() + ' '
|
144 |
-
break
|
145 |
-
else:
|
146 |
-
# If no suitable split word found, just append the current part
|
147 |
-
word_batches.append(current_word_part.strip())
|
148 |
-
current_word_part = ""
|
149 |
-
current_word_part += word + ' '
|
150 |
-
if current_word_part:
|
151 |
-
word_batches.append(current_word_part.strip())
|
152 |
-
return word_batches
|
153 |
|
154 |
for sentence in sentences:
|
155 |
-
if len(
|
156 |
-
|
157 |
else:
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
colon_parts = sentence.split(':')
|
167 |
-
if len(colon_parts) > 1:
|
168 |
-
for part in colon_parts:
|
169 |
-
if len(part.encode('utf-8')) <= max_chars:
|
170 |
-
batches.append(part)
|
171 |
-
else:
|
172 |
-
# If colon part is still too long, split by comma
|
173 |
-
comma_parts = re.split('[,,]', part)
|
174 |
-
if len(comma_parts) > 1:
|
175 |
-
current_comma_part = ""
|
176 |
-
for comma_part in comma_parts:
|
177 |
-
if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
|
178 |
-
current_comma_part += comma_part + ','
|
179 |
-
else:
|
180 |
-
if current_comma_part:
|
181 |
-
batches.append(current_comma_part.rstrip(','))
|
182 |
-
current_comma_part = comma_part + ','
|
183 |
-
if current_comma_part:
|
184 |
-
batches.append(current_comma_part.rstrip(','))
|
185 |
-
else:
|
186 |
-
# If no comma, split by words
|
187 |
-
batches.extend(split_by_words(part))
|
188 |
-
else:
|
189 |
-
# If no colon, split by comma
|
190 |
-
comma_parts = re.split('[,,]', sentence)
|
191 |
-
if len(comma_parts) > 1:
|
192 |
-
current_comma_part = ""
|
193 |
-
for comma_part in comma_parts:
|
194 |
-
if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
|
195 |
-
current_comma_part += comma_part + ','
|
196 |
-
else:
|
197 |
-
if current_comma_part:
|
198 |
-
batches.append(current_comma_part.rstrip(','))
|
199 |
-
current_comma_part = comma_part + ','
|
200 |
-
if current_comma_part:
|
201 |
-
batches.append(current_comma_part.rstrip(','))
|
202 |
-
else:
|
203 |
-
# If no comma, split by words
|
204 |
-
batches.extend(split_by_words(sentence))
|
205 |
-
else:
|
206 |
-
current_batch = sentence
|
207 |
-
|
208 |
-
if current_batch:
|
209 |
-
batches.append(current_batch)
|
210 |
-
|
211 |
-
return batches
|
212 |
|
213 |
@gpu_decorator
|
214 |
-
def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, progress=gr.Progress()):
|
215 |
if exp_name == "F5-TTS":
|
216 |
ema_model = F5TTS_ema_model
|
217 |
elif exp_name == "E2-TTS":
|
@@ -269,8 +186,44 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
|
|
269 |
generated_waves.append(generated_wave)
|
270 |
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
271 |
|
272 |
-
# Combine all generated waves
|
273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
# Remove silence
|
276 |
if remove_silence:
|
@@ -296,11 +249,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
|
|
296 |
return (target_sample_rate, final_wave), spectrogram_path
|
297 |
|
298 |
@gpu_decorator
|
299 |
-
def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence,
|
300 |
-
if not custom_split_words.strip():
|
301 |
-
custom_words = [word.strip() for word in custom_split_words.split(',')]
|
302 |
-
global SPLIT_WORDS
|
303 |
-
SPLIT_WORDS = custom_words
|
304 |
|
305 |
print(gen_text)
|
306 |
|
@@ -308,7 +257,9 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
|
|
308 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
309 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
310 |
|
311 |
-
non_silent_segs = silence.split_on_silence(
|
|
|
|
|
312 |
non_silent_wave = AudioSegment.silent(duration=0)
|
313 |
for non_silent_seg in non_silent_segs:
|
314 |
non_silent_wave += non_silent_seg
|
@@ -334,16 +285,25 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
|
|
334 |
else:
|
335 |
gr.Info("Using custom reference text...")
|
336 |
|
337 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
audio, sr = torchaudio.load(ref_audio)
|
339 |
-
|
340 |
-
|
|
|
|
|
341 |
print('ref_text', ref_text)
|
342 |
-
for i,
|
343 |
-
print(f'gen_text {i}',
|
344 |
|
345 |
gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
|
346 |
-
return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
|
|
|
347 |
|
348 |
@gpu_decorator
|
349 |
def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
|
@@ -448,12 +408,7 @@ with gr.Blocks() as app_tts:
|
|
448 |
remove_silence = gr.Checkbox(
|
449 |
label="Remove Silences",
|
450 |
info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
|
451 |
-
value=
|
452 |
-
)
|
453 |
-
split_words_input = gr.Textbox(
|
454 |
-
label="Custom Split Words",
|
455 |
-
info="Enter custom words to split on, separated by commas. Leave blank to use default list.",
|
456 |
-
lines=2,
|
457 |
)
|
458 |
speed_slider = gr.Slider(
|
459 |
label="Speed",
|
@@ -463,6 +418,14 @@ with gr.Blocks() as app_tts:
|
|
463 |
step=0.1,
|
464 |
info="Adjust the speed of the audio.",
|
465 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
speed_slider.change(update_speed, inputs=speed_slider)
|
467 |
|
468 |
audio_output = gr.Audio(label="Synthesized Audio")
|
@@ -476,7 +439,7 @@ with gr.Blocks() as app_tts:
|
|
476 |
gen_text_input,
|
477 |
model_choice,
|
478 |
remove_silence,
|
479 |
-
|
480 |
],
|
481 |
outputs=[audio_output, spectrogram_output],
|
482 |
)
|
@@ -724,7 +687,7 @@ with gr.Blocks() as app_emotional:
|
|
724 |
ref_text = speech_types[current_emotion].get('ref_text', '')
|
725 |
|
726 |
# Generate speech for this segment
|
727 |
-
audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence,
|
728 |
sr, audio_data = audio
|
729 |
|
730 |
generated_audio_segments.append(audio_data)
|
@@ -825,4 +788,4 @@ def main(port, host, share, api):
|
|
825 |
|
826 |
|
827 |
if __name__ == "__main__":
|
828 |
-
main()
|
|
|
|
|
1 |
import re
|
2 |
import torch
|
3 |
import torchaudio
|
|
|
16 |
save_spectrogram,
|
17 |
)
|
18 |
from transformers import pipeline
|
|
|
19 |
import click
|
20 |
import soundfile as sf
|
21 |
|
|
|
31 |
else:
|
32 |
return func
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
device = (
|
35 |
"cuda"
|
36 |
if torch.cuda.is_available()
|
|
|
58 |
ode_method = "euler"
|
59 |
sway_sampling_coef = -1.0
|
60 |
speed = 1.0
|
|
|
61 |
fix_duration = None
|
62 |
|
63 |
|
|
|
98 |
"E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
|
99 |
)
|
100 |
|
101 |
+
def chunk_text(text, max_chars=135):
|
102 |
+
"""
|
103 |
+
Splits the input text into chunks, each with a maximum number of characters.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
text (str): The text to be split.
|
107 |
+
max_chars (int): The maximum number of characters per chunk.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
List[str]: A list of text chunks.
|
111 |
+
"""
|
112 |
+
chunks = []
|
113 |
+
current_chunk = ""
|
114 |
+
# Split the text into sentences based on punctuation followed by whitespace
|
115 |
+
sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
for sentence in sentences:
|
118 |
+
if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
|
119 |
+
current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
|
120 |
else:
|
121 |
+
if current_chunk:
|
122 |
+
chunks.append(current_chunk.strip())
|
123 |
+
current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
|
124 |
+
|
125 |
+
if current_chunk:
|
126 |
+
chunks.append(current_chunk.strip())
|
127 |
+
|
128 |
+
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
@gpu_decorator
|
131 |
+
def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, cross_fade_duration=0.15, progress=gr.Progress()):
|
132 |
if exp_name == "F5-TTS":
|
133 |
ema_model = F5TTS_ema_model
|
134 |
elif exp_name == "E2-TTS":
|
|
|
186 |
generated_waves.append(generated_wave)
|
187 |
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
188 |
|
189 |
+
# Combine all generated waves with cross-fading
|
190 |
+
if cross_fade_duration <= 0:
|
191 |
+
# Simply concatenate
|
192 |
+
final_wave = np.concatenate(generated_waves)
|
193 |
+
else:
|
194 |
+
final_wave = generated_waves[0]
|
195 |
+
for i in range(1, len(generated_waves)):
|
196 |
+
prev_wave = final_wave
|
197 |
+
next_wave = generated_waves[i]
|
198 |
+
|
199 |
+
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
|
200 |
+
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
201 |
+
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
|
202 |
+
|
203 |
+
if cross_fade_samples <= 0:
|
204 |
+
# No overlap possible, concatenate
|
205 |
+
final_wave = np.concatenate([prev_wave, next_wave])
|
206 |
+
continue
|
207 |
+
|
208 |
+
# Overlapping parts
|
209 |
+
prev_overlap = prev_wave[-cross_fade_samples:]
|
210 |
+
next_overlap = next_wave[:cross_fade_samples]
|
211 |
+
|
212 |
+
# Fade out and fade in
|
213 |
+
fade_out = np.linspace(1, 0, cross_fade_samples)
|
214 |
+
fade_in = np.linspace(0, 1, cross_fade_samples)
|
215 |
+
|
216 |
+
# Cross-faded overlap
|
217 |
+
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
218 |
+
|
219 |
+
# Combine
|
220 |
+
new_wave = np.concatenate([
|
221 |
+
prev_wave[:-cross_fade_samples],
|
222 |
+
cross_faded_overlap,
|
223 |
+
next_wave[cross_fade_samples:]
|
224 |
+
])
|
225 |
+
|
226 |
+
final_wave = new_wave
|
227 |
|
228 |
# Remove silence
|
229 |
if remove_silence:
|
|
|
249 |
return (target_sample_rate, final_wave), spectrogram_path
|
250 |
|
251 |
@gpu_decorator
|
252 |
+
def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, cross_fade_duration=0.15):
|
|
|
|
|
|
|
|
|
253 |
|
254 |
print(gen_text)
|
255 |
|
|
|
257 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
258 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
259 |
|
260 |
+
non_silent_segs = silence.split_on_silence(
|
261 |
+
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000
|
262 |
+
)
|
263 |
non_silent_wave = AudioSegment.silent(duration=0)
|
264 |
for non_silent_seg in non_silent_segs:
|
265 |
non_silent_wave += non_silent_seg
|
|
|
285 |
else:
|
286 |
gr.Info("Using custom reference text...")
|
287 |
|
288 |
+
# Add the functionality to ensure it ends with ". "
|
289 |
+
if not ref_text.endswith(". "):
|
290 |
+
if ref_text.endswith("."):
|
291 |
+
ref_text += " "
|
292 |
+
else:
|
293 |
+
ref_text += ". "
|
294 |
+
|
295 |
audio, sr = torchaudio.load(ref_audio)
|
296 |
+
|
297 |
+
# Use the new chunk_text function to split gen_text
|
298 |
+
max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
|
299 |
+
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
300 |
print('ref_text', ref_text)
|
301 |
+
for i, batch_text in enumerate(gen_text_batches):
|
302 |
+
print(f'gen_text {i}', batch_text)
|
303 |
|
304 |
gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
|
305 |
+
return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence, cross_fade_duration)
|
306 |
+
|
307 |
|
308 |
@gpu_decorator
|
309 |
def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
|
|
|
408 |
remove_silence = gr.Checkbox(
|
409 |
label="Remove Silences",
|
410 |
info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
|
411 |
+
value=False,
|
|
|
|
|
|
|
|
|
|
|
412 |
)
|
413 |
speed_slider = gr.Slider(
|
414 |
label="Speed",
|
|
|
418 |
step=0.1,
|
419 |
info="Adjust the speed of the audio.",
|
420 |
)
|
421 |
+
cross_fade_duration_slider = gr.Slider(
|
422 |
+
label="Cross-Fade Duration (s)",
|
423 |
+
minimum=0.0,
|
424 |
+
maximum=1.0,
|
425 |
+
value=0.15,
|
426 |
+
step=0.01,
|
427 |
+
info="Set the duration of the cross-fade between audio clips.",
|
428 |
+
)
|
429 |
speed_slider.change(update_speed, inputs=speed_slider)
|
430 |
|
431 |
audio_output = gr.Audio(label="Synthesized Audio")
|
|
|
439 |
gen_text_input,
|
440 |
model_choice,
|
441 |
remove_silence,
|
442 |
+
cross_fade_duration_slider,
|
443 |
],
|
444 |
outputs=[audio_output, spectrogram_output],
|
445 |
)
|
|
|
687 |
ref_text = speech_types[current_emotion].get('ref_text', '')
|
688 |
|
689 |
# Generate speech for this segment
|
690 |
+
audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, 0)
|
691 |
sr, audio_data = audio
|
692 |
|
693 |
generated_audio_segments.append(audio_data)
|
|
|
788 |
|
789 |
|
790 |
if __name__ == "__main__":
|
791 |
+
main()
|
inference-cli.py
CHANGED
@@ -1,26 +1,24 @@
|
|
|
|
|
|
1 |
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
import torchaudio
|
4 |
-
import
|
5 |
-
import
|
6 |
from einops import rearrange
|
7 |
-
from vocos import Vocos
|
8 |
from pydub import AudioSegment, silence
|
9 |
-
from model import CFM, UNetT, DiT, MMDiT
|
10 |
-
from cached_path import cached_path
|
11 |
-
from model.utils import (
|
12 |
-
load_checkpoint,
|
13 |
-
get_tokenizer,
|
14 |
-
convert_char_to_pinyin,
|
15 |
-
save_spectrogram,
|
16 |
-
)
|
17 |
from transformers import pipeline
|
18 |
-
|
19 |
-
|
20 |
-
import
|
21 |
-
import
|
22 |
-
|
23 |
-
import codecs
|
24 |
|
25 |
parser = argparse.ArgumentParser(
|
26 |
prog="python3 inference-cli.py",
|
@@ -73,6 +71,11 @@ parser.add_argument(
|
|
73 |
"--remove_silence",
|
74 |
help="Remove silence.",
|
75 |
)
|
|
|
|
|
|
|
|
|
|
|
76 |
args = parser.parse_args()
|
77 |
|
78 |
config = tomli.load(open(args.config, "rb"))
|
@@ -88,24 +91,23 @@ model = args.model if args.model else config["model"]
|
|
88 |
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
89 |
wave_path = Path(output_dir)/"out.wav"
|
90 |
spectrogram_path = Path(output_dir)/"out.png"
|
91 |
-
|
92 |
-
SPLIT_WORDS = [
|
93 |
-
"but", "however", "nevertheless", "yet", "still",
|
94 |
-
"therefore", "thus", "hence", "consequently",
|
95 |
-
"moreover", "furthermore", "additionally",
|
96 |
-
"meanwhile", "alternatively", "otherwise",
|
97 |
-
"namely", "specifically", "for example", "such as",
|
98 |
-
"in fact", "indeed", "notably",
|
99 |
-
"in contrast", "on the other hand", "conversely",
|
100 |
-
"in conclusion", "to summarize", "finally"
|
101 |
-
]
|
102 |
|
103 |
device = (
|
104 |
"cuda"
|
105 |
if torch.cuda.is_available()
|
106 |
else "mps" if torch.backends.mps.is_available() else "cpu"
|
107 |
)
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
print(f"Using {device} device")
|
111 |
|
@@ -124,8 +126,9 @@ speed = 1.0
|
|
124 |
fix_duration = None
|
125 |
|
126 |
def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
|
127 |
-
ckpt_path =
|
128 |
-
|
|
|
129 |
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
130 |
model = CFM(
|
131 |
transformer=model_cls(
|
@@ -153,103 +156,36 @@ F5TTS_model_cfg = dict(
|
|
153 |
)
|
154 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
current_word_part = ""
|
171 |
-
word_batches = []
|
172 |
-
for word in words:
|
173 |
-
if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
|
174 |
-
current_word_part += word + ' '
|
175 |
-
else:
|
176 |
-
if current_word_part:
|
177 |
-
# Try to find a suitable split word
|
178 |
-
for split_word in split_words:
|
179 |
-
split_index = current_word_part.rfind(' ' + split_word + ' ')
|
180 |
-
if split_index != -1:
|
181 |
-
word_batches.append(current_word_part[:split_index].strip())
|
182 |
-
current_word_part = current_word_part[split_index:].strip() + ' '
|
183 |
-
break
|
184 |
-
else:
|
185 |
-
# If no suitable split word found, just append the current part
|
186 |
-
word_batches.append(current_word_part.strip())
|
187 |
-
current_word_part = ""
|
188 |
-
current_word_part += word + ' '
|
189 |
-
if current_word_part:
|
190 |
-
word_batches.append(current_word_part.strip())
|
191 |
-
return word_batches
|
192 |
|
193 |
for sentence in sentences:
|
194 |
-
if len(
|
195 |
-
|
196 |
else:
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
colon_parts = sentence.split(':')
|
206 |
-
if len(colon_parts) > 1:
|
207 |
-
for part in colon_parts:
|
208 |
-
if len(part.encode('utf-8')) <= max_chars:
|
209 |
-
batches.append(part)
|
210 |
-
else:
|
211 |
-
# If colon part is still too long, split by comma
|
212 |
-
comma_parts = re.split('[,,]', part)
|
213 |
-
if len(comma_parts) > 1:
|
214 |
-
current_comma_part = ""
|
215 |
-
for comma_part in comma_parts:
|
216 |
-
if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
|
217 |
-
current_comma_part += comma_part + ','
|
218 |
-
else:
|
219 |
-
if current_comma_part:
|
220 |
-
batches.append(current_comma_part.rstrip(','))
|
221 |
-
current_comma_part = comma_part + ','
|
222 |
-
if current_comma_part:
|
223 |
-
batches.append(current_comma_part.rstrip(','))
|
224 |
-
else:
|
225 |
-
# If no comma, split by words
|
226 |
-
batches.extend(split_by_words(part))
|
227 |
-
else:
|
228 |
-
# If no colon, split by comma
|
229 |
-
comma_parts = re.split('[,,]', sentence)
|
230 |
-
if len(comma_parts) > 1:
|
231 |
-
current_comma_part = ""
|
232 |
-
for comma_part in comma_parts:
|
233 |
-
if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
|
234 |
-
current_comma_part += comma_part + ','
|
235 |
-
else:
|
236 |
-
if current_comma_part:
|
237 |
-
batches.append(current_comma_part.rstrip(','))
|
238 |
-
current_comma_part = comma_part + ','
|
239 |
-
if current_comma_part:
|
240 |
-
batches.append(current_comma_part.rstrip(','))
|
241 |
-
else:
|
242 |
-
# If no comma, split by words
|
243 |
-
batches.extend(split_by_words(sentence))
|
244 |
-
else:
|
245 |
-
current_batch = sentence
|
246 |
-
|
247 |
-
if current_batch:
|
248 |
-
batches.append(current_batch)
|
249 |
-
|
250 |
-
return batches
|
251 |
|
252 |
-
|
|
|
253 |
if model == "F5-TTS":
|
254 |
ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
|
255 |
elif model == "E2-TTS":
|
@@ -307,8 +243,44 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
|
|
307 |
generated_waves.append(generated_wave)
|
308 |
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
309 |
|
310 |
-
# Combine all generated waves
|
311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
|
313 |
with open(wave_path, "wb") as f:
|
314 |
sf.write(f.name, final_wave, target_sample_rate)
|
@@ -329,11 +301,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
|
|
329 |
print(spectrogram_path)
|
330 |
|
331 |
|
332 |
-
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence,
|
333 |
-
if not custom_split_words.strip():
|
334 |
-
custom_words = [word.strip() for word in custom_split_words.split(',')]
|
335 |
-
global SPLIT_WORDS
|
336 |
-
SPLIT_WORDS = custom_words
|
337 |
|
338 |
print(gen_text)
|
339 |
|
@@ -341,7 +309,7 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
|
|
341 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
342 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
343 |
|
344 |
-
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=
|
345 |
non_silent_wave = AudioSegment.silent(duration=0)
|
346 |
for non_silent_seg in non_silent_segs:
|
347 |
non_silent_wave += non_silent_seg
|
@@ -373,16 +341,23 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
|
|
373 |
else:
|
374 |
print("Using custom reference text...")
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
# Split the input text into batches
|
377 |
audio, sr = torchaudio.load(ref_audio)
|
378 |
-
max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (
|
379 |
-
gen_text_batches =
|
380 |
print('ref_text', ref_text)
|
381 |
for i, gen_text in enumerate(gen_text_batches):
|
382 |
print(f'gen_text {i}', gen_text)
|
383 |
|
384 |
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
|
385 |
-
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
|
386 |
|
387 |
|
388 |
-
infer(ref_audio, ref_text, gen_text, model, remove_silence
|
|
|
1 |
+
import argparse
|
2 |
+
import codecs
|
3 |
import re
|
4 |
+
import tempfile
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import soundfile as sf
|
9 |
+
import tomli
|
10 |
import torch
|
11 |
import torchaudio
|
12 |
+
import tqdm
|
13 |
+
from cached_path import cached_path
|
14 |
from einops import rearrange
|
|
|
15 |
from pydub import AudioSegment, silence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
from transformers import pipeline
|
17 |
+
from vocos import Vocos
|
18 |
+
|
19 |
+
from model import CFM, DiT, MMDiT, UNetT
|
20 |
+
from model.utils import (convert_char_to_pinyin, get_tokenizer,
|
21 |
+
load_checkpoint, save_spectrogram)
|
|
|
22 |
|
23 |
parser = argparse.ArgumentParser(
|
24 |
prog="python3 inference-cli.py",
|
|
|
71 |
"--remove_silence",
|
72 |
help="Remove silence.",
|
73 |
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--load_vocoder_from_local",
|
76 |
+
action="store_true",
|
77 |
+
help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
|
78 |
+
)
|
79 |
args = parser.parse_args()
|
80 |
|
81 |
config = tomli.load(open(args.config, "rb"))
|
|
|
91 |
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
92 |
wave_path = Path(output_dir)/"out.wav"
|
93 |
spectrogram_path = Path(output_dir)/"out.png"
|
94 |
+
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
device = (
|
97 |
"cuda"
|
98 |
if torch.cuda.is_available()
|
99 |
else "mps" if torch.backends.mps.is_available() else "cpu"
|
100 |
)
|
101 |
+
|
102 |
+
if args.load_vocoder_from_local:
|
103 |
+
print(f"Load vocos from local path {vocos_local_path}")
|
104 |
+
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
105 |
+
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
|
106 |
+
vocos.load_state_dict(state_dict)
|
107 |
+
vocos.eval()
|
108 |
+
else:
|
109 |
+
print("Donwload Vocos from huggingface charactr/vocos-mel-24khz")
|
110 |
+
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
111 |
|
112 |
print(f"Using {device} device")
|
113 |
|
|
|
126 |
fix_duration = None
|
127 |
|
128 |
def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
|
129 |
+
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
|
130 |
+
if not Path(ckpt_path).exists():
|
131 |
+
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
132 |
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
133 |
model = CFM(
|
134 |
transformer=model_cls(
|
|
|
156 |
)
|
157 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
158 |
|
159 |
+
|
160 |
+
def chunk_text(text, max_chars=135):
|
161 |
+
"""
|
162 |
+
Splits the input text into chunks, each with a maximum number of characters.
|
163 |
+
Args:
|
164 |
+
text (str): The text to be split.
|
165 |
+
max_chars (int): The maximum number of characters per chunk.
|
166 |
+
Returns:
|
167 |
+
List[str]: A list of text chunks.
|
168 |
+
"""
|
169 |
+
chunks = []
|
170 |
+
current_chunk = ""
|
171 |
+
# Split the text into sentences based on punctuation followed by whitespace
|
172 |
+
sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
for sentence in sentences:
|
175 |
+
if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
|
176 |
+
current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
|
177 |
else:
|
178 |
+
if current_chunk:
|
179 |
+
chunks.append(current_chunk.strip())
|
180 |
+
current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
|
181 |
+
|
182 |
+
if current_chunk:
|
183 |
+
chunks.append(current_chunk.strip())
|
184 |
+
|
185 |
+
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
+
|
188 |
+
def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
|
189 |
if model == "F5-TTS":
|
190 |
ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
|
191 |
elif model == "E2-TTS":
|
|
|
243 |
generated_waves.append(generated_wave)
|
244 |
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
245 |
|
246 |
+
# Combine all generated waves with cross-fading
|
247 |
+
if cross_fade_duration <= 0:
|
248 |
+
# Simply concatenate
|
249 |
+
final_wave = np.concatenate(generated_waves)
|
250 |
+
else:
|
251 |
+
final_wave = generated_waves[0]
|
252 |
+
for i in range(1, len(generated_waves)):
|
253 |
+
prev_wave = final_wave
|
254 |
+
next_wave = generated_waves[i]
|
255 |
+
|
256 |
+
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
|
257 |
+
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
258 |
+
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
|
259 |
+
|
260 |
+
if cross_fade_samples <= 0:
|
261 |
+
# No overlap possible, concatenate
|
262 |
+
final_wave = np.concatenate([prev_wave, next_wave])
|
263 |
+
continue
|
264 |
+
|
265 |
+
# Overlapping parts
|
266 |
+
prev_overlap = prev_wave[-cross_fade_samples:]
|
267 |
+
next_overlap = next_wave[:cross_fade_samples]
|
268 |
+
|
269 |
+
# Fade out and fade in
|
270 |
+
fade_out = np.linspace(1, 0, cross_fade_samples)
|
271 |
+
fade_in = np.linspace(0, 1, cross_fade_samples)
|
272 |
+
|
273 |
+
# Cross-faded overlap
|
274 |
+
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
275 |
+
|
276 |
+
# Combine
|
277 |
+
new_wave = np.concatenate([
|
278 |
+
prev_wave[:-cross_fade_samples],
|
279 |
+
cross_faded_overlap,
|
280 |
+
next_wave[cross_fade_samples:]
|
281 |
+
])
|
282 |
+
|
283 |
+
final_wave = new_wave
|
284 |
|
285 |
with open(wave_path, "wb") as f:
|
286 |
sf.write(f.name, final_wave, target_sample_rate)
|
|
|
301 |
print(spectrogram_path)
|
302 |
|
303 |
|
304 |
+
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
|
|
|
|
|
|
|
|
|
305 |
|
306 |
print(gen_text)
|
307 |
|
|
|
309 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
310 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
311 |
|
312 |
+
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
|
313 |
non_silent_wave = AudioSegment.silent(duration=0)
|
314 |
for non_silent_seg in non_silent_segs:
|
315 |
non_silent_wave += non_silent_seg
|
|
|
341 |
else:
|
342 |
print("Using custom reference text...")
|
343 |
|
344 |
+
# Add the functionality to ensure it ends with ". "
|
345 |
+
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
|
346 |
+
if ref_text.endswith("."):
|
347 |
+
ref_text += " "
|
348 |
+
else:
|
349 |
+
ref_text += ". "
|
350 |
+
|
351 |
# Split the input text into batches
|
352 |
audio, sr = torchaudio.load(ref_audio)
|
353 |
+
max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
|
354 |
+
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
355 |
print('ref_text', ref_text)
|
356 |
for i, gen_text in enumerate(gen_text_batches):
|
357 |
print(f'gen_text {i}', gen_text)
|
358 |
|
359 |
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
|
360 |
+
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
|
361 |
|
362 |
|
363 |
+
infer(ref_audio, ref_text, gen_text, model, remove_silence)
|
inference-cli.toml
CHANGED
@@ -6,5 +6,5 @@ ref_text = "Some call me nature, others call me mother nature."
|
|
6 |
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
|
7 |
# File with text to generate. Ignores the text above.
|
8 |
gen_file = ""
|
9 |
-
remove_silence =
|
10 |
output_dir = "tests"
|
|
|
6 |
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
|
7 |
# File with text to generate. Ignores the text above.
|
8 |
gen_file = ""
|
9 |
+
remove_silence = false
|
10 |
output_dir = "tests"
|
model/utils.py
CHANGED
@@ -22,12 +22,6 @@ from einops import rearrange, reduce
|
|
22 |
|
23 |
import jieba
|
24 |
from pypinyin import lazy_pinyin, Style
|
25 |
-
import zhconv
|
26 |
-
from zhon.hanzi import punctuation
|
27 |
-
from jiwer import compute_measures
|
28 |
-
|
29 |
-
from funasr import AutoModel
|
30 |
-
from faster_whisper import WhisperModel
|
31 |
|
32 |
from model.ecapa_tdnn import ECAPA_TDNN_SMALL
|
33 |
from model.modules import MelSpec
|
@@ -432,6 +426,7 @@ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path
|
|
432 |
|
433 |
def load_asr_model(lang, ckpt_dir = ""):
|
434 |
if lang == "zh":
|
|
|
435 |
model = AutoModel(
|
436 |
model = os.path.join(ckpt_dir, "paraformer-zh"),
|
437 |
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
|
@@ -440,6 +435,7 @@ def load_asr_model(lang, ckpt_dir = ""):
|
|
440 |
disable_update=True,
|
441 |
) # following seed-tts setting
|
442 |
elif lang == "en":
|
|
|
443 |
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
|
444 |
model = WhisperModel(model_size, device="cuda", compute_type="float16")
|
445 |
return model
|
@@ -451,6 +447,7 @@ def run_asr_wer(args):
|
|
451 |
rank, lang, test_set, ckpt_dir = args
|
452 |
|
453 |
if lang == "zh":
|
|
|
454 |
torch.cuda.set_device(rank)
|
455 |
elif lang == "en":
|
456 |
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
@@ -458,10 +455,12 @@ def run_asr_wer(args):
|
|
458 |
raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
|
459 |
|
460 |
asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
|
461 |
-
|
|
|
462 |
punctuation_all = punctuation + string.punctuation
|
463 |
wers = []
|
464 |
|
|
|
465 |
for gen_wav, prompt_wav, truth in tqdm(test_set):
|
466 |
if lang == "zh":
|
467 |
res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
|
|
|
22 |
|
23 |
import jieba
|
24 |
from pypinyin import lazy_pinyin, Style
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
from model.ecapa_tdnn import ECAPA_TDNN_SMALL
|
27 |
from model.modules import MelSpec
|
|
|
426 |
|
427 |
def load_asr_model(lang, ckpt_dir = ""):
|
428 |
if lang == "zh":
|
429 |
+
from funasr import AutoModel
|
430 |
model = AutoModel(
|
431 |
model = os.path.join(ckpt_dir, "paraformer-zh"),
|
432 |
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
|
|
|
435 |
disable_update=True,
|
436 |
) # following seed-tts setting
|
437 |
elif lang == "en":
|
438 |
+
from faster_whisper import WhisperModel
|
439 |
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
|
440 |
model = WhisperModel(model_size, device="cuda", compute_type="float16")
|
441 |
return model
|
|
|
447 |
rank, lang, test_set, ckpt_dir = args
|
448 |
|
449 |
if lang == "zh":
|
450 |
+
import zhconv
|
451 |
torch.cuda.set_device(rank)
|
452 |
elif lang == "en":
|
453 |
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
|
|
455 |
raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
|
456 |
|
457 |
asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
|
458 |
+
|
459 |
+
from zhon.hanzi import punctuation
|
460 |
punctuation_all = punctuation + string.punctuation
|
461 |
wers = []
|
462 |
|
463 |
+
from jiwer import compute_measures
|
464 |
for gen_wav, prompt_wav, truth in tqdm(test_set):
|
465 |
if lang == "zh":
|
466 |
res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
|
requirements.txt
CHANGED
@@ -5,25 +5,19 @@ datasets
|
|
5 |
einops>=0.8.0
|
6 |
einx>=0.3.0
|
7 |
ema_pytorch>=0.5.2
|
8 |
-
faster_whisper
|
9 |
-
funasr
|
10 |
gradio
|
11 |
jieba
|
12 |
-
jiwer
|
13 |
librosa
|
14 |
matplotlib
|
15 |
-
numpy
|
16 |
pydub
|
17 |
pypinyin
|
18 |
safetensors
|
19 |
soundfile
|
20 |
-
|
21 |
-
# torchaudio>=2.3.0
|
22 |
torchdiffeq
|
23 |
tqdm>=4.65.0
|
24 |
transformers
|
25 |
vocos
|
26 |
wandb
|
27 |
x_transformers>=1.31.14
|
28 |
-
zhconv
|
29 |
-
zhon
|
|
|
5 |
einops>=0.8.0
|
6 |
einx>=0.3.0
|
7 |
ema_pytorch>=0.5.2
|
|
|
|
|
8 |
gradio
|
9 |
jieba
|
|
|
10 |
librosa
|
11 |
matplotlib
|
12 |
+
numpy<=1.26.4
|
13 |
pydub
|
14 |
pypinyin
|
15 |
safetensors
|
16 |
soundfile
|
17 |
+
tomli
|
|
|
18 |
torchdiffeq
|
19 |
tqdm>=4.65.0
|
20 |
transformers
|
21 |
vocos
|
22 |
wandb
|
23 |
x_transformers>=1.31.14
|
|
|
|
requirements_eval.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
faster_whisper
|
2 |
+
funasr
|
3 |
+
jiwer
|
4 |
+
zhconv
|
5 |
+
zhon
|