0.6.2 support socket_server.py with general text chunk
Browse files- pyproject.toml +1 -1
- src/f5_tts/socket_server.py +14 -6
pyproject.toml
CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4 |
|
5 |
[project]
|
6 |
name = "f5-tts"
|
7 |
-
version = "0.6.
|
8 |
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
9 |
readme = "README.md"
|
10 |
license = {text = "MIT License"}
|
|
|
4 |
|
5 |
[project]
|
6 |
name = "f5-tts"
|
7 |
+
version = "0.6.2"
|
8 |
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
9 |
readme = "README.md"
|
10 |
license = {text = "MIT License"}
|
src/f5_tts/socket_server.py
CHANGED
@@ -14,19 +14,15 @@ import torch
|
|
14 |
import torchaudio
|
15 |
from huggingface_hub import hf_hub_download
|
16 |
|
17 |
-
import nltk
|
18 |
-
from nltk.tokenize import sent_tokenize
|
19 |
-
|
20 |
from f5_tts.model.backbones.dit import DiT
|
21 |
from f5_tts.infer.utils_infer import (
|
|
|
22 |
preprocess_ref_audio_text,
|
23 |
load_vocoder,
|
24 |
load_model,
|
25 |
infer_batch_process,
|
26 |
)
|
27 |
|
28 |
-
nltk.download("punkt_tab")
|
29 |
-
|
30 |
logging.basicConfig(level=logging.INFO)
|
31 |
logger = logging.getLogger(__name__)
|
32 |
|
@@ -89,6 +85,7 @@ class TTSStreamingProcessor:
|
|
89 |
self.update_reference(ref_audio, ref_text)
|
90 |
self._warm_up()
|
91 |
self.file_writer_thread = None
|
|
|
92 |
|
93 |
def load_ema_model(self, ckpt_file, vocab_file, dtype):
|
94 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
@@ -111,6 +108,12 @@ class TTSStreamingProcessor:
|
|
111 |
self.ref_audio, self.ref_text = preprocess_ref_audio_text(ref_audio, ref_text)
|
112 |
self.audio, self.sr = torchaudio.load(self.ref_audio)
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
def _warm_up(self):
|
115 |
logger.info("Warming up the model...")
|
116 |
gen_text = "Warm-up text for the model."
|
@@ -128,7 +131,11 @@ class TTSStreamingProcessor:
|
|
128 |
logger.info("Warm-up completed.")
|
129 |
|
130 |
def generate_stream(self, text, conn):
|
131 |
-
text_batches =
|
|
|
|
|
|
|
|
|
132 |
|
133 |
audio_stream = infer_batch_process(
|
134 |
(self.audio, self.sr),
|
@@ -172,6 +179,7 @@ def handle_client(conn, processor):
|
|
172 |
while True:
|
173 |
data = conn.recv(1024)
|
174 |
if not data:
|
|
|
175 |
break
|
176 |
data_str = data.decode("utf-8").strip()
|
177 |
logger.info(f"Received text: {data_str}")
|
|
|
14 |
import torchaudio
|
15 |
from huggingface_hub import hf_hub_download
|
16 |
|
|
|
|
|
|
|
17 |
from f5_tts.model.backbones.dit import DiT
|
18 |
from f5_tts.infer.utils_infer import (
|
19 |
+
chunk_text,
|
20 |
preprocess_ref_audio_text,
|
21 |
load_vocoder,
|
22 |
load_model,
|
23 |
infer_batch_process,
|
24 |
)
|
25 |
|
|
|
|
|
26 |
logging.basicConfig(level=logging.INFO)
|
27 |
logger = logging.getLogger(__name__)
|
28 |
|
|
|
85 |
self.update_reference(ref_audio, ref_text)
|
86 |
self._warm_up()
|
87 |
self.file_writer_thread = None
|
88 |
+
self.first_package = True
|
89 |
|
90 |
def load_ema_model(self, ckpt_file, vocab_file, dtype):
|
91 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
|
|
108 |
self.ref_audio, self.ref_text = preprocess_ref_audio_text(ref_audio, ref_text)
|
109 |
self.audio, self.sr = torchaudio.load(self.ref_audio)
|
110 |
|
111 |
+
ref_audio_duration = self.audio.shape[-1] / self.sr
|
112 |
+
ref_text_byte_len = len(self.ref_text.encode("utf-8"))
|
113 |
+
self.max_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration))
|
114 |
+
self.few_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 2)
|
115 |
+
self.min_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 4)
|
116 |
+
|
117 |
def _warm_up(self):
|
118 |
logger.info("Warming up the model...")
|
119 |
gen_text = "Warm-up text for the model."
|
|
|
131 |
logger.info("Warm-up completed.")
|
132 |
|
133 |
def generate_stream(self, text, conn):
|
134 |
+
text_batches = chunk_text(text, max_chars=self.max_chars)
|
135 |
+
if self.first_package:
|
136 |
+
text_batches = chunk_text(text_batches[0], max_chars=self.few_chars) + text_batches[1:]
|
137 |
+
text_batches = chunk_text(text_batches[0], max_chars=self.min_chars) + text_batches[1:]
|
138 |
+
self.first_package = False
|
139 |
|
140 |
audio_stream = infer_batch_process(
|
141 |
(self.audio, self.sr),
|
|
|
179 |
while True:
|
180 |
data = conn.recv(1024)
|
181 |
if not data:
|
182 |
+
processor.first_package = True
|
183 |
break
|
184 |
data_str = data.decode("utf-8").strip()
|
185 |
logger.info(f"Received text: {data_str}")
|