SWivid commited on
Commit
656197c
·
1 Parent(s): 8a0c9c4

merging into one infer_batch_process function

Browse files
src/f5_tts/infer/README.md CHANGED
@@ -144,7 +144,14 @@ python src/f5_tts/socket_server.py
144
  <details>
145
  <summary>Then create client to communicate</summary>
146
 
 
 
 
 
 
 
147
  ``` python
 
148
  import socket
149
  import asyncio
150
  import pyaudio
@@ -165,7 +172,6 @@ async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
165
 
166
  async def play_audio_stream():
167
  nonlocal first_chunk_time
168
- buffer = b""
169
  p = pyaudio.PyAudio()
170
  stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
171
 
@@ -204,7 +210,7 @@ async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
204
 
205
 
206
  if __name__ == "__main__":
207
- text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency.Let's break down the components"
208
 
209
  asyncio.run(listen_to_F5TTS(text_to_send))
210
  ```
 
144
  <details>
145
  <summary>Then create client to communicate</summary>
146
 
147
+ ```bash
148
+ # If PyAudio not installed
149
+ sudo apt-get install portaudio19-dev
150
+ pip install pyaudio
151
+ ```
152
+
153
  ``` python
154
+ # Create the socket_client.py
155
  import socket
156
  import asyncio
157
  import pyaudio
 
172
 
173
  async def play_audio_stream():
174
  nonlocal first_chunk_time
 
175
  p = pyaudio.PyAudio()
176
  stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
177
 
 
210
 
211
 
212
  if __name__ == "__main__":
213
+ text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
214
 
215
  asyncio.run(listen_to_F5TTS(text_to_send))
216
  ```
src/f5_tts/infer/utils_infer.py CHANGED
@@ -390,22 +390,24 @@ def infer_process(
390
  print("\n")
391
 
392
  show_info(f"Generating audio in {len(gen_text_batches)} batches...")
393
- return infer_batch_process(
394
- (audio, sr),
395
- ref_text,
396
- gen_text_batches,
397
- model_obj,
398
- vocoder,
399
- mel_spec_type=mel_spec_type,
400
- progress=progress,
401
- target_rms=target_rms,
402
- cross_fade_duration=cross_fade_duration,
403
- nfe_step=nfe_step,
404
- cfg_strength=cfg_strength,
405
- sway_sampling_coef=sway_sampling_coef,
406
- speed=speed,
407
- fix_duration=fix_duration,
408
- device=device,
 
 
409
  )
410
 
411
 
@@ -428,125 +430,6 @@ def infer_batch_process(
428
  speed=1,
429
  fix_duration=None,
430
  device=None,
431
- ):
432
- audio, sr = ref_audio
433
- if audio.shape[0] > 1:
434
- audio = torch.mean(audio, dim=0, keepdim=True)
435
-
436
- rms = torch.sqrt(torch.mean(torch.square(audio)))
437
- if rms < target_rms:
438
- audio = audio * target_rms / rms
439
- if sr != target_sample_rate:
440
- resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
441
- audio = resampler(audio)
442
- audio = audio.to(device)
443
-
444
- generated_waves = []
445
- spectrograms = []
446
-
447
- if len(ref_text[-1].encode("utf-8")) == 1:
448
- ref_text = ref_text + " "
449
- for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
450
- # Prepare the text
451
- text_list = [ref_text + gen_text]
452
- final_text_list = convert_char_to_pinyin(text_list)
453
-
454
- ref_audio_len = audio.shape[-1] // hop_length
455
- if fix_duration is not None:
456
- duration = int(fix_duration * target_sample_rate / hop_length)
457
- else:
458
- # Calculate duration
459
- ref_text_len = len(ref_text.encode("utf-8"))
460
- gen_text_len = len(gen_text.encode("utf-8"))
461
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
462
-
463
- # inference
464
- with torch.inference_mode():
465
- generated, _ = model_obj.sample(
466
- cond=audio,
467
- text=final_text_list,
468
- duration=duration,
469
- steps=nfe_step,
470
- cfg_strength=cfg_strength,
471
- sway_sampling_coef=sway_sampling_coef,
472
- )
473
-
474
- generated = generated.to(torch.float32)
475
- generated = generated[:, ref_audio_len:, :]
476
- generated_mel_spec = generated.permute(0, 2, 1)
477
- if mel_spec_type == "vocos":
478
- generated_wave = vocoder.decode(generated_mel_spec)
479
- elif mel_spec_type == "bigvgan":
480
- generated_wave = vocoder(generated_mel_spec)
481
- if rms < target_rms:
482
- generated_wave = generated_wave * rms / target_rms
483
-
484
- # wav -> numpy
485
- generated_wave = generated_wave.squeeze().cpu().numpy()
486
-
487
- generated_waves.append(generated_wave)
488
- spectrograms.append(generated_mel_spec[0].cpu().numpy())
489
-
490
- # Combine all generated waves with cross-fading
491
- if cross_fade_duration <= 0:
492
- # Simply concatenate
493
- final_wave = np.concatenate(generated_waves)
494
- else:
495
- final_wave = generated_waves[0]
496
- for i in range(1, len(generated_waves)):
497
- prev_wave = final_wave
498
- next_wave = generated_waves[i]
499
-
500
- # Calculate cross-fade samples, ensuring it does not exceed wave lengths
501
- cross_fade_samples = int(cross_fade_duration * target_sample_rate)
502
- cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
503
-
504
- if cross_fade_samples <= 0:
505
- # No overlap possible, concatenate
506
- final_wave = np.concatenate([prev_wave, next_wave])
507
- continue
508
-
509
- # Overlapping parts
510
- prev_overlap = prev_wave[-cross_fade_samples:]
511
- next_overlap = next_wave[:cross_fade_samples]
512
-
513
- # Fade out and fade in
514
- fade_out = np.linspace(1, 0, cross_fade_samples)
515
- fade_in = np.linspace(0, 1, cross_fade_samples)
516
-
517
- # Cross-faded overlap
518
- cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
519
-
520
- # Combine
521
- new_wave = np.concatenate(
522
- [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
523
- )
524
-
525
- final_wave = new_wave
526
-
527
- # Create a combined spectrogram
528
- combined_spectrogram = np.concatenate(spectrograms, axis=1)
529
-
530
- return final_wave, target_sample_rate, combined_spectrogram
531
-
532
-
533
- # infer batch process for stream mode
534
- def infer_batch_process_stream(
535
- ref_audio,
536
- ref_text,
537
- gen_text_batches,
538
- model_obj,
539
- vocoder,
540
- mel_spec_type="vocos",
541
- progress=None,
542
- target_rms=0.1,
543
- cross_fade_duration=0.15,
544
- nfe_step=32,
545
- cfg_strength=2.0,
546
- sway_sampling_coef=-1,
547
- speed=1,
548
- fix_duration=None,
549
- device=None,
550
  streaming=False,
551
  chunk_size=2048,
552
  ):
@@ -562,19 +445,18 @@ def infer_batch_process_stream(
562
  audio = resampler(audio)
563
  audio = audio.to(device)
564
 
565
- if len(ref_text[-1].encode("utf-8")) == 1:
566
- ref_text = ref_text + " "
567
-
568
  generated_waves = []
569
  spectrograms = []
570
 
571
- def process_batch(i, gen_text):
572
- print(f"Generating audio for batch {i + 1}/{len(gen_text_batches)}: {gen_text}")
573
 
 
574
  local_speed = speed
575
- if len(gen_text) < 10:
576
  local_speed = 0.3
577
 
 
578
  text_list = [ref_text + gen_text]
579
  final_text_list = convert_char_to_pinyin(text_list)
580
 
@@ -582,10 +464,12 @@ def infer_batch_process_stream(
582
  if fix_duration is not None:
583
  duration = int(fix_duration * target_sample_rate / hop_length)
584
  else:
 
585
  ref_text_len = len(ref_text.encode("utf-8"))
586
  gen_text_len = len(gen_text.encode("utf-8"))
587
  duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
588
 
 
589
  with torch.inference_mode():
590
  generated, _ = model_obj.sample(
591
  cond=audio,
@@ -599,76 +483,79 @@ def infer_batch_process_stream(
599
  generated = generated.to(torch.float32)
600
  generated = generated[:, ref_audio_len:, :]
601
  generated_mel_spec = generated.permute(0, 2, 1)
602
-
603
- print(f"Generated mel spectrogram shape: {generated_mel_spec.shape}")
604
-
605
  if mel_spec_type == "vocos":
606
  generated_wave = vocoder.decode(generated_mel_spec)
607
  elif mel_spec_type == "bigvgan":
608
  generated_wave = vocoder(generated_mel_spec)
609
-
610
- print(f"Generated wave shape before RMS adjustment: {generated_wave.shape}")
611
-
612
  if rms < target_rms:
613
  generated_wave = generated_wave * rms / target_rms
614
 
615
- print(f"Generated wave shape after RMS adjustment: {generated_wave.shape}")
616
-
617
  generated_wave = generated_wave.squeeze().cpu().numpy()
618
 
619
  if streaming:
620
  for j in range(0, len(generated_wave), chunk_size):
621
  yield generated_wave[j : j + chunk_size], target_sample_rate
622
-
623
- return generated_wave, generated_mel_spec[0].cpu().numpy()
624
 
625
  if streaming:
626
- for i, gen_text in enumerate(progress.tqdm(gen_text_batches) if progress else gen_text_batches):
627
- for chunk in process_batch(i, gen_text):
628
  yield chunk
629
  else:
630
  with ThreadPoolExecutor() as executor:
631
- futures = [executor.submit(process_batch, i, gen_text) for i, gen_text in enumerate(gen_text_batches)]
632
  for future in progress.tqdm(futures) if progress else futures:
633
  result = future.result()
634
  if result:
635
- generated_wave, generated_mel_spec = result
636
  generated_waves.append(generated_wave)
637
  spectrograms.append(generated_mel_spec)
638
 
639
  if generated_waves:
640
  if cross_fade_duration <= 0:
 
641
  final_wave = np.concatenate(generated_waves)
642
  else:
 
643
  final_wave = generated_waves[0]
644
  for i in range(1, len(generated_waves)):
645
  prev_wave = final_wave
646
  next_wave = generated_waves[i]
647
 
 
648
  cross_fade_samples = int(cross_fade_duration * target_sample_rate)
649
  cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
650
 
651
  if cross_fade_samples <= 0:
 
652
  final_wave = np.concatenate([prev_wave, next_wave])
653
  continue
654
 
 
655
  prev_overlap = prev_wave[-cross_fade_samples:]
656
  next_overlap = next_wave[:cross_fade_samples]
657
 
 
658
  fade_out = np.linspace(1, 0, cross_fade_samples)
659
  fade_in = np.linspace(0, 1, cross_fade_samples)
660
 
 
661
  cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
662
 
 
663
  new_wave = np.concatenate(
664
  [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
665
  )
666
 
667
  final_wave = new_wave
668
 
 
669
  combined_spectrogram = np.concatenate(spectrograms, axis=1)
670
 
671
  yield final_wave, target_sample_rate, combined_spectrogram
 
672
  else:
673
  yield None, target_sample_rate, None
674
 
 
390
  print("\n")
391
 
392
  show_info(f"Generating audio in {len(gen_text_batches)} batches...")
393
+ return next(
394
+ infer_batch_process(
395
+ (audio, sr),
396
+ ref_text,
397
+ gen_text_batches,
398
+ model_obj,
399
+ vocoder,
400
+ mel_spec_type=mel_spec_type,
401
+ progress=progress,
402
+ target_rms=target_rms,
403
+ cross_fade_duration=cross_fade_duration,
404
+ nfe_step=nfe_step,
405
+ cfg_strength=cfg_strength,
406
+ sway_sampling_coef=sway_sampling_coef,
407
+ speed=speed,
408
+ fix_duration=fix_duration,
409
+ device=device,
410
+ )
411
  )
412
 
413
 
 
430
  speed=1,
431
  fix_duration=None,
432
  device=None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  streaming=False,
434
  chunk_size=2048,
435
  ):
 
445
  audio = resampler(audio)
446
  audio = audio.to(device)
447
 
 
 
 
448
  generated_waves = []
449
  spectrograms = []
450
 
451
+ if len(ref_text[-1].encode("utf-8")) == 1:
452
+ ref_text = ref_text + " "
453
 
454
+ def process_batch(gen_text):
455
  local_speed = speed
456
+ if len(gen_text.encode("utf-8")) < 10:
457
  local_speed = 0.3
458
 
459
+ # Prepare the text
460
  text_list = [ref_text + gen_text]
461
  final_text_list = convert_char_to_pinyin(text_list)
462
 
 
464
  if fix_duration is not None:
465
  duration = int(fix_duration * target_sample_rate / hop_length)
466
  else:
467
+ # Calculate duration
468
  ref_text_len = len(ref_text.encode("utf-8"))
469
  gen_text_len = len(gen_text.encode("utf-8"))
470
  duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
471
 
472
+ # inference
473
  with torch.inference_mode():
474
  generated, _ = model_obj.sample(
475
  cond=audio,
 
483
  generated = generated.to(torch.float32)
484
  generated = generated[:, ref_audio_len:, :]
485
  generated_mel_spec = generated.permute(0, 2, 1)
 
 
 
486
  if mel_spec_type == "vocos":
487
  generated_wave = vocoder.decode(generated_mel_spec)
488
  elif mel_spec_type == "bigvgan":
489
  generated_wave = vocoder(generated_mel_spec)
 
 
 
490
  if rms < target_rms:
491
  generated_wave = generated_wave * rms / target_rms
492
 
493
+ # wav -> numpy
 
494
  generated_wave = generated_wave.squeeze().cpu().numpy()
495
 
496
  if streaming:
497
  for j in range(0, len(generated_wave), chunk_size):
498
  yield generated_wave[j : j + chunk_size], target_sample_rate
499
+ else:
500
+ yield generated_wave, generated_mel_spec[0].cpu().numpy()
501
 
502
  if streaming:
503
+ for gen_text in progress.tqdm(gen_text_batches) if progress else gen_text_batches:
504
+ for chunk in process_batch(gen_text):
505
  yield chunk
506
  else:
507
  with ThreadPoolExecutor() as executor:
508
+ futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches]
509
  for future in progress.tqdm(futures) if progress else futures:
510
  result = future.result()
511
  if result:
512
+ generated_wave, generated_mel_spec = next(result)
513
  generated_waves.append(generated_wave)
514
  spectrograms.append(generated_mel_spec)
515
 
516
  if generated_waves:
517
  if cross_fade_duration <= 0:
518
+ # Simply concatenate
519
  final_wave = np.concatenate(generated_waves)
520
  else:
521
+ # Combine all generated waves with cross-fading
522
  final_wave = generated_waves[0]
523
  for i in range(1, len(generated_waves)):
524
  prev_wave = final_wave
525
  next_wave = generated_waves[i]
526
 
527
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
528
  cross_fade_samples = int(cross_fade_duration * target_sample_rate)
529
  cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
530
 
531
  if cross_fade_samples <= 0:
532
+ # No overlap possible, concatenate
533
  final_wave = np.concatenate([prev_wave, next_wave])
534
  continue
535
 
536
+ # Overlapping parts
537
  prev_overlap = prev_wave[-cross_fade_samples:]
538
  next_overlap = next_wave[:cross_fade_samples]
539
 
540
+ # Fade out and fade in
541
  fade_out = np.linspace(1, 0, cross_fade_samples)
542
  fade_in = np.linspace(0, 1, cross_fade_samples)
543
 
544
+ # Cross-faded overlap
545
  cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
546
 
547
+ # Combine
548
  new_wave = np.concatenate(
549
  [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
550
  )
551
 
552
  final_wave = new_wave
553
 
554
+ # Create a combined spectrogram
555
  combined_spectrogram = np.concatenate(spectrograms, axis=1)
556
 
557
  yield final_wave, target_sample_rate, combined_spectrogram
558
+
559
  else:
560
  yield None, target_sample_rate, None
561
 
src/f5_tts/socket_server.py CHANGED
@@ -1,20 +1,31 @@
 
 
 
 
 
1
  import socket
2
  import struct
 
 
 
 
 
3
  import torch
4
  import torchaudio
5
- import logging
6
- import wave
7
- import numpy as np
8
- import argparse
9
- import traceback
10
- import gc
11
- import threading
12
- import queue
13
- from nltk.tokenize import sent_tokenize
14
- from infer.utils_infer import preprocess_ref_audio_text, load_vocoder, load_model, infer_batch_process_stream
15
- from model.backbones.dit import DiT
16
  from huggingface_hub import hf_hub_download
17
- from importlib.resources import files
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
@@ -103,12 +114,13 @@ class TTSStreamingProcessor:
103
  def _warm_up(self):
104
  logger.info("Warming up the model...")
105
  gen_text = "Warm-up text for the model."
106
- for _ in infer_batch_process_stream(
107
  (self.audio, self.sr),
108
  self.ref_text,
109
  [gen_text],
110
  self.model,
111
  self.vocoder,
 
112
  device=self.device,
113
  streaming=True,
114
  ):
@@ -118,12 +130,13 @@ class TTSStreamingProcessor:
118
  def generate_stream(self, text, conn):
119
  text_batches = sent_tokenize(text)
120
 
121
- audio_stream = infer_batch_process_stream(
122
  (self.audio, self.sr),
123
  self.ref_text,
124
  text_batches,
125
  self.model,
126
  self.vocoder,
 
127
  device=self.device,
128
  streaming=True,
129
  chunk_size=2048,
 
1
+ import argparse
2
+ import gc
3
+ import logging
4
+ import numpy as np
5
+ import queue
6
  import socket
7
  import struct
8
+ import threading
9
+ import traceback
10
+ import wave
11
+ from importlib.resources import files
12
+
13
  import torch
14
  import torchaudio
 
 
 
 
 
 
 
 
 
 
 
15
  from huggingface_hub import hf_hub_download
16
+
17
+ import nltk
18
+ from nltk.tokenize import sent_tokenize
19
+
20
+ from f5_tts.model.backbones.dit import DiT
21
+ from f5_tts.infer.utils_infer import (
22
+ preprocess_ref_audio_text,
23
+ load_vocoder,
24
+ load_model,
25
+ infer_batch_process,
26
+ )
27
+
28
+ nltk.download("punkt_tab")
29
 
30
  logging.basicConfig(level=logging.INFO)
31
  logger = logging.getLogger(__name__)
 
114
  def _warm_up(self):
115
  logger.info("Warming up the model...")
116
  gen_text = "Warm-up text for the model."
117
+ for _ in infer_batch_process(
118
  (self.audio, self.sr),
119
  self.ref_text,
120
  [gen_text],
121
  self.model,
122
  self.vocoder,
123
+ progress=None,
124
  device=self.device,
125
  streaming=True,
126
  ):
 
130
  def generate_stream(self, text, conn):
131
  text_batches = sent_tokenize(text)
132
 
133
+ audio_stream = infer_batch_process(
134
  (self.audio, self.sr),
135
  self.ref_text,
136
  text_batches,
137
  self.model,
138
  self.vocoder,
139
+ progress=None,
140
  device=self.device,
141
  streaming=True,
142
  chunk_size=2048,