Jarod Mica commited on
Commit
2a521e9
·
unverified ·
2 Parent(s): 31e5051 bdc76f5

Merge branch 'SWivid:main' into main

Browse files
Files changed (7) hide show
  1. README.md +18 -3
  2. gradio_app.py +95 -132
  3. inference-cli.py +111 -136
  4. inference-cli.toml +1 -1
  5. model/utils.py +6 -7
  6. requirements.txt +2 -8
  7. requirements_eval.txt +5 -0
README.md CHANGED
@@ -1,16 +1,25 @@
1
  # F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
2
 
 
 
 
 
 
 
3
  [![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS)
4
  [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
5
  [![demo](https://img.shields.io/badge/GitHub-Demo%20page-blue.svg)](https://swivid.github.io/F5-TTS/)
6
  [![space](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
 
7
 
8
  **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
9
 
10
- **E2 TTS**: Flat-UNet Transformer, closest reproduction.
11
 
12
  **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
13
 
 
 
14
  ## Installation
15
 
16
  Clone the repository:
@@ -62,7 +71,7 @@ An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discuss
62
 
63
  ## Inference
64
 
65
- To run inference with pretrained models, download the checkpoints from [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), or automatically downloaded with `inference-cli` and `gradio_app`.
66
 
67
  Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`.
68
  - To avoid possible inference failures, make sure you have seen through the following instructions.
@@ -148,6 +157,12 @@ bash scripts/eval_infer_batch.sh
148
 
149
  ### Objective Evaluation
150
 
 
 
 
 
 
 
151
  **Some Notes**
152
 
153
  For faster-whisper with CUDA 11:
@@ -193,4 +208,4 @@ python scripts/eval_librispeech_test_clean.py
193
  ```
194
  ## License
195
 
196
- Our code is released under MIT License.
 
1
  # F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
2
 
3
+ <div style="position: absolute; width: 100%;">
4
+ <div style="position: absolute; top: 0; right: 100px;">
5
+ <img src="https://avatars.githubusercontent.com/u/35554183?s=200&v=4" alt="Watermark" style="width: 140px; height: auto;">
6
+ </div>
7
+ </div>
8
+
9
  [![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS)
10
  [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
11
  [![demo](https://img.shields.io/badge/GitHub-Demo%20page-blue.svg)](https://swivid.github.io/F5-TTS/)
12
  [![space](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
13
+ [![lab](https://img.shields.io/badge/X--LANCE-Lab-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/)
14
 
15
  **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
16
 
17
+ **E2 TTS**: Flat-UNet Transformer, closest reproduction from [paper](https://arxiv.org/abs/2406.18009).
18
 
19
  **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
20
 
21
+ ### Thanks to all the contributors !
22
+
23
  ## Installation
24
 
25
  Clone the repository:
 
71
 
72
  ## Inference
73
 
74
+ The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [⭐ Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or automatically downloaded with `inference-cli` and `gradio_app`.
75
 
76
  Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`.
77
  - To avoid possible inference failures, make sure you have seen through the following instructions.
 
157
 
158
  ### Objective Evaluation
159
 
160
+ Install packages for evaluation:
161
+
162
+ ```bash
163
+ pip install -r requirements_eval.txt
164
+ ```
165
+
166
  **Some Notes**
167
 
168
  For faster-whisper with CUDA 11:
 
208
  ```
209
  ## License
210
 
211
+ Our code is released under MIT License. The pre-trained models are licensed under the CC-BY-NC license due to the training data Emilia, which is an in-the-wild dataset. Sorry for any inconvenience this may cause.
gradio_app.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import re
3
  import torch
4
  import torchaudio
@@ -17,7 +16,6 @@ from model.utils import (
17
  save_spectrogram,
18
  )
19
  from transformers import pipeline
20
- import librosa
21
  import click
22
  import soundfile as sf
23
 
@@ -33,19 +31,6 @@ def gpu_decorator(func):
33
  else:
34
  return func
35
 
36
-
37
-
38
- SPLIT_WORDS = [
39
- "but", "however", "nevertheless", "yet", "still",
40
- "therefore", "thus", "hence", "consequently",
41
- "moreover", "furthermore", "additionally",
42
- "meanwhile", "alternatively", "otherwise",
43
- "namely", "specifically", "for example", "such as",
44
- "in fact", "indeed", "notably",
45
- "in contrast", "on the other hand", "conversely",
46
- "in conclusion", "to summarize", "finally"
47
- ]
48
-
49
  device = (
50
  "cuda"
51
  if torch.cuda.is_available()
@@ -73,7 +58,6 @@ cfg_strength = 2.0
73
  ode_method = "euler"
74
  sway_sampling_coef = -1.0
75
  speed = 1.0
76
- # fix_duration = 27 # None or float (duration in seconds)
77
  fix_duration = None
78
 
79
 
@@ -114,104 +98,37 @@ E2TTS_ema_model = load_model(
114
  "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
115
  )
116
 
117
- def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
118
- if len(text.encode('utf-8')) <= max_chars:
119
- return [text]
120
- if text[-1] not in ['。', '.', '!', '!', '?', '?']:
121
- text += '.'
122
-
123
- sentences = re.split('([。.!?!?])', text)
124
- sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
125
-
126
- batches = []
127
- current_batch = ""
128
-
129
- def split_by_words(text):
130
- words = text.split()
131
- current_word_part = ""
132
- word_batches = []
133
- for word in words:
134
- if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
135
- current_word_part += word + ' '
136
- else:
137
- if current_word_part:
138
- # Try to find a suitable split word
139
- for split_word in split_words:
140
- split_index = current_word_part.rfind(' ' + split_word + ' ')
141
- if split_index != -1:
142
- word_batches.append(current_word_part[:split_index].strip())
143
- current_word_part = current_word_part[split_index:].strip() + ' '
144
- break
145
- else:
146
- # If no suitable split word found, just append the current part
147
- word_batches.append(current_word_part.strip())
148
- current_word_part = ""
149
- current_word_part += word + ' '
150
- if current_word_part:
151
- word_batches.append(current_word_part.strip())
152
- return word_batches
153
 
154
  for sentence in sentences:
155
- if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
156
- current_batch += sentence
157
  else:
158
- # If adding this sentence would exceed the limit
159
- if current_batch:
160
- batches.append(current_batch)
161
- current_batch = ""
162
-
163
- # If the sentence itself is longer than max_chars, split it
164
- if len(sentence.encode('utf-8')) > max_chars:
165
- # First, try to split by colon
166
- colon_parts = sentence.split(':')
167
- if len(colon_parts) > 1:
168
- for part in colon_parts:
169
- if len(part.encode('utf-8')) <= max_chars:
170
- batches.append(part)
171
- else:
172
- # If colon part is still too long, split by comma
173
- comma_parts = re.split('[,,]', part)
174
- if len(comma_parts) > 1:
175
- current_comma_part = ""
176
- for comma_part in comma_parts:
177
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
178
- current_comma_part += comma_part + ','
179
- else:
180
- if current_comma_part:
181
- batches.append(current_comma_part.rstrip(','))
182
- current_comma_part = comma_part + ','
183
- if current_comma_part:
184
- batches.append(current_comma_part.rstrip(','))
185
- else:
186
- # If no comma, split by words
187
- batches.extend(split_by_words(part))
188
- else:
189
- # If no colon, split by comma
190
- comma_parts = re.split('[,,]', sentence)
191
- if len(comma_parts) > 1:
192
- current_comma_part = ""
193
- for comma_part in comma_parts:
194
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
195
- current_comma_part += comma_part + ','
196
- else:
197
- if current_comma_part:
198
- batches.append(current_comma_part.rstrip(','))
199
- current_comma_part = comma_part + ','
200
- if current_comma_part:
201
- batches.append(current_comma_part.rstrip(','))
202
- else:
203
- # If no comma, split by words
204
- batches.extend(split_by_words(sentence))
205
- else:
206
- current_batch = sentence
207
-
208
- if current_batch:
209
- batches.append(current_batch)
210
-
211
- return batches
212
 
213
  @gpu_decorator
214
- def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, progress=gr.Progress()):
215
  if exp_name == "F5-TTS":
216
  ema_model = F5TTS_ema_model
217
  elif exp_name == "E2-TTS":
@@ -269,8 +186,44 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
269
  generated_waves.append(generated_wave)
270
  spectrograms.append(generated_mel_spec[0].cpu().numpy())
271
 
272
- # Combine all generated waves
273
- final_wave = np.concatenate(generated_waves)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
  # Remove silence
276
  if remove_silence:
@@ -296,11 +249,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
296
  return (target_sample_rate, final_wave), spectrogram_path
297
 
298
  @gpu_decorator
299
- def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words=''):
300
- if not custom_split_words.strip():
301
- custom_words = [word.strip() for word in custom_split_words.split(',')]
302
- global SPLIT_WORDS
303
- SPLIT_WORDS = custom_words
304
 
305
  print(gen_text)
306
 
@@ -308,7 +257,9 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
308
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
309
  aseg = AudioSegment.from_file(ref_audio_orig)
310
 
311
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
 
 
312
  non_silent_wave = AudioSegment.silent(duration=0)
313
  for non_silent_seg in non_silent_segs:
314
  non_silent_wave += non_silent_seg
@@ -334,16 +285,25 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
334
  else:
335
  gr.Info("Using custom reference text...")
336
 
337
- # Split the input text into batches
 
 
 
 
 
 
338
  audio, sr = torchaudio.load(ref_audio)
339
- max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
340
- gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
 
 
341
  print('ref_text', ref_text)
342
- for i, gen_text in enumerate(gen_text_batches):
343
- print(f'gen_text {i}', gen_text)
344
 
345
  gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
346
- return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
 
347
 
348
  @gpu_decorator
349
  def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
@@ -448,12 +408,7 @@ with gr.Blocks() as app_tts:
448
  remove_silence = gr.Checkbox(
449
  label="Remove Silences",
450
  info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
451
- value=True,
452
- )
453
- split_words_input = gr.Textbox(
454
- label="Custom Split Words",
455
- info="Enter custom words to split on, separated by commas. Leave blank to use default list.",
456
- lines=2,
457
  )
458
  speed_slider = gr.Slider(
459
  label="Speed",
@@ -463,6 +418,14 @@ with gr.Blocks() as app_tts:
463
  step=0.1,
464
  info="Adjust the speed of the audio.",
465
  )
 
 
 
 
 
 
 
 
466
  speed_slider.change(update_speed, inputs=speed_slider)
467
 
468
  audio_output = gr.Audio(label="Synthesized Audio")
@@ -476,7 +439,7 @@ with gr.Blocks() as app_tts:
476
  gen_text_input,
477
  model_choice,
478
  remove_silence,
479
- split_words_input,
480
  ],
481
  outputs=[audio_output, spectrogram_output],
482
  )
@@ -724,7 +687,7 @@ with gr.Blocks() as app_emotional:
724
  ref_text = speech_types[current_emotion].get('ref_text', '')
725
 
726
  # Generate speech for this segment
727
- audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, "")
728
  sr, audio_data = audio
729
 
730
  generated_audio_segments.append(audio_data)
@@ -825,4 +788,4 @@ def main(port, host, share, api):
825
 
826
 
827
  if __name__ == "__main__":
828
- main()
 
 
1
  import re
2
  import torch
3
  import torchaudio
 
16
  save_spectrogram,
17
  )
18
  from transformers import pipeline
 
19
  import click
20
  import soundfile as sf
21
 
 
31
  else:
32
  return func
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  device = (
35
  "cuda"
36
  if torch.cuda.is_available()
 
58
  ode_method = "euler"
59
  sway_sampling_coef = -1.0
60
  speed = 1.0
 
61
  fix_duration = None
62
 
63
 
 
98
  "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
99
  )
100
 
101
+ def chunk_text(text, max_chars=135):
102
+ """
103
+ Splits the input text into chunks, each with a maximum number of characters.
104
+
105
+ Args:
106
+ text (str): The text to be split.
107
+ max_chars (int): The maximum number of characters per chunk.
108
+
109
+ Returns:
110
+ List[str]: A list of text chunks.
111
+ """
112
+ chunks = []
113
+ current_chunk = ""
114
+ # Split the text into sentences based on punctuation followed by whitespace
115
+ sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  for sentence in sentences:
118
+ if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
119
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
120
  else:
121
+ if current_chunk:
122
+ chunks.append(current_chunk.strip())
123
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
124
+
125
+ if current_chunk:
126
+ chunks.append(current_chunk.strip())
127
+
128
+ return chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  @gpu_decorator
131
+ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, cross_fade_duration=0.15, progress=gr.Progress()):
132
  if exp_name == "F5-TTS":
133
  ema_model = F5TTS_ema_model
134
  elif exp_name == "E2-TTS":
 
186
  generated_waves.append(generated_wave)
187
  spectrograms.append(generated_mel_spec[0].cpu().numpy())
188
 
189
+ # Combine all generated waves with cross-fading
190
+ if cross_fade_duration <= 0:
191
+ # Simply concatenate
192
+ final_wave = np.concatenate(generated_waves)
193
+ else:
194
+ final_wave = generated_waves[0]
195
+ for i in range(1, len(generated_waves)):
196
+ prev_wave = final_wave
197
+ next_wave = generated_waves[i]
198
+
199
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
200
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
201
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
202
+
203
+ if cross_fade_samples <= 0:
204
+ # No overlap possible, concatenate
205
+ final_wave = np.concatenate([prev_wave, next_wave])
206
+ continue
207
+
208
+ # Overlapping parts
209
+ prev_overlap = prev_wave[-cross_fade_samples:]
210
+ next_overlap = next_wave[:cross_fade_samples]
211
+
212
+ # Fade out and fade in
213
+ fade_out = np.linspace(1, 0, cross_fade_samples)
214
+ fade_in = np.linspace(0, 1, cross_fade_samples)
215
+
216
+ # Cross-faded overlap
217
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
218
+
219
+ # Combine
220
+ new_wave = np.concatenate([
221
+ prev_wave[:-cross_fade_samples],
222
+ cross_faded_overlap,
223
+ next_wave[cross_fade_samples:]
224
+ ])
225
+
226
+ final_wave = new_wave
227
 
228
  # Remove silence
229
  if remove_silence:
 
249
  return (target_sample_rate, final_wave), spectrogram_path
250
 
251
  @gpu_decorator
252
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, cross_fade_duration=0.15):
 
 
 
 
253
 
254
  print(gen_text)
255
 
 
257
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
258
  aseg = AudioSegment.from_file(ref_audio_orig)
259
 
260
+ non_silent_segs = silence.split_on_silence(
261
+ aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000
262
+ )
263
  non_silent_wave = AudioSegment.silent(duration=0)
264
  for non_silent_seg in non_silent_segs:
265
  non_silent_wave += non_silent_seg
 
285
  else:
286
  gr.Info("Using custom reference text...")
287
 
288
+ # Add the functionality to ensure it ends with ". "
289
+ if not ref_text.endswith(". "):
290
+ if ref_text.endswith("."):
291
+ ref_text += " "
292
+ else:
293
+ ref_text += ". "
294
+
295
  audio, sr = torchaudio.load(ref_audio)
296
+
297
+ # Use the new chunk_text function to split gen_text
298
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
299
+ gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
300
  print('ref_text', ref_text)
301
+ for i, batch_text in enumerate(gen_text_batches):
302
+ print(f'gen_text {i}', batch_text)
303
 
304
  gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
305
+ return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence, cross_fade_duration)
306
+
307
 
308
  @gpu_decorator
309
  def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
 
408
  remove_silence = gr.Checkbox(
409
  label="Remove Silences",
410
  info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
411
+ value=False,
 
 
 
 
 
412
  )
413
  speed_slider = gr.Slider(
414
  label="Speed",
 
418
  step=0.1,
419
  info="Adjust the speed of the audio.",
420
  )
421
+ cross_fade_duration_slider = gr.Slider(
422
+ label="Cross-Fade Duration (s)",
423
+ minimum=0.0,
424
+ maximum=1.0,
425
+ value=0.15,
426
+ step=0.01,
427
+ info="Set the duration of the cross-fade between audio clips.",
428
+ )
429
  speed_slider.change(update_speed, inputs=speed_slider)
430
 
431
  audio_output = gr.Audio(label="Synthesized Audio")
 
439
  gen_text_input,
440
  model_choice,
441
  remove_silence,
442
+ cross_fade_duration_slider,
443
  ],
444
  outputs=[audio_output, spectrogram_output],
445
  )
 
687
  ref_text = speech_types[current_emotion].get('ref_text', '')
688
 
689
  # Generate speech for this segment
690
+ audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, 0)
691
  sr, audio_data = audio
692
 
693
  generated_audio_segments.append(audio_data)
 
788
 
789
 
790
  if __name__ == "__main__":
791
+ main()
inference-cli.py CHANGED
@@ -1,26 +1,24 @@
 
 
1
  import re
 
 
 
 
 
 
2
  import torch
3
  import torchaudio
4
- import numpy as np
5
- import tempfile
6
  from einops import rearrange
7
- from vocos import Vocos
8
  from pydub import AudioSegment, silence
9
- from model import CFM, UNetT, DiT, MMDiT
10
- from cached_path import cached_path
11
- from model.utils import (
12
- load_checkpoint,
13
- get_tokenizer,
14
- convert_char_to_pinyin,
15
- save_spectrogram,
16
- )
17
  from transformers import pipeline
18
- import soundfile as sf
19
- import tomli
20
- import argparse
21
- import tqdm
22
- from pathlib import Path
23
- import codecs
24
 
25
  parser = argparse.ArgumentParser(
26
  prog="python3 inference-cli.py",
@@ -73,6 +71,11 @@ parser.add_argument(
73
  "--remove_silence",
74
  help="Remove silence.",
75
  )
 
 
 
 
 
76
  args = parser.parse_args()
77
 
78
  config = tomli.load(open(args.config, "rb"))
@@ -88,24 +91,23 @@ model = args.model if args.model else config["model"]
88
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
89
  wave_path = Path(output_dir)/"out.wav"
90
  spectrogram_path = Path(output_dir)/"out.png"
91
-
92
- SPLIT_WORDS = [
93
- "but", "however", "nevertheless", "yet", "still",
94
- "therefore", "thus", "hence", "consequently",
95
- "moreover", "furthermore", "additionally",
96
- "meanwhile", "alternatively", "otherwise",
97
- "namely", "specifically", "for example", "such as",
98
- "in fact", "indeed", "notably",
99
- "in contrast", "on the other hand", "conversely",
100
- "in conclusion", "to summarize", "finally"
101
- ]
102
 
103
  device = (
104
  "cuda"
105
  if torch.cuda.is_available()
106
  else "mps" if torch.backends.mps.is_available() else "cpu"
107
  )
108
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
 
 
 
 
 
 
 
 
 
109
 
110
  print(f"Using {device} device")
111
 
@@ -124,8 +126,9 @@ speed = 1.0
124
  fix_duration = None
125
 
126
  def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
127
- ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
128
- # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
 
129
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
130
  model = CFM(
131
  transformer=model_cls(
@@ -153,103 +156,36 @@ F5TTS_model_cfg = dict(
153
  )
154
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
155
 
156
- def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
157
- if len(text.encode('utf-8')) <= max_chars:
158
- return [text]
159
- if text[-1] not in ['。', '.', '!', '!', '?', '?']:
160
- text += '.'
161
-
162
- sentences = re.split('([。.!?!?])', text)
163
- sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
164
-
165
- batches = []
166
- current_batch = ""
167
-
168
- def split_by_words(text):
169
- words = text.split()
170
- current_word_part = ""
171
- word_batches = []
172
- for word in words:
173
- if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
174
- current_word_part += word + ' '
175
- else:
176
- if current_word_part:
177
- # Try to find a suitable split word
178
- for split_word in split_words:
179
- split_index = current_word_part.rfind(' ' + split_word + ' ')
180
- if split_index != -1:
181
- word_batches.append(current_word_part[:split_index].strip())
182
- current_word_part = current_word_part[split_index:].strip() + ' '
183
- break
184
- else:
185
- # If no suitable split word found, just append the current part
186
- word_batches.append(current_word_part.strip())
187
- current_word_part = ""
188
- current_word_part += word + ' '
189
- if current_word_part:
190
- word_batches.append(current_word_part.strip())
191
- return word_batches
192
 
193
  for sentence in sentences:
194
- if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
195
- current_batch += sentence
196
  else:
197
- # If adding this sentence would exceed the limit
198
- if current_batch:
199
- batches.append(current_batch)
200
- current_batch = ""
201
-
202
- # If the sentence itself is longer than max_chars, split it
203
- if len(sentence.encode('utf-8')) > max_chars:
204
- # First, try to split by colon
205
- colon_parts = sentence.split(':')
206
- if len(colon_parts) > 1:
207
- for part in colon_parts:
208
- if len(part.encode('utf-8')) <= max_chars:
209
- batches.append(part)
210
- else:
211
- # If colon part is still too long, split by comma
212
- comma_parts = re.split('[,,]', part)
213
- if len(comma_parts) > 1:
214
- current_comma_part = ""
215
- for comma_part in comma_parts:
216
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
217
- current_comma_part += comma_part + ','
218
- else:
219
- if current_comma_part:
220
- batches.append(current_comma_part.rstrip(','))
221
- current_comma_part = comma_part + ','
222
- if current_comma_part:
223
- batches.append(current_comma_part.rstrip(','))
224
- else:
225
- # If no comma, split by words
226
- batches.extend(split_by_words(part))
227
- else:
228
- # If no colon, split by comma
229
- comma_parts = re.split('[,,]', sentence)
230
- if len(comma_parts) > 1:
231
- current_comma_part = ""
232
- for comma_part in comma_parts:
233
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
234
- current_comma_part += comma_part + ','
235
- else:
236
- if current_comma_part:
237
- batches.append(current_comma_part.rstrip(','))
238
- current_comma_part = comma_part + ','
239
- if current_comma_part:
240
- batches.append(current_comma_part.rstrip(','))
241
- else:
242
- # If no comma, split by words
243
- batches.extend(split_by_words(sentence))
244
- else:
245
- current_batch = sentence
246
-
247
- if current_batch:
248
- batches.append(current_batch)
249
-
250
- return batches
251
 
252
- def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
 
253
  if model == "F5-TTS":
254
  ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
255
  elif model == "E2-TTS":
@@ -307,8 +243,44 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
307
  generated_waves.append(generated_wave)
308
  spectrograms.append(generated_mel_spec[0].cpu().numpy())
309
 
310
- # Combine all generated waves
311
- final_wave = np.concatenate(generated_waves)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
  with open(wave_path, "wb") as f:
314
  sf.write(f.name, final_wave, target_sample_rate)
@@ -329,11 +301,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
329
  print(spectrogram_path)
330
 
331
 
332
- def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_split_words):
333
- if not custom_split_words.strip():
334
- custom_words = [word.strip() for word in custom_split_words.split(',')]
335
- global SPLIT_WORDS
336
- SPLIT_WORDS = custom_words
337
 
338
  print(gen_text)
339
 
@@ -341,7 +309,7 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
341
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
342
  aseg = AudioSegment.from_file(ref_audio_orig)
343
 
344
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
345
  non_silent_wave = AudioSegment.silent(duration=0)
346
  for non_silent_seg in non_silent_segs:
347
  non_silent_wave += non_silent_seg
@@ -373,16 +341,23 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
373
  else:
374
  print("Using custom reference text...")
375
 
 
 
 
 
 
 
 
376
  # Split the input text into batches
377
  audio, sr = torchaudio.load(ref_audio)
378
- max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
379
- gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
380
  print('ref_text', ref_text)
381
  for i, gen_text in enumerate(gen_text_batches):
382
  print(f'gen_text {i}', gen_text)
383
 
384
  print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
385
- return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
386
 
387
 
388
- infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))
 
1
+ import argparse
2
+ import codecs
3
  import re
4
+ import tempfile
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import soundfile as sf
9
+ import tomli
10
  import torch
11
  import torchaudio
12
+ import tqdm
13
+ from cached_path import cached_path
14
  from einops import rearrange
 
15
  from pydub import AudioSegment, silence
 
 
 
 
 
 
 
 
16
  from transformers import pipeline
17
+ from vocos import Vocos
18
+
19
+ from model import CFM, DiT, MMDiT, UNetT
20
+ from model.utils import (convert_char_to_pinyin, get_tokenizer,
21
+ load_checkpoint, save_spectrogram)
 
22
 
23
  parser = argparse.ArgumentParser(
24
  prog="python3 inference-cli.py",
 
71
  "--remove_silence",
72
  help="Remove silence.",
73
  )
74
+ parser.add_argument(
75
+ "--load_vocoder_from_local",
76
+ action="store_true",
77
+ help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
78
+ )
79
  args = parser.parse_args()
80
 
81
  config = tomli.load(open(args.config, "rb"))
 
91
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
92
  wave_path = Path(output_dir)/"out.wav"
93
  spectrogram_path = Path(output_dir)/"out.png"
94
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
 
 
 
 
 
 
 
 
 
 
95
 
96
  device = (
97
  "cuda"
98
  if torch.cuda.is_available()
99
  else "mps" if torch.backends.mps.is_available() else "cpu"
100
  )
101
+
102
+ if args.load_vocoder_from_local:
103
+ print(f"Load vocos from local path {vocos_local_path}")
104
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
105
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
106
+ vocos.load_state_dict(state_dict)
107
+ vocos.eval()
108
+ else:
109
+ print("Donwload Vocos from huggingface charactr/vocos-mel-24khz")
110
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
111
 
112
  print(f"Using {device} device")
113
 
 
126
  fix_duration = None
127
 
128
  def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
129
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
130
+ if not Path(ckpt_path).exists():
131
+ ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
132
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
133
  model = CFM(
134
  transformer=model_cls(
 
156
  )
157
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
158
 
159
+
160
+ def chunk_text(text, max_chars=135):
161
+ """
162
+ Splits the input text into chunks, each with a maximum number of characters.
163
+ Args:
164
+ text (str): The text to be split.
165
+ max_chars (int): The maximum number of characters per chunk.
166
+ Returns:
167
+ List[str]: A list of text chunks.
168
+ """
169
+ chunks = []
170
+ current_chunk = ""
171
+ # Split the text into sentences based on punctuation followed by whitespace
172
+ sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  for sentence in sentences:
175
+ if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
176
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
177
  else:
178
+ if current_chunk:
179
+ chunks.append(current_chunk.strip())
180
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
181
+
182
+ if current_chunk:
183
+ chunks.append(current_chunk.strip())
184
+
185
+ return chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+
188
+ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
189
  if model == "F5-TTS":
190
  ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
191
  elif model == "E2-TTS":
 
243
  generated_waves.append(generated_wave)
244
  spectrograms.append(generated_mel_spec[0].cpu().numpy())
245
 
246
+ # Combine all generated waves with cross-fading
247
+ if cross_fade_duration <= 0:
248
+ # Simply concatenate
249
+ final_wave = np.concatenate(generated_waves)
250
+ else:
251
+ final_wave = generated_waves[0]
252
+ for i in range(1, len(generated_waves)):
253
+ prev_wave = final_wave
254
+ next_wave = generated_waves[i]
255
+
256
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
257
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
258
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
259
+
260
+ if cross_fade_samples <= 0:
261
+ # No overlap possible, concatenate
262
+ final_wave = np.concatenate([prev_wave, next_wave])
263
+ continue
264
+
265
+ # Overlapping parts
266
+ prev_overlap = prev_wave[-cross_fade_samples:]
267
+ next_overlap = next_wave[:cross_fade_samples]
268
+
269
+ # Fade out and fade in
270
+ fade_out = np.linspace(1, 0, cross_fade_samples)
271
+ fade_in = np.linspace(0, 1, cross_fade_samples)
272
+
273
+ # Cross-faded overlap
274
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
275
+
276
+ # Combine
277
+ new_wave = np.concatenate([
278
+ prev_wave[:-cross_fade_samples],
279
+ cross_faded_overlap,
280
+ next_wave[cross_fade_samples:]
281
+ ])
282
+
283
+ final_wave = new_wave
284
 
285
  with open(wave_path, "wb") as f:
286
  sf.write(f.name, final_wave, target_sample_rate)
 
301
  print(spectrogram_path)
302
 
303
 
304
+ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
 
 
 
 
305
 
306
  print(gen_text)
307
 
 
309
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
310
  aseg = AudioSegment.from_file(ref_audio_orig)
311
 
312
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
313
  non_silent_wave = AudioSegment.silent(duration=0)
314
  for non_silent_seg in non_silent_segs:
315
  non_silent_wave += non_silent_seg
 
341
  else:
342
  print("Using custom reference text...")
343
 
344
+ # Add the functionality to ensure it ends with ". "
345
+ if not ref_text.endswith(". ") and not ref_text.endswith("。"):
346
+ if ref_text.endswith("."):
347
+ ref_text += " "
348
+ else:
349
+ ref_text += ". "
350
+
351
  # Split the input text into batches
352
  audio, sr = torchaudio.load(ref_audio)
353
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
354
+ gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
355
  print('ref_text', ref_text)
356
  for i, gen_text in enumerate(gen_text_batches):
357
  print(f'gen_text {i}', gen_text)
358
 
359
  print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
360
+ return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
361
 
362
 
363
+ infer(ref_audio, ref_text, gen_text, model, remove_silence)
inference-cli.toml CHANGED
@@ -6,5 +6,5 @@ ref_text = "Some call me nature, others call me mother nature."
6
  gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
7
  # File with text to generate. Ignores the text above.
8
  gen_file = ""
9
- remove_silence = true
10
  output_dir = "tests"
 
6
  gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
7
  # File with text to generate. Ignores the text above.
8
  gen_file = ""
9
+ remove_silence = false
10
  output_dir = "tests"
model/utils.py CHANGED
@@ -22,12 +22,6 @@ from einops import rearrange, reduce
22
 
23
  import jieba
24
  from pypinyin import lazy_pinyin, Style
25
- import zhconv
26
- from zhon.hanzi import punctuation
27
- from jiwer import compute_measures
28
-
29
- from funasr import AutoModel
30
- from faster_whisper import WhisperModel
31
 
32
  from model.ecapa_tdnn import ECAPA_TDNN_SMALL
33
  from model.modules import MelSpec
@@ -432,6 +426,7 @@ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path
432
 
433
  def load_asr_model(lang, ckpt_dir = ""):
434
  if lang == "zh":
 
435
  model = AutoModel(
436
  model = os.path.join(ckpt_dir, "paraformer-zh"),
437
  # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
@@ -440,6 +435,7 @@ def load_asr_model(lang, ckpt_dir = ""):
440
  disable_update=True,
441
  ) # following seed-tts setting
442
  elif lang == "en":
 
443
  model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
444
  model = WhisperModel(model_size, device="cuda", compute_type="float16")
445
  return model
@@ -451,6 +447,7 @@ def run_asr_wer(args):
451
  rank, lang, test_set, ckpt_dir = args
452
 
453
  if lang == "zh":
 
454
  torch.cuda.set_device(rank)
455
  elif lang == "en":
456
  os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
@@ -458,10 +455,12 @@ def run_asr_wer(args):
458
  raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
459
 
460
  asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
461
-
 
462
  punctuation_all = punctuation + string.punctuation
463
  wers = []
464
 
 
465
  for gen_wav, prompt_wav, truth in tqdm(test_set):
466
  if lang == "zh":
467
  res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
 
22
 
23
  import jieba
24
  from pypinyin import lazy_pinyin, Style
 
 
 
 
 
 
25
 
26
  from model.ecapa_tdnn import ECAPA_TDNN_SMALL
27
  from model.modules import MelSpec
 
426
 
427
  def load_asr_model(lang, ckpt_dir = ""):
428
  if lang == "zh":
429
+ from funasr import AutoModel
430
  model = AutoModel(
431
  model = os.path.join(ckpt_dir, "paraformer-zh"),
432
  # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
 
435
  disable_update=True,
436
  ) # following seed-tts setting
437
  elif lang == "en":
438
+ from faster_whisper import WhisperModel
439
  model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
440
  model = WhisperModel(model_size, device="cuda", compute_type="float16")
441
  return model
 
447
  rank, lang, test_set, ckpt_dir = args
448
 
449
  if lang == "zh":
450
+ import zhconv
451
  torch.cuda.set_device(rank)
452
  elif lang == "en":
453
  os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
 
455
  raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
456
 
457
  asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
458
+
459
+ from zhon.hanzi import punctuation
460
  punctuation_all = punctuation + string.punctuation
461
  wers = []
462
 
463
+ from jiwer import compute_measures
464
  for gen_wav, prompt_wav, truth in tqdm(test_set):
465
  if lang == "zh":
466
  res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
requirements.txt CHANGED
@@ -5,25 +5,19 @@ datasets
5
  einops>=0.8.0
6
  einx>=0.3.0
7
  ema_pytorch>=0.5.2
8
- faster_whisper
9
- funasr
10
  gradio
11
  jieba
12
- jiwer
13
  librosa
14
  matplotlib
15
- numpy==1.23.5
16
  pydub
17
  pypinyin
18
  safetensors
19
  soundfile
20
- # torch>=2.0
21
- # torchaudio>=2.3.0
22
  torchdiffeq
23
  tqdm>=4.65.0
24
  transformers
25
  vocos
26
  wandb
27
  x_transformers>=1.31.14
28
- zhconv
29
- zhon
 
5
  einops>=0.8.0
6
  einx>=0.3.0
7
  ema_pytorch>=0.5.2
 
 
8
  gradio
9
  jieba
 
10
  librosa
11
  matplotlib
12
+ numpy<=1.26.4
13
  pydub
14
  pypinyin
15
  safetensors
16
  soundfile
17
+ tomli
 
18
  torchdiffeq
19
  tqdm>=4.65.0
20
  transformers
21
  vocos
22
  wandb
23
  x_transformers>=1.31.14
 
 
requirements_eval.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ faster_whisper
2
+ funasr
3
+ jiwer
4
+ zhconv
5
+ zhon