SWivid commited on
Commit
615d183
·
1 Parent(s): 5622219

add code-switch friendly synth. and a smoother silence remover

Browse files
gradio_app.py CHANGED
@@ -7,7 +7,7 @@ import numpy as np
7
  import tempfile
8
  from einops import rearrange
9
  from vocos import Vocos
10
- from pydub import AudioSegment
11
  from model import CFM, UNetT, DiT, MMDiT
12
  from cached_path import cached_path
13
  from model.utils import (
@@ -111,7 +111,7 @@ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
111
  current_word_part = ""
112
  word_batches = []
113
  for word in words:
114
- if len(current_word_part) + len(word) + 1 <= max_chars:
115
  current_word_part += word + ' '
116
  else:
117
  if current_word_part:
@@ -132,7 +132,7 @@ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
132
  return word_batches
133
 
134
  for sentence in sentences:
135
- if len(current_batch) + len(sentence) <= max_chars:
136
  current_batch += sentence
137
  else:
138
  # If adding this sentence would exceed the limit
@@ -141,20 +141,20 @@ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
141
  current_batch = ""
142
 
143
  # If the sentence itself is longer than max_chars, split it
144
- if len(sentence) > max_chars:
145
  # First, try to split by colon
146
  colon_parts = sentence.split(':')
147
  if len(colon_parts) > 1:
148
  for part in colon_parts:
149
- if len(part) <= max_chars:
150
  batches.append(part)
151
  else:
152
  # If colon part is still too long, split by comma
153
- comma_parts = part.split(',')
154
  if len(comma_parts) > 1:
155
  current_comma_part = ""
156
  for comma_part in comma_parts:
157
- if len(current_comma_part) + len(comma_part) <= max_chars:
158
  current_comma_part += comma_part + ','
159
  else:
160
  if current_comma_part:
@@ -167,11 +167,11 @@ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
167
  batches.extend(split_by_words(part))
168
  else:
169
  # If no colon, split by comma
170
- comma_parts = sentence.split(',')
171
  if len(comma_parts) > 1:
172
  current_comma_part = ""
173
  for comma_part in comma_parts:
174
- if len(current_comma_part) + len(comma_part) <= max_chars:
175
  current_comma_part += comma_part + ','
176
  else:
177
  if current_comma_part:
@@ -219,8 +219,8 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
219
  # Calculate duration
220
  ref_audio_len = audio.shape[-1] // hop_length
221
  zh_pause_punc = r"。,、;:?!"
222
- ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
223
- gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
224
  duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
225
 
226
  # inference
@@ -244,23 +244,27 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
244
 
245
  # wav -> numpy
246
  generated_wave = generated_wave.squeeze().cpu().numpy()
247
-
248
- if remove_silence:
249
- non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
250
- non_silent_wave = np.array([])
251
- for interval in non_silent_intervals:
252
- start, end = interval
253
- non_silent_wave = np.concatenate(
254
- [non_silent_wave, generated_wave[start:end]]
255
- )
256
- generated_wave = non_silent_wave
257
-
258
  generated_waves.append(generated_wave)
259
  spectrograms.append(generated_mel_spec[0].cpu().numpy())
260
 
261
  # Combine all generated waves
262
  final_wave = np.concatenate(generated_waves)
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  # Create a combined spectrogram
265
  combined_spectrogram = np.concatenate(spectrograms, axis=1)
266
 
@@ -270,11 +274,24 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
270
 
271
  return (target_sample_rate, final_wave), spectrogram_path
272
 
273
- def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
 
 
 
 
 
274
  print(gen_text)
 
275
  gr.Info("Converting audio...")
276
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
277
  aseg = AudioSegment.from_file(ref_audio_orig)
 
 
 
 
 
 
 
278
  audio_duration = len(aseg)
279
  if audio_duration > 15000:
280
  gr.Warning("Audio is over 15s, clipping to only first 15s.")
@@ -296,7 +313,14 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
296
  gr.Info("Using custom reference text...")
297
 
298
  # Split the input text into batches
299
- gen_text_batches = split_text_into_batches(gen_text)
 
 
 
 
 
 
 
300
 
301
  gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
302
  return infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence)
