SWivid commited on
Commit
dfe1d95
·
1 Parent(s): ca08681

0.6.2 support socket_server.py with general text chunk

Browse files
Files changed (2) hide show
  1. pyproject.toml +1 -1
  2. src/f5_tts/socket_server.py +14 -6
pyproject.toml CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
 
5
  [project]
6
  name = "f5-tts"
7
- version = "0.6.1"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
 
4
 
5
  [project]
6
  name = "f5-tts"
7
+ version = "0.6.2"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
src/f5_tts/socket_server.py CHANGED
@@ -14,19 +14,15 @@ import torch
14
  import torchaudio
15
  from huggingface_hub import hf_hub_download
16
 
17
- import nltk
18
- from nltk.tokenize import sent_tokenize
19
-
20
  from f5_tts.model.backbones.dit import DiT
21
  from f5_tts.infer.utils_infer import (
 
22
  preprocess_ref_audio_text,
23
  load_vocoder,
24
  load_model,
25
  infer_batch_process,
26
  )
27
 
28
- nltk.download("punkt_tab")
29
-
30
  logging.basicConfig(level=logging.INFO)
31
  logger = logging.getLogger(__name__)
32
 
@@ -89,6 +85,7 @@ class TTSStreamingProcessor:
89
  self.update_reference(ref_audio, ref_text)
90
  self._warm_up()
91
  self.file_writer_thread = None
 
92
 
93
  def load_ema_model(self, ckpt_file, vocab_file, dtype):
94
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
@@ -111,6 +108,12 @@ class TTSStreamingProcessor:
111
  self.ref_audio, self.ref_text = preprocess_ref_audio_text(ref_audio, ref_text)
112
  self.audio, self.sr = torchaudio.load(self.ref_audio)
113
 
 
 
 
 
 
 
114
  def _warm_up(self):
115
  logger.info("Warming up the model...")
116
  gen_text = "Warm-up text for the model."
@@ -128,7 +131,11 @@ class TTSStreamingProcessor:
128
  logger.info("Warm-up completed.")
129
 
130
  def generate_stream(self, text, conn):
131
- text_batches = sent_tokenize(text)
 
 
 
 
132
 
133
  audio_stream = infer_batch_process(
134
  (self.audio, self.sr),
@@ -172,6 +179,7 @@ def handle_client(conn, processor):
172
  while True:
173
  data = conn.recv(1024)
174
  if not data:
 
175
  break
176
  data_str = data.decode("utf-8").strip()
177
  logger.info(f"Received text: {data_str}")
 
14
  import torchaudio
15
  from huggingface_hub import hf_hub_download
16
 
 
 
 
17
  from f5_tts.model.backbones.dit import DiT
18
  from f5_tts.infer.utils_infer import (
19
+ chunk_text,
20
  preprocess_ref_audio_text,
21
  load_vocoder,
22
  load_model,
23
  infer_batch_process,
24
  )
25
 
 
 
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
28
 
 
85
  self.update_reference(ref_audio, ref_text)
86
  self._warm_up()
87
  self.file_writer_thread = None
88
+ self.first_package = True
89
 
90
  def load_ema_model(self, ckpt_file, vocab_file, dtype):
91
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
 
108
  self.ref_audio, self.ref_text = preprocess_ref_audio_text(ref_audio, ref_text)
109
  self.audio, self.sr = torchaudio.load(self.ref_audio)
110
 
111
+ ref_audio_duration = self.audio.shape[-1] / self.sr
112
+ ref_text_byte_len = len(self.ref_text.encode("utf-8"))
113
+ self.max_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration))
114
+ self.few_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 2)
115
+ self.min_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 4)
116
+
117
  def _warm_up(self):
118
  logger.info("Warming up the model...")
119
  gen_text = "Warm-up text for the model."
 
131
  logger.info("Warm-up completed.")
132
 
133
  def generate_stream(self, text, conn):
134
+ text_batches = chunk_text(text, max_chars=self.max_chars)
135
+ if self.first_package:
136
+ text_batches = chunk_text(text_batches[0], max_chars=self.few_chars) + text_batches[1:]
137
+ text_batches = chunk_text(text_batches[0], max_chars=self.min_chars) + text_batches[1:]
138
+ self.first_package = False
139
 
140
  audio_stream = infer_batch_process(
141
  (self.audio, self.sr),
 
179
  while True:
180
  data = conn.recv(1024)
181
  if not data:
182
+ processor.first_package = True
183
  break
184
  data_str = data.decode("utf-8").strip()
185
  logger.info(f"Received text: {data_str}")