SWivid commited on
Commit
1cec6dd
·
1 Parent(s): b648e8b
Files changed (3) hide show
  1. README.md +1 -1
  2. gradio_app.py +4 -6
  3. inference-cli.py +7 -8
README.md CHANGED
@@ -87,7 +87,7 @@ python inference-cli.py \
87
  --model "E2-TTS" \
88
  --ref_audio "tests/ref_audio/test_zh_1_ref_short.wav" \
89
  --ref_text "对,这就是我,万人敬仰的太乙真人。" \
90
- --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:\"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?\""
91
  ```
92
 
93
  ### Gradio App
 
87
  --model "E2-TTS" \
88
  --ref_audio "tests/ref_audio/test_zh_1_ref_short.wav" \
89
  --ref_text "对,这就是我,万人敬仰的太乙真人。" \
90
+ --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道,我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
91
  ```
92
 
93
  ### Gradio App
gradio_app.py CHANGED
@@ -201,7 +201,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
201
  elif exp_name == "E2-TTS":
202
  ema_model = E2TTS_ema_model
203
 
204
- audio, sr = torchaudio.load(ref_audio)
205
  if audio.shape[0] > 1:
206
  audio = torch.mean(audio, dim=0, keepdim=True)
207
 
@@ -320,17 +320,15 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
320
  gr.Info("Using custom reference text...")
321
 
322
  # Split the input text into batches
323
- if len(ref_text.encode('utf-8')) == len(ref_text) and len(gen_text.encode('utf-8')) == len(gen_text):
324
- max_chars = 400-len(ref_text.encode('utf-8'))
325
- else:
326
- max_chars = 300-len(ref_text.encode('utf-8'))
327
  gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
328
  print('ref_text', ref_text)
329
  for i, gen_text in enumerate(gen_text_batches):
330
  print(f'gen_text {i}', gen_text)
331
 
332
  gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
333
- return infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence)
334
 
335
  def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
336
  # Split the script into speaker blocks
 
201
  elif exp_name == "E2-TTS":
202
  ema_model = E2TTS_ema_model
203
 
204
+ audio, sr = ref_audio
205
  if audio.shape[0] > 1:
206
  audio = torch.mean(audio, dim=0, keepdim=True)
207
 
 
320
  gr.Info("Using custom reference text...")
321
 
322
  # Split the input text into batches
323
+ audio, sr = torchaudio.load(ref_audio)
324
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / 24000) * (30 - audio.shape[-1] / 24000))
 
 
325
  gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
326
  print('ref_text', ref_text)
327
  for i, gen_text in enumerate(gen_text_batches):
328
  print(f'gen_text {i}', gen_text)
329
 
330
  gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
331
+ return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
332
 
333
  def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
334
  # Split the script into speaker blocks
inference-cli.py CHANGED
@@ -47,6 +47,7 @@ parser.add_argument(
47
  "-s",
48
  "--ref_text",
49
  type=str,
 
50
  help="Subtitle for the reference audio."
51
  )
52
  parser.add_argument(
@@ -70,7 +71,7 @@ args = parser.parse_args()
70
  config = tomli.load(open(args.config, "rb"))
71
 
72
  ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
73
- ref_text = args.ref_text if args.ref_text else config["ref_text"]
74
  gen_text = args.gen_text if args.gen_text else config["gen_text"]
75
  output_dir = args.output_dir if args.output_dir else config["output_dir"]
76
  exp_name = args.model if args.model else config["model"]
@@ -243,7 +244,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence)
243
  elif exp_name == "E2-TTS":
244
  ema_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
245
 
246
- audio, sr = torchaudio.load(ref_audio)
247
  if audio.shape[0] > 1:
248
  audio = torch.mean(audio, dim=0, keepdim=True)
249
 
@@ -364,17 +365,15 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
364
  print("Using custom reference text...")
365
 
366
  # Split the input text into batches
367
- if len(ref_text.encode('utf-8')) == len(ref_text) and len(gen_text.encode('utf-8')) == len(gen_text):
368
- max_chars = 400-len(ref_text.encode('utf-8'))
369
- else:
370
- max_chars = 300-len(ref_text.encode('utf-8'))
371
  gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
372
  print('ref_text', ref_text)
373
  for i, gen_text in enumerate(gen_text_batches):
374
  print(f'gen_text {i}', gen_text)
375
 
376
- print(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
377
- return infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence)
378
 
379
 
380
  infer(ref_audio, ref_text, gen_text, exp_name, remove_silence, ",".join(SPLIT_WORDS))
 
47
  "-s",
48
  "--ref_text",
49
  type=str,
50
+ default="666",
51
  help="Subtitle for the reference audio."
52
  )
53
  parser.add_argument(
 
71
  config = tomli.load(open(args.config, "rb"))
72
 
73
  ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
74
+ ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
75
  gen_text = args.gen_text if args.gen_text else config["gen_text"]
76
  output_dir = args.output_dir if args.output_dir else config["output_dir"]
77
  exp_name = args.model if args.model else config["model"]
 
244
  elif exp_name == "E2-TTS":
245
  ema_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
246
 
247
+ audio, sr = ref_audio
248
  if audio.shape[0] > 1:
249
  audio = torch.mean(audio, dim=0, keepdim=True)
250
 
 
365
  print("Using custom reference text...")
366
 
367
  # Split the input text into batches
368
+ audio, sr = torchaudio.load(ref_audio)
369
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / 24000) * (30 - audio.shape[-1] / 24000))
 
 
370
  gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
371
  print('ref_text', ref_text)
372
  for i, gen_text in enumerate(gen_text_batches):
373
  print(f'gen_text {i}', gen_text)
374
 
375
+ print(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches, loading models...")
376
+ return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
377
 
378
 
379
  infer(ref_audio, ref_text, gen_text, exp_name, remove_silence, ",".join(SPLIT_WORDS))