@@ -393,15 +417,8 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
393
  audio_output = gr.Audio(label="Synthesized Audio")
394
  spectrogram_output = gr.Image(label="Spectrogram")
395
 
396
- def infer_with_custom_split(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words):
397
- if custom_split_words:
398
- custom_words = [word.strip() for word in custom_split_words.split(',')]
399
- global SPLIT_WORDS
400
- SPLIT_WORDS = custom_words
401
- return infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence)
402
-
403
  generate_btn.click(
404
- infer_with_custom_split,
405
  inputs=[
406
  ref_audio_input,
407
  ref_text_input,
@@ -412,6 +429,14 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
412
  ],
413
  outputs=[audio_output, spectrogram_output],
414
  )
 
 
 
 
 
 
 
 
415
  with gr.Tab("Podcast Generation"):
416
  speaker1_name = gr.Textbox(label="Speaker 1 Name")
417
  ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
 
7
  import tempfile
8
  from einops import rearrange
9
  from vocos import Vocos
10
+ from pydub import AudioSegment, silence
11
  from model import CFM, UNetT, DiT, MMDiT
12
  from cached_path import cached_path
13
  from model.utils import (
 
111
  current_word_part = ""
112
  word_batches = []
113
  for word in words:
114
+ if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
115
  current_word_part += word + ' '
116
  else:
117
  if current_word_part:
 
132
  return word_batches
133
 
134
  for sentence in sentences:
135
+ if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
136
  current_batch += sentence
137
  else:
138
  # If adding this sentence would exceed the limit
 
141
  current_batch = ""
142
 
143
  # If the sentence itself is longer than max_chars, split it
144
+ if len(sentence.encode('utf-8')) > max_chars:
145
  # First, try to split by colon
146
  colon_parts = sentence.split(':')
147
  if len(colon_parts) > 1:
148
  for part in colon_parts:
149
+ if len(part.encode('utf-8')) <= max_chars:
150
  batches.append(part)
151
  else:
152
  # If colon part is still too long, split by comma
153
+ comma_parts = re.split('[,,]', part)
154
  if len(comma_parts) > 1:
155
  current_comma_part = ""
156
  for comma_part in comma_parts:
157
+ if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
158
  current_comma_part += comma_part + ','
159
  else:
160
  if current_comma_part:
 
167
  batches.extend(split_by_words(part))
168
  else:
169
  # If no colon, split by comma
170
+ comma_parts = re.split('[,,]', sentence)
171
  if len(comma_parts) > 1:
172
  current_comma_part = ""
173
  for comma_part in comma_parts:
174
+ if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
175
  current_comma_part += comma_part + ','
176
  else:
177
  if current_comma_part:
 
219
  # Calculate duration
220
  ref_audio_len = audio.shape[-1] // hop_length
221
  zh_pause_punc = r"。,、;:?!"
222
+ ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
223
+ gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
224
  duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
225
 
226
  # inference
 
244
 
245
  # wav -> numpy
246
  generated_wave = generated_wave.squeeze().cpu().numpy()
247
+
 
 
 
 
 
 
 
 
 
 
248
  generated_waves.append(generated_wave)
249
  spectrograms.append(generated_mel_spec[0].cpu().numpy())
250
 
251
  # Combine all generated waves
252
  final_wave = np.concatenate(generated_waves)
253
 
254
+ # Remove silence
255
+ if remove_silence:
256
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
257
+ sf.write(f.name, final_wave, target_sample_rate)
258
+ aseg = AudioSegment.from_file(f.name)
259
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
260
+ non_silent_wave = AudioSegment.silent(duration=0)
261
+ for non_silent_seg in non_silent_segs:
262
+ non_silent_wave += non_silent_seg
263
+ aseg = non_silent_wave
264
+ aseg.export(f.name, format="wav")
265
+ final_wave, _ = torchaudio.load(f.name)
266
+ final_wave = final_wave.squeeze().cpu().numpy()
267
+
268
  # Create a combined spectrogram
269
  combined_spectrogram = np.concatenate(spectrograms, axis=1)
270
 
 
274
 
275
  return (target_sample_rate, final_wave), spectrogram_path
276
 
277
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words):
278
+ if not custom_split_words.strip():
279
+ custom_words = [word.strip() for word in custom_split_words.split(',')]
280
+ global SPLIT_WORDS
281
+ SPLIT_WORDS = custom_words
282
+
283
  print(gen_text)
