SWivid commited on
Commit
eb1a2c4
·
2 Parent(s): c2e2e16 1c556bc

Merge branch 'main' of github.com:SWivid/F5-TTS into main

Browse files
Files changed (1) hide show
  1. gradio_app.py +94 -130
gradio_app.py CHANGED
@@ -31,19 +31,6 @@ def gpu_decorator(func):
31
  else:
32
  return func
33
 
34
-
35
-
36
- SPLIT_WORDS = [
37
- "but", "however", "nevertheless", "yet", "still",
38
- "therefore", "thus", "hence", "consequently",
39
- "moreover", "furthermore", "additionally",
40
- "meanwhile", "alternatively", "otherwise",
41
- "namely", "specifically", "for example", "such as",
42
- "in fact", "indeed", "notably",
43
- "in contrast", "on the other hand", "conversely",
44
- "in conclusion", "to summarize", "finally"
45
- ]
46
-
47
  device = (
48
  "cuda"
49
  if torch.cuda.is_available()
@@ -71,7 +58,6 @@ cfg_strength = 2.0
71
  ode_method = "euler"
72
  sway_sampling_coef = -1.0
73
  speed = 1.0
74
- # fix_duration = 27 # None or float (duration in seconds)
75
  fix_duration = None
76
 
77
 
@@ -112,104 +98,37 @@ E2TTS_ema_model = load_model(
112
  "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
113
  )
114
 
115
- def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
116
- if len(text.encode('utf-8')) <= max_chars:
117
- return [text]
118
- if text[-1] not in ['。', '.', '!', '!', '?', '?']:
119
- text += '.'
120
-
121
- sentences = re.split('([。.!?!?])', text)
122
- sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
123
-
124
- batches = []
125
- current_batch = ""
126
-
127
- def split_by_words(text):
128
- words = text.split()
129
- current_word_part = ""
130
- word_batches = []
131
- for word in words:
132
- if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
133
- current_word_part += word + ' '
134
- else:
135
- if current_word_part:
136
- # Try to find a suitable split word
137
- for split_word in split_words:
138
- split_index = current_word_part.rfind(' ' + split_word + ' ')
139
- if split_index != -1:
140
- word_batches.append(current_word_part[:split_index].strip())
141
- current_word_part = current_word_part[split_index:].strip() + ' '
142
- break
143
- else:
144
- # If no suitable split word found, just append the current part
145
- word_batches.append(current_word_part.strip())
146
- current_word_part = ""
147
- current_word_part += word + ' '
148
- if current_word_part:
149
- word_batches.append(current_word_part.strip())
150
- return word_batches
151
 
152
  for sentence in sentences:
153
- if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
154
- current_batch += sentence
155
  else:
156
- # If adding this sentence would exceed the limit
157
- if current_batch:
158
- batches.append(current_batch)
159
- current_batch = ""
160
-
161
- # If the sentence itself is longer than max_chars, split it
162
- if len(sentence.encode('utf-8')) > max_chars:
163
- # First, try to split by colon
164
- colon_parts = sentence.split(':')
165
- if len(colon_parts) > 1:
166
- for part in colon_parts:
167
- if len(part.encode('utf-8')) <= max_chars:
168
- batches.append(part)
169
- else:
170
- # If colon part is still too long, split by comma
171
- comma_parts = re.split('[,,]', part)
172
- if len(comma_parts) > 1:
173
- current_comma_part = ""
174
- for comma_part in comma_parts:
175
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
176
- current_comma_part += comma_part + ','
177
- else:
178
- if current_comma_part:
179
- batches.append(current_comma_part.rstrip(','))
180
- current_comma_part = comma_part + ','
181
- if current_comma_part:
182
- batches.append(current_comma_part.rstrip(','))
183
- else:
184
- # If no comma, split by words
185
- batches.extend(split_by_words(part))
186
- else:
187
- # If no colon, split by comma
188
- comma_parts = re.split('[,,]', sentence)
189
- if len(comma_parts) > 1:
190
- current_comma_part = ""
191
- for comma_part in comma_parts:
192
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
193
- current_comma_part += comma_part + ','
194
- else:
195
- if current_comma_part:
196
- batches.append(current_comma_part.rstrip(','))
197
- current_comma_part = comma_part + ','
198
- if current_comma_part:
199
- batches.append(current_comma_part.rstrip(','))
200
- else:
201
- # If no comma, split by words
202
- batches.extend(split_by_words(sentence))
203
- else:
204
- current_batch = sentence
205
-
206
- if current_batch:
207
- batches.append(current_batch)
208
-
209
- return batches
210
 
211
  @gpu_decorator
212
- def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, progress=gr.Progress()):
213
  if exp_name == "F5-TTS":
