Yushen CHEN commited on
Commit
1085b73
·
unverified ·
2 Parent(s): ea90244 6e24f1e

Merge pull request #354 from kunci115/main

Browse files
Files changed (2) hide show
  1. src/f5_tts/infer/README.md +71 -1
  2. src/f5_tts/socket.py +154 -0
src/f5_tts/infer/README.md CHANGED
@@ -113,4 +113,74 @@ To test speech editing capabilities, use the following command:
113
 
114
  ```bash
115
  python src/f5_tts/infer/speech_edit.py
116
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  ```bash
115
  python src/f5_tts/infer/speech_edit.py
116
+ ```
117
+
118
+ ## Socket Realtime Client
119
+
120
+ To communicate with socket server you need to run
121
+ ```bash
122
+ python src/f5_tts/socket.py
123
+ ```
124
+
125
+ then create client to communicate
126
+
127
+ ``` python
128
+ import socket
129
+ import numpy as np
130
+ import asyncio
131
+ import pyaudio
132
+
133
+ async def listen_to_voice(text, server_ip='localhost', server_port=9999):
134
+ client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
135
+ client_socket.connect((server_ip, server_port))
136
+
137
+ async def play_audio_stream():
138
+ buffer = b''
139
+ p = pyaudio.PyAudio()
140
+ stream = p.open(format=pyaudio.paFloat32,
141
+ channels=1,
142
+ rate=24000, # Ensure this matches the server's sampling rate
143
+ output=True,
144
+ frames_per_buffer=2048)
145
+
146
+ try:
147
+ while True:
148
+ chunk = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 1024)
149
+ if not chunk: # End of stream
150
+ break
151
+ if b"END_OF_AUDIO" in chunk:
152
+ buffer += chunk.replace(b"END_OF_AUDIO", b"")
153
+ if buffer:
154
+ audio_array = np.frombuffer(buffer, dtype=np.float32).copy() # Make a writable copy
155
+ stream.write(audio_array.tobytes())
156
+ break
157
+ buffer += chunk
158
+ if len(buffer) >= 4096:
159
+ audio_array = np.frombuffer(buffer[:4096], dtype=np.float32).copy() # Make a writable copy
160
+ stream.write(audio_array.tobytes())
161
+ buffer = buffer[4096:]
162
+ finally:
163
+ stream.stop_stream()
164
+ stream.close()
165
+ p.terminate()
166
+
167
+ try:
168
+ # Send only the text to the server
169
+ await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, text.encode('utf-8'))
170
+ await play_audio_stream()
171
+ print("Audio playback finished.")
172
+
173
+ except Exception as e:
174
+ print(f"Error in listen_to_voice: {e}")
175
+
176
+ finally:
177
+ client_socket.close()
178
+
179
+ # Example usage: Replace this with your actual server IP and port
180
+ async def main():
181
+ await listen_to_voice("my name is jenny..", server_ip='localhost', server_port=9998)
182
+
183
+ # Run the main async function
184
+ asyncio.run(main())
185
+ ```
186
+
src/f5_tts/socket.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import socket
2
+ import struct
3
+ import torch
4
+ import torchaudio
5
+ from threading import Thread
6
+
7
+
8
+ import gc
9
+ import traceback
10
+
11
+
12
+ from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
13
+ from model.backbones.dit import DiT
14
+
15
+
16
+ class TTSStreamingProcessor:
17
+ def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
18
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # Load the model using the provided checkpoint and vocab files
21
+ self.model = load_model(
22
+ DiT,
23
+ dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
24
+ ckpt_file,
25
+ vocab_file
26
+ ).to(self.device, dtype=dtype)
27
+
28
+ # Load the vocoder
29
+ self.vocoder = load_vocoder(is_local=False)
30
+
31
+ # Set sampling rate for streaming
32
+ self.sampling_rate = 24000 # Consistency with client
33
+
34
+ # Set reference audio and text
35
+ self.ref_audio = ref_audio
36
+ self.ref_text = ref_text
37
+
38
+ # Warm up the model
39
+ self._warm_up()
40
+
41
+ def _warm_up(self):
42
+ """Warm up the model with a dummy input to ensure it's ready for real-time processing."""
43
+ print("Warming up the model...")
44
+ ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
45
+ audio, sr = torchaudio.load(ref_audio)
46
+ gen_text = "Warm-up text for the model."
47
+
48
+ # Pass the vocoder as an argument here
49
+ infer_batch_process((audio, sr), ref_text, [gen_text], self.model, self.vocoder, device=self.device)
50
+ print("Warm-up completed.")
51
+
52
+ def generate_stream(self, text, play_steps_in_s=0.5):
53
+ """Generate audio in chunks and yield them in real-time."""
54
+ # Preprocess the reference audio and text
55
+ ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
56
+
57
+ # Load reference audio
58
+ audio, sr = torchaudio.load(ref_audio)
59
+
60
+ # Run inference for the input text
61
+ audio_chunk, final_sample_rate, _ = infer_batch_process(
62
+ (audio, sr), ref_text, [text], self.model, self.vocoder, device=self.device # Pass vocoder here
63
+ )
64
+
65
+ # Break the generated audio into chunks and send them
66
+ chunk_size = int(final_sample_rate * play_steps_in_s)
67
+
68
+ for i in range(0, len(audio_chunk), chunk_size):
69
+ chunk = audio_chunk[i:i + chunk_size]
70
+
71
+ # Check if it's the final chunk
72
+ if i + chunk_size >= len(audio_chunk):
73
+ chunk = audio_chunk[i:]
74
+
75
+ # Avoid sending empty or repeated chunks
76
+ if len(chunk) == 0:
77
+ break
78
+
79
+ # Pack and send the audio chunk
80
+ packed_audio = struct.pack(f'{len(chunk)}f', *chunk)
81
+ yield packed_audio
82
+
83
+ # Ensure that no final word is repeated by not resending partial chunks
84
+ if len(audio_chunk) % chunk_size != 0:
85
+ remaining_chunk = audio_chunk[-(len(audio_chunk) % chunk_size):]
86
+ packed_audio = struct.pack(f'{len(remaining_chunk)}f', *remaining_chunk)
87
+ yield packed_audio
88
+
89
+
90
+ def handle_client(client_socket, processor):
91
+ try:
92
+ while True:
93
+ # Receive data from the client
94
+ data = client_socket.recv(1024).decode("utf-8")
95
+ if not data:
96
+ break
97
+
98
+ try:
99
+ # The client sends the text input
100
+ text = data.strip()
101
+
102
+ # Generate and stream audio chunks
103
+ for audio_chunk in processor.generate_stream(text):
104
+ client_socket.sendall(audio_chunk)
105
+
106
+ # Send end-of-audio signal
107
+ client_socket.sendall(b"END_OF_AUDIO")
108
+
109
+ except Exception as inner_e:
110
+ print(f"Error during processing: {inner_e}")
111
+ traceback.print_exc() # Print the full traceback to diagnose the issue
112
+ break
113
+
114
+ except Exception as e:
115
+ print(f"Error handling client: {e}")
116
+ traceback.print_exc()
117
+ finally:
118
+ client_socket.close()
119
+
120
+
121
+ def start_server(host, port, processor):
122
+ server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
123
+ server.bind((host, port))
124
+ server.listen(5)
125
+ print(f"Server listening on {host}:{port}")
126
+
127
+ while True:
128
+ client_socket, addr = server.accept()
129
+ print(f"Accepted connection from {addr}")
130
+ client_handler = Thread(target=handle_client, args=(client_socket, processor))
131
+ client_handler.start()
132
+
133
+
134
+ if __name__ == "__main__":
135
+ try:
136
+ # Load the model and vocoder using the provided files
137
+ ckpt_file = "" # pointing your checkpoint "ckpts/model/model_1096.pt"
138
+ vocab_file = "" # Add vocab file path if needed
139
+ ref_audio ="" # add ref audio"./tests/ref_audio/reference.wav"
140
+ ref_text = ""
141
+
142
+ # Initialize the processor with the model and vocoder
143
+ processor = TTSStreamingProcessor(
144
+ ckpt_file=ckpt_file,
145
+ vocab_file=vocab_file,
146
+ ref_audio=ref_audio,
147
+ ref_text=ref_text,
148
+ dtype=torch.float32
149
+ )
150
+
151
+ # Start the server
152
+ start_server("0.0.0.0", 9998, processor)
153
+ except KeyboardInterrupt:
154
+ gc.collect()