284
+
285
  gr.Info("Converting audio...")
286
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
287
  aseg = AudioSegment.from_file(ref_audio_orig)
288
+
289
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
290
+ non_silent_wave = AudioSegment.silent(duration=0)
291
+ for non_silent_seg in non_silent_segs:
292
+ non_silent_wave += non_silent_seg
293
+ aseg = non_silent_wave
294
+
295
  audio_duration = len(aseg)
296
  if audio_duration > 15000:
297
  gr.Warning("Audio is over 15s, clipping to only first 15s.")
 
313
  gr.Info("Using custom reference text...")
314
 
315
  # Split the input text into batches
316
+ if len(ref_text.encode('utf-8')) == len(ref_text):
317
+ max_chars = 400-len(ref_text.encode('utf-8'))
318
+ else:
319
+ max_chars = 300-len(ref_text.encode('utf-8'))
320
+ gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
321
+ print('ref_text', ref_text)
322
+ for i, gen_text in enumerate(gen_text_batches):
323
+ print(f'gen_text {i}', gen_text)
324
 
325
  gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
326
  return infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence)
 
417
  audio_output = gr.Audio(label="Synthesized Audio")
418
  spectrogram_output = gr.Image(label="Spectrogram")
419
 
 
 
 
 
 
 
 
420
  generate_btn.click(
421
+ infer,
422
  inputs=[
423
  ref_audio_input,
424
  ref_text_input,
 
429
  ],
430
  outputs=[audio_output, spectrogram_output],
431
  )
432
+
433
+ gr.Markdown(
434
+ """
435
+ # Podcast Generation
436
+
437
+ Supported by [RootingInLoad](https://github.com/RootingInLoad)
438
+ """
439
+ )
440
  with gr.Tab("Podcast Generation"):
441
  speaker1_name = gr.Textbox(label="Speaker 1 Name")
442
  ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
model/utils.py CHANGED
@@ -294,8 +294,8 @@ def get_inference_prompt(
294
  # ref_audio = gt_audio
295
  else:
296
  zh_pause_punc = r"。,、;:?!"
297
- ref_text_len = len(prompt_text) + len(re.findall(zh_pause_punc, prompt_text))
298
- gen_text_len = len(gt_text) + len(re.findall(zh_pause_punc, gt_text))
299
  total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
300
 
301
  # to mel spectrogram
 
294
  # ref_audio = gt_audio
295
  else:
296
  zh_pause_punc = r"。,、;:?!"
297
+ ref_text_len = len(prompt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, prompt_text))
298
+ gen_text_len = len(gt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gt_text))
299
  total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
300
 
301
  # to mel spectrogram
requirements_gradio.txt CHANGED
@@ -1,4 +1,5 @@
1
  cached_path
2
  click
3
  gradio
4
- pydub
 
 
1
  cached_path
2
  click
3
  gradio
4
+ pydub
5
+ soundfile
test_infer_single.py CHANGED
@@ -130,8 +130,8 @@ if fix_duration is not None:
130
  duration = int(fix_duration * target_sample_rate / hop_length)
131
  else: # simple linear scale calcul
132
  zh_pause_punc = r"。,、;:?!"
133
- ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
134
- gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
135
  duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
136
 
137
  # Inference
 
130
  duration = int(fix_duration * target_sample_rate / hop_length)
131
  else: # simple linear scale calcul
132
  zh_pause_punc = r"。,、;:?!"
133
+ ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
134
+ gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
135
  duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
136
 
137
  # Inference