214
  ema_model = F5TTS_ema_model
215
  elif exp_name == "E2-TTS":
@@ -267,8 +186,44 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
267
  generated_waves.append(generated_wave)
268
  spectrograms.append(generated_mel_spec[0].cpu().numpy())
269
 
270
- # Combine all generated waves
271
- final_wave = np.concatenate(generated_waves)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  # Remove silence
274
  if remove_silence:
@@ -294,11 +249,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
294
  return (target_sample_rate, final_wave), spectrogram_path
295
 
296
  @gpu_decorator
297
- def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words=''):
298
- if not custom_split_words.strip():
299
- custom_words = [word.strip() for word in custom_split_words.split(',')]
300
- global SPLIT_WORDS
301
- SPLIT_WORDS = custom_words
302
 
303
  print(gen_text)
304
 
@@ -306,7 +257,9 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
306
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
307
  aseg = AudioSegment.from_file(ref_audio_orig)
308
 
309
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
 
 
310
  non_silent_wave = AudioSegment.silent(duration=0)
311
  for non_silent_seg in non_silent_segs:
312
  non_silent_wave += non_silent_seg
@@ -332,16 +285,24 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
332
  else:
333
  gr.Info("Using custom reference text...")
334
 
335
- # Split the input text into batches
 
 
 
 
 
 
336
  audio, sr = torchaudio.load(ref_audio)
337
- max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
338
- gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
 
339
  print('ref_text', ref_text)
340
- for i, gen_text in enumerate(gen_text_batches):
341
- print(f'gen_text {i}', gen_text)
342
 
343
  gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
344
- return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
 
345
 
346
  @gpu_decorator
347
  def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
@@ -446,12 +407,7 @@ with gr.Blocks() as app_tts:
446
  remove_silence = gr.Checkbox(
447
  label="Remove Silences",
448
  info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
449
- value=True,
450
- )
451
- split_words_input = gr.Textbox(
452
- label="Custom Split Words",
453
- info="Enter custom words to split on, separated by commas. Leave blank to use default list.",
454
- lines=2,
455
  )
456
  speed_slider = gr.Slider(
457
  label="Speed",
@@ -461,6 +417,14 @@ with gr.Blocks() as app_tts:
461
  step=0.1,
462
  info="Adjust the speed of the audio.",
463
  )
 
 
 
 
 
 
 
 
464
  speed_slider.change(update_speed, inputs=speed_slider)
465
 
466
  audio_output = gr.Audio(label="Synthesized Audio")
@@ -474,7 +438,7 @@ with gr.Blocks() as app_tts:
474
  gen_text_input,
475
  model_choice,
476
  remove_silence,
477
- split_words_input,
478
  ],
479
  outputs=[audio_output, spectrogram_output],
480
  )
@@ -722,7 +686,7 @@ with gr.Blocks() as app_emotional:
722
  ref_text = speech_types[current_emotion].get('ref_text', '')
723
 
724
  # Generate speech for this segment
725
- audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, "")
726
  sr, audio_data = audio
727
 
728
  generated_audio_segments.append(audio_data)
@@ -823,4 +787,4 @@ def main(port, host, share, api):
823
 
824
 
825
  if __name__ == "__main__":
826
- main()
 
31
  else:
32
  return func
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  device = (
35
  "cuda"
36
  if torch.cuda.is_available()
 
58
  ode_method = "euler"
59
  sway_sampling_coef = -1.0
60
  speed = 1.0
 
61
  fix_duration = None
62
 
63
 
 
98
  "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
99
  )
100
 
101
+ def chunk_text(text, max_chars=135):
102
+ """
103
+ Splits the input text into chunks, each with a maximum number of characters.
104
+
105
+ Args:
106
+ text (str): The text to be split.
107
+ max_chars (int): The maximum number of characters per chunk.
108
+
109
+ Returns:
110
+ List[str]: A list of text chunks.
111
+ """
112
+ chunks = []
113
+ current_chunk = ""
114
+ # Split the text into sentences based on punctuation followed by whitespace
115
+ sentences = re.split(r'(?<=[;:,.!?])\s+', text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  for sentence in sentences:
118
+ if len(current_chunk) + len(sentence) <= max_chars:
119
+ current_chunk += sentence + " "
120
  else:
121
+ if current_chunk:
122
+ chunks.append(current_chunk.strip())
123
+ current_chunk = sentence + " "
124
+
125
+ if current_chunk:
126
+ chunks.append(current_chunk.strip())
127
+
128
+ return chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  @gpu_decorator
131
+ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, cross_fade_duration=0.15, progress=gr.Progress()):
132
  if exp_name == "F5-TTS":
133
  ema_model = F5TTS_ema_model
134
  elif exp_name == "E2-TTS":
 
186
  generated_waves.append(generated_wave)
187
  spectrograms.append(generated_mel_spec[0].cpu().numpy())
188
 
189
+ # Combine all generated waves with cross-fading
190
+ if cross_fade_duration <= 0:
191
+ # Simply concatenate
192
+ final_wave = np.concatenate(generated_waves)
193
+ else:
194
+ final_wave = generated_waves[0]
195
+ for i in range(1, len(generated_waves)):
196
+ prev_wave = final_wave
197
+ next_wave = generated_waves[i]
198
+
199
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
200
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
201
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
202
+
203
+ if cross_fade_samples <= 0:
204
+ # No overlap possible, concatenate
205
+ final_wave = np.concatenate([prev_wave, next_wave])
206
+ continue
207
+
208
+ # Overlapping parts
209
+ prev_overlap = prev_wave[-cross_fade_samples:]
210
+ next_overlap = next_wave[:cross_fade_samples]
211
+
212
+ # Fade out and fade in
213
+ fade_out = np.linspace(1, 0, cross_fade_samples)
214
+ fade_in = np.linspace(0, 1, cross_fade_samples)
215
+
216
+ # Cross-faded overlap
217
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
218
+
219
+ # Combine
220
+ new_wave = np.concatenate([
221
+ prev_wave[:-cross_fade_samples],
222
+ cross_faded_overlap,
223
+ next_wave[cross_fade_samples:]
224
+ ])
225
+
226
+ final_wave = new_wave
227
 
228
  # Remove silence
229
  if remove_silence:
 
249
  return (target_sample_rate, final_wave), spectrogram_path
250
 
251
  @gpu_decorator
252
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, cross_fade_duration=0.15):
 
 
 
 
253
 
254
  print(gen_text)
255
 
 
257
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
258
  aseg = AudioSegment.from_file(ref_audio_orig)
259
 
260
+ non_silent_segs = silence.split_on_silence(
261
+ aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500
262
+ )
263
  non_silent_wave = AudioSegment.silent(duration=0)
264
  for non_silent_seg in non_silent_segs:
265
  non_silent_wave += non_silent_seg
 
285
  else:
286
  gr.Info("Using custom reference text...")
287
 
288
+ # Add the functionality to ensure it ends with ". "
289
+ if not ref_text.endswith(". "):
290
+ if ref_text.endswith("."):
291
+ ref_text += " "
292
+ else:
293
+ ref_text += ". "
294
+
295
  audio, sr = torchaudio.load(ref_audio)
296
+
297
+ # Use the new chunk_text function to split gen_text
298
+ gen_text_batches = chunk_text(gen_text, max_chars=135)
299
  print('ref_text', ref_text)
300
+ for i, batch_text in enumerate(gen_text_batches):
301
+ print(f'gen_text {i}', batch_text)
302
 
303
  gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
304
+ return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence, cross_fade_duration)
305
+
306
 
307
  @gpu_decorator
308
  def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
 
407
  remove_silence = gr.Checkbox(
408
  label="Remove Silences",
409
  info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
410
+ value=False,
 
 
 
 
 
411
  )
412
  speed_slider = gr.Slider(
413
  label="Speed",
 
417
  step=0.1,
418
  info="Adjust the speed of the audio.",
419
  )
420
+ cross_fade_duration_slider = gr.Slider(
421
+ label="Cross-Fade Duration (s)",
422
+ minimum=0.0,
423
+ maximum=1.0,
424
+ value=0.15,
425
+ step=0.01,
426
+ info="Set the duration of the cross-fade between audio clips.",
427
+ )
428
  speed_slider.change(update_speed, inputs=speed_slider)
429
 
430
  audio_output = gr.Audio(label="Synthesized Audio")
 
438
  gen_text_input,
439
  model_choice,
440
  remove_silence,
441
+ cross_fade_duration_slider,
442
  ],
443
  outputs=[audio_output, spectrogram_output],
444
  )
 
686
  ref_text = speech_types[current_emotion].get('ref_text', '')
687
 
688
  # Generate speech for this segment
689
+ audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, 0)
690
  sr, audio_data = audio
691
 
692
  generated_audio_segments.append(audio_data)
 
787
 
788
 
789
  if __name__ == "__main__":
790
+ main()