merging into one infer_batch_process function
Browse files- src/f5_tts/infer/README.md +8 -2
- src/f5_tts/infer/utils_infer.py +42 -155
- src/f5_tts/socket_server.py +27 -14
src/f5_tts/infer/README.md
CHANGED
@@ -144,7 +144,14 @@ python src/f5_tts/socket_server.py
|
|
144 |
<details>
|
145 |
<summary>Then create client to communicate</summary>
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
``` python
|
|
|
148 |
import socket
|
149 |
import asyncio
|
150 |
import pyaudio
|
@@ -165,7 +172,6 @@ async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
|
|
165 |
|
166 |
async def play_audio_stream():
|
167 |
nonlocal first_chunk_time
|
168 |
-
buffer = b""
|
169 |
p = pyaudio.PyAudio()
|
170 |
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
|
171 |
|
@@ -204,7 +210,7 @@ async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
|
|
204 |
|
205 |
|
206 |
if __name__ == "__main__":
|
207 |
-
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency.Let's break down the components"
|
208 |
|
209 |
asyncio.run(listen_to_F5TTS(text_to_send))
|
210 |
```
|
|
|
144 |
<details>
|
145 |
<summary>Then create client to communicate</summary>
|
146 |
|
147 |
+
```bash
|
148 |
+
# If PyAudio not installed
|
149 |
+
sudo apt-get install portaudio19-dev
|
150 |
+
pip install pyaudio
|
151 |
+
```
|
152 |
+
|
153 |
``` python
|
154 |
+
# Create the socket_client.py
|
155 |
import socket
|
156 |
import asyncio
|
157 |
import pyaudio
|
|
|
172 |
|
173 |
async def play_audio_stream():
|
174 |
nonlocal first_chunk_time
|
|
|
175 |
p = pyaudio.PyAudio()
|
176 |
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
|
177 |
|
|
|
210 |
|
211 |
|
212 |
if __name__ == "__main__":
|
213 |
+
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
|
214 |
|
215 |
asyncio.run(listen_to_F5TTS(text_to_send))
|
216 |
```
|
src/f5_tts/infer/utils_infer.py
CHANGED
@@ -390,22 +390,24 @@ def infer_process(
|
|
390 |
print("\n")
|
391 |
|
392 |
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
|
393 |
-
return
|
394 |
-
(
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
|
|
|
|
409 |
)
|
410 |
|
411 |
|
@@ -428,125 +430,6 @@ def infer_batch_process(
|
|
428 |
speed=1,
|
429 |
fix_duration=None,
|
430 |
device=None,
|
431 |
-
):
|
432 |
-
audio, sr = ref_audio
|
433 |
-
if audio.shape[0] > 1:
|
434 |
-
audio = torch.mean(audio, dim=0, keepdim=True)
|
435 |
-
|
436 |
-
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
437 |
-
if rms < target_rms:
|
438 |
-
audio = audio * target_rms / rms
|
439 |
-
if sr != target_sample_rate:
|
440 |
-
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
441 |
-
audio = resampler(audio)
|
442 |
-
audio = audio.to(device)
|
443 |
-
|
444 |
-
generated_waves = []
|
445 |
-
spectrograms = []
|
446 |
-
|
447 |
-
if len(ref_text[-1].encode("utf-8")) == 1:
|
448 |
-
ref_text = ref_text + " "
|
449 |
-
for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
|
450 |
-
# Prepare the text
|
451 |
-
text_list = [ref_text + gen_text]
|
452 |
-
final_text_list = convert_char_to_pinyin(text_list)
|
453 |
-
|
454 |
-
ref_audio_len = audio.shape[-1] // hop_length
|
455 |
-
if fix_duration is not None:
|
456 |
-
duration = int(fix_duration * target_sample_rate / hop_length)
|
457 |
-
else:
|
458 |
-
# Calculate duration
|
459 |
-
ref_text_len = len(ref_text.encode("utf-8"))
|
460 |
-
gen_text_len = len(gen_text.encode("utf-8"))
|
461 |
-
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
462 |
-
|
463 |
-
# inference
|
464 |
-
with torch.inference_mode():
|
465 |
-
generated, _ = model_obj.sample(
|
466 |
-
cond=audio,
|
467 |
-
text=final_text_list,
|
468 |
-
duration=duration,
|
469 |
-
steps=nfe_step,
|
470 |
-
cfg_strength=cfg_strength,
|
471 |
-
sway_sampling_coef=sway_sampling_coef,
|
472 |
-
)
|
473 |
-
|
474 |
-
generated = generated.to(torch.float32)
|
475 |
-
generated = generated[:, ref_audio_len:, :]
|
476 |
-
generated_mel_spec = generated.permute(0, 2, 1)
|
477 |
-
if mel_spec_type == "vocos":
|
478 |
-
generated_wave = vocoder.decode(generated_mel_spec)
|
479 |
-
elif mel_spec_type == "bigvgan":
|
480 |
-
generated_wave = vocoder(generated_mel_spec)
|
481 |
-
if rms < target_rms:
|
482 |
-
generated_wave = generated_wave * rms / target_rms
|
483 |
-
|
484 |
-
# wav -> numpy
|
485 |
-
generated_wave = generated_wave.squeeze().cpu().numpy()
|
486 |
-
|
487 |
-
generated_waves.append(generated_wave)
|
488 |
-
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
489 |
-
|
490 |
-
# Combine all generated waves with cross-fading
|
491 |
-
if cross_fade_duration <= 0:
|
492 |
-
# Simply concatenate
|
493 |
-
final_wave = np.concatenate(generated_waves)
|
494 |
-
else:
|
495 |
-
final_wave = generated_waves[0]
|
496 |
-
for i in range(1, len(generated_waves)):
|
497 |
-
prev_wave = final_wave
|
498 |
-
next_wave = generated_waves[i]
|
499 |
-
|
500 |
-
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
|
501 |
-
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
502 |
-
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
|
503 |
-
|
504 |
-
if cross_fade_samples <= 0:
|
505 |
-
# No overlap possible, concatenate
|
506 |
-
final_wave = np.concatenate([prev_wave, next_wave])
|
507 |
-
continue
|
508 |
-
|
509 |
-
# Overlapping parts
|
510 |
-
prev_overlap = prev_wave[-cross_fade_samples:]
|
511 |
-
next_overlap = next_wave[:cross_fade_samples]
|
512 |
-
|
513 |
-
# Fade out and fade in
|
514 |
-
fade_out = np.linspace(1, 0, cross_fade_samples)
|
515 |
-
fade_in = np.linspace(0, 1, cross_fade_samples)
|
516 |
-
|
517 |
-
# Cross-faded overlap
|
518 |
-
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
519 |
-
|
520 |
-
# Combine
|
521 |
-
new_wave = np.concatenate(
|
522 |
-
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
|
523 |
-
)
|
524 |
-
|
525 |
-
final_wave = new_wave
|
526 |
-
|
527 |
-
# Create a combined spectrogram
|
528 |
-
combined_spectrogram = np.concatenate(spectrograms, axis=1)
|
529 |
-
|
530 |
-
return final_wave, target_sample_rate, combined_spectrogram
|
531 |
-
|
532 |
-
|
533 |
-
# infer batch process for stream mode
|
534 |
-
def infer_batch_process_stream(
|
535 |
-
ref_audio,
|
536 |
-
ref_text,
|
537 |
-
gen_text_batches,
|
538 |
-
model_obj,
|
539 |
-
vocoder,
|
540 |
-
mel_spec_type="vocos",
|
541 |
-
progress=None,
|
542 |
-
target_rms=0.1,
|
543 |
-
cross_fade_duration=0.15,
|
544 |
-
nfe_step=32,
|
545 |
-
cfg_strength=2.0,
|
546 |
-
sway_sampling_coef=-1,
|
547 |
-
speed=1,
|
548 |
-
fix_duration=None,
|
549 |
-
device=None,
|
550 |
streaming=False,
|
551 |
chunk_size=2048,
|
552 |
):
|
@@ -562,19 +445,18 @@ def infer_batch_process_stream(
|
|
562 |
audio = resampler(audio)
|
563 |
audio = audio.to(device)
|
564 |
|
565 |
-
if len(ref_text[-1].encode("utf-8")) == 1:
|
566 |
-
ref_text = ref_text + " "
|
567 |
-
|
568 |
generated_waves = []
|
569 |
spectrograms = []
|
570 |
|
571 |
-
|
572 |
-
|
573 |
|
|
|
574 |
local_speed = speed
|
575 |
-
if len(gen_text) < 10:
|
576 |
local_speed = 0.3
|
577 |
|
|
|
578 |
text_list = [ref_text + gen_text]
|
579 |
final_text_list = convert_char_to_pinyin(text_list)
|
580 |
|
@@ -582,10 +464,12 @@ def infer_batch_process_stream(
|
|
582 |
if fix_duration is not None:
|
583 |
duration = int(fix_duration * target_sample_rate / hop_length)
|
584 |
else:
|
|
|
585 |
ref_text_len = len(ref_text.encode("utf-8"))
|
586 |
gen_text_len = len(gen_text.encode("utf-8"))
|
587 |
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
|
588 |
|
|
|
589 |
with torch.inference_mode():
|
590 |
generated, _ = model_obj.sample(
|
591 |
cond=audio,
|
@@ -599,76 +483,79 @@ def infer_batch_process_stream(
|
|
599 |
generated = generated.to(torch.float32)
|
600 |
generated = generated[:, ref_audio_len:, :]
|
601 |
generated_mel_spec = generated.permute(0, 2, 1)
|
602 |
-
|
603 |
-
print(f"Generated mel spectrogram shape: {generated_mel_spec.shape}")
|
604 |
-
|
605 |
if mel_spec_type == "vocos":
|
606 |
generated_wave = vocoder.decode(generated_mel_spec)
|
607 |
elif mel_spec_type == "bigvgan":
|
608 |
generated_wave = vocoder(generated_mel_spec)
|
609 |
-
|
610 |
-
print(f"Generated wave shape before RMS adjustment: {generated_wave.shape}")
|
611 |
-
|
612 |
if rms < target_rms:
|
613 |
generated_wave = generated_wave * rms / target_rms
|
614 |
|
615 |
-
|
616 |
-
|
617 |
generated_wave = generated_wave.squeeze().cpu().numpy()
|
618 |
|
619 |
if streaming:
|
620 |
for j in range(0, len(generated_wave), chunk_size):
|
621 |
yield generated_wave[j : j + chunk_size], target_sample_rate
|
622 |
-
|
623 |
-
|
624 |
|
625 |
if streaming:
|
626 |
-
for
|
627 |
-
for chunk in process_batch(
|
628 |
yield chunk
|
629 |
else:
|
630 |
with ThreadPoolExecutor() as executor:
|
631 |
-
futures = [executor.submit(process_batch,
|
632 |
for future in progress.tqdm(futures) if progress else futures:
|
633 |
result = future.result()
|
634 |
if result:
|
635 |
-
generated_wave, generated_mel_spec = result
|
636 |
generated_waves.append(generated_wave)
|
637 |
spectrograms.append(generated_mel_spec)
|
638 |
|
639 |
if generated_waves:
|
640 |
if cross_fade_duration <= 0:
|
|
|
641 |
final_wave = np.concatenate(generated_waves)
|
642 |
else:
|
|
|
643 |
final_wave = generated_waves[0]
|
644 |
for i in range(1, len(generated_waves)):
|
645 |
prev_wave = final_wave
|
646 |
next_wave = generated_waves[i]
|
647 |
|
|
|
648 |
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
649 |
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
|
650 |
|
651 |
if cross_fade_samples <= 0:
|
|
|
652 |
final_wave = np.concatenate([prev_wave, next_wave])
|
653 |
continue
|
654 |
|
|
|
655 |
prev_overlap = prev_wave[-cross_fade_samples:]
|
656 |
next_overlap = next_wave[:cross_fade_samples]
|
657 |
|
|
|
658 |
fade_out = np.linspace(1, 0, cross_fade_samples)
|
659 |
fade_in = np.linspace(0, 1, cross_fade_samples)
|
660 |
|
|
|
661 |
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
662 |
|
|
|
663 |
new_wave = np.concatenate(
|
664 |
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
|
665 |
)
|
666 |
|
667 |
final_wave = new_wave
|
668 |
|
|
|
669 |
combined_spectrogram = np.concatenate(spectrograms, axis=1)
|
670 |
|
671 |
yield final_wave, target_sample_rate, combined_spectrogram
|
|
|
672 |
else:
|
673 |
yield None, target_sample_rate, None
|
674 |
|
|
|
390 |
print("\n")
|
391 |
|
392 |
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
|
393 |
+
return next(
|
394 |
+
infer_batch_process(
|
395 |
+
(audio, sr),
|
396 |
+
ref_text,
|
397 |
+
gen_text_batches,
|
398 |
+
model_obj,
|
399 |
+
vocoder,
|
400 |
+
mel_spec_type=mel_spec_type,
|
401 |
+
progress=progress,
|
402 |
+
target_rms=target_rms,
|
403 |
+
cross_fade_duration=cross_fade_duration,
|
404 |
+
nfe_step=nfe_step,
|
405 |
+
cfg_strength=cfg_strength,
|
406 |
+
sway_sampling_coef=sway_sampling_coef,
|
407 |
+
speed=speed,
|
408 |
+
fix_duration=fix_duration,
|
409 |
+
device=device,
|
410 |
+
)
|
411 |
)
|
412 |
|
413 |
|
|
|
430 |
speed=1,
|
431 |
fix_duration=None,
|
432 |
device=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
streaming=False,
|
434 |
chunk_size=2048,
|
435 |
):
|
|
|
445 |
audio = resampler(audio)
|
446 |
audio = audio.to(device)
|
447 |
|
|
|
|
|
|
|
448 |
generated_waves = []
|
449 |
spectrograms = []
|
450 |
|
451 |
+
if len(ref_text[-1].encode("utf-8")) == 1:
|
452 |
+
ref_text = ref_text + " "
|
453 |
|
454 |
+
def process_batch(gen_text):
|
455 |
local_speed = speed
|
456 |
+
if len(gen_text.encode("utf-8")) < 10:
|
457 |
local_speed = 0.3
|
458 |
|
459 |
+
# Prepare the text
|
460 |
text_list = [ref_text + gen_text]
|
461 |
final_text_list = convert_char_to_pinyin(text_list)
|
462 |
|
|
|
464 |
if fix_duration is not None:
|
465 |
duration = int(fix_duration * target_sample_rate / hop_length)
|
466 |
else:
|
467 |
+
# Calculate duration
|
468 |
ref_text_len = len(ref_text.encode("utf-8"))
|
469 |
gen_text_len = len(gen_text.encode("utf-8"))
|
470 |
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
|
471 |
|
472 |
+
# inference
|
473 |
with torch.inference_mode():
|
474 |
generated, _ = model_obj.sample(
|
475 |
cond=audio,
|
|
|
483 |
generated = generated.to(torch.float32)
|
484 |
generated = generated[:, ref_audio_len:, :]
|
485 |
generated_mel_spec = generated.permute(0, 2, 1)
|
|
|
|
|
|
|
486 |
if mel_spec_type == "vocos":
|
487 |
generated_wave = vocoder.decode(generated_mel_spec)
|
488 |
elif mel_spec_type == "bigvgan":
|
489 |
generated_wave = vocoder(generated_mel_spec)
|
|
|
|
|
|
|
490 |
if rms < target_rms:
|
491 |
generated_wave = generated_wave * rms / target_rms
|
492 |
|
493 |
+
# wav -> numpy
|
|
|
494 |
generated_wave = generated_wave.squeeze().cpu().numpy()
|
495 |
|
496 |
if streaming:
|
497 |
for j in range(0, len(generated_wave), chunk_size):
|
498 |
yield generated_wave[j : j + chunk_size], target_sample_rate
|
499 |
+
else:
|
500 |
+
yield generated_wave, generated_mel_spec[0].cpu().numpy()
|
501 |
|
502 |
if streaming:
|
503 |
+
for gen_text in progress.tqdm(gen_text_batches) if progress else gen_text_batches:
|
504 |
+
for chunk in process_batch(gen_text):
|
505 |
yield chunk
|
506 |
else:
|
507 |
with ThreadPoolExecutor() as executor:
|
508 |
+
futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches]
|
509 |
for future in progress.tqdm(futures) if progress else futures:
|
510 |
result = future.result()
|
511 |
if result:
|
512 |
+
generated_wave, generated_mel_spec = next(result)
|
513 |
generated_waves.append(generated_wave)
|
514 |
spectrograms.append(generated_mel_spec)
|
515 |
|
516 |
if generated_waves:
|
517 |
if cross_fade_duration <= 0:
|
518 |
+
# Simply concatenate
|
519 |
final_wave = np.concatenate(generated_waves)
|
520 |
else:
|
521 |
+
# Combine all generated waves with cross-fading
|
522 |
final_wave = generated_waves[0]
|
523 |
for i in range(1, len(generated_waves)):
|
524 |
prev_wave = final_wave
|
525 |
next_wave = generated_waves[i]
|
526 |
|
527 |
+
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
|
528 |
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
529 |
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
|
530 |
|
531 |
if cross_fade_samples <= 0:
|
532 |
+
# No overlap possible, concatenate
|
533 |
final_wave = np.concatenate([prev_wave, next_wave])
|
534 |
continue
|
535 |
|
536 |
+
# Overlapping parts
|
537 |
prev_overlap = prev_wave[-cross_fade_samples:]
|
538 |
next_overlap = next_wave[:cross_fade_samples]
|
539 |
|
540 |
+
# Fade out and fade in
|
541 |
fade_out = np.linspace(1, 0, cross_fade_samples)
|
542 |
fade_in = np.linspace(0, 1, cross_fade_samples)
|
543 |
|
544 |
+
# Cross-faded overlap
|
545 |
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
546 |
|
547 |
+
# Combine
|
548 |
new_wave = np.concatenate(
|
549 |
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
|
550 |
)
|
551 |
|
552 |
final_wave = new_wave
|
553 |
|
554 |
+
# Create a combined spectrogram
|
555 |
combined_spectrogram = np.concatenate(spectrograms, axis=1)
|
556 |
|
557 |
yield final_wave, target_sample_rate, combined_spectrogram
|
558 |
+
|
559 |
else:
|
560 |
yield None, target_sample_rate, None
|
561 |
|
src/f5_tts/socket_server.py
CHANGED
@@ -1,20 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import socket
|
2 |
import struct
|
|
|
|
|
|
|
|
|
|
|
3 |
import torch
|
4 |
import torchaudio
|
5 |
-
import logging
|
6 |
-
import wave
|
7 |
-
import numpy as np
|
8 |
-
import argparse
|
9 |
-
import traceback
|
10 |
-
import gc
|
11 |
-
import threading
|
12 |
-
import queue
|
13 |
-
from nltk.tokenize import sent_tokenize
|
14 |
-
from infer.utils_infer import preprocess_ref_audio_text, load_vocoder, load_model, infer_batch_process_stream
|
15 |
-
from model.backbones.dit import DiT
|
16 |
from huggingface_hub import hf_hub_download
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
logging.basicConfig(level=logging.INFO)
|
20 |
logger = logging.getLogger(__name__)
|
@@ -103,12 +114,13 @@ class TTSStreamingProcessor:
|
|
103 |
def _warm_up(self):
|
104 |
logger.info("Warming up the model...")
|
105 |
gen_text = "Warm-up text for the model."
|
106 |
-
for _ in
|
107 |
(self.audio, self.sr),
|
108 |
self.ref_text,
|
109 |
[gen_text],
|
110 |
self.model,
|
111 |
self.vocoder,
|
|
|
112 |
device=self.device,
|
113 |
streaming=True,
|
114 |
):
|
@@ -118,12 +130,13 @@ class TTSStreamingProcessor:
|
|
118 |
def generate_stream(self, text, conn):
|
119 |
text_batches = sent_tokenize(text)
|
120 |
|
121 |
-
audio_stream =
|
122 |
(self.audio, self.sr),
|
123 |
self.ref_text,
|
124 |
text_batches,
|
125 |
self.model,
|
126 |
self.vocoder,
|
|
|
127 |
device=self.device,
|
128 |
streaming=True,
|
129 |
chunk_size=2048,
|
|
|
1 |
+
import argparse
|
2 |
+
import gc
|
3 |
+
import logging
|
4 |
+
import numpy as np
|
5 |
+
import queue
|
6 |
import socket
|
7 |
import struct
|
8 |
+
import threading
|
9 |
+
import traceback
|
10 |
+
import wave
|
11 |
+
from importlib.resources import files
|
12 |
+
|
13 |
import torch
|
14 |
import torchaudio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
from huggingface_hub import hf_hub_download
|
16 |
+
|
17 |
+
import nltk
|
18 |
+
from nltk.tokenize import sent_tokenize
|
19 |
+
|
20 |
+
from f5_tts.model.backbones.dit import DiT
|
21 |
+
from f5_tts.infer.utils_infer import (
|
22 |
+
preprocess_ref_audio_text,
|
23 |
+
load_vocoder,
|
24 |
+
load_model,
|
25 |
+
infer_batch_process,
|
26 |
+
)
|
27 |
+
|
28 |
+
nltk.download("punkt_tab")
|
29 |
|
30 |
logging.basicConfig(level=logging.INFO)
|
31 |
logger = logging.getLogger(__name__)
|
|
|
114 |
def _warm_up(self):
|
115 |
logger.info("Warming up the model...")
|
116 |
gen_text = "Warm-up text for the model."
|
117 |
+
for _ in infer_batch_process(
|
118 |
(self.audio, self.sr),
|
119 |
self.ref_text,
|
120 |
[gen_text],
|
121 |
self.model,
|
122 |
self.vocoder,
|
123 |
+
progress=None,
|
124 |
device=self.device,
|
125 |
streaming=True,
|
126 |
):
|
|
|
130 |
def generate_stream(self, text, conn):
|
131 |
text_batches = sent_tokenize(text)
|
132 |
|
133 |
+
audio_stream = infer_batch_process(
|
134 |
(self.audio, self.sr),
|
135 |
self.ref_text,
|
136 |
text_batches,
|
137 |
self.model,
|
138 |
self.vocoder,
|
139 |
+
progress=None,
|
140 |
device=self.device,
|
141 |
streaming=True,
|
142 |
chunk_size=2048,
|