kunci115 commited on
Commit
1e3fac8
·
unverified ·
1 Parent(s): dee0420

[add] socket.py

Browse files

to play stream socket mode

Files changed (1) hide show
  1. src/f5_tts/socket.py +155 -0
src/f5_tts/socket.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import socket
2
+ import struct
3
+ import torch
4
+ import torchaudio
5
+ from threading import Thread
6
+
7
+
8
+ import gc
9
+ import traceback
10
+
11
+
12
+ from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
13
+ from model.backbones.dit import DiT
14
+
15
+
16
+ class TTSStreamingProcessor:
17
+ def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
18
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # Load the model using the provided checkpoint and vocab files
21
+ self.model = load_model(
22
+ DiT,
23
+ dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
24
+ ckpt_file,
25
+ vocab_file
26
+ ).to(self.device, dtype=dtype)
27
+
28
+ # Load the vocoder
29
+ self.vocoder = load_vocoder(is_local=False)
30
+
31
+ # Set sampling rate for streaming
32
+ self.sampling_rate = 24000 # Consistency with client
33
+
34
+ # Set reference audio and text
35
+ self.ref_audio = ref_audio
36
+ self.ref_text = ref_text
37
+
38
+ # Warm up the model
39
+ self._warm_up()
40
+
41
+ def _warm_up(self):
42
+ """Warm up the model with a dummy input to ensure it's ready for real-time processing."""
43
+ print("Warming up the model...")
44
+ ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
45
+ audio, sr = torchaudio.load(ref_audio)
46
+ gen_text = "Warm-up text for the model."
47
+
48
+ # Pass the vocoder as an argument here
49
+ infer_batch_process((audio, sr), ref_text, [gen_text], self.model, self.vocoder, device=self.device)
50
+ print("Warm-up completed.")
51
+
52
+ def generate_stream(self, text, play_steps_in_s=0.5):
53
+ """Generate audio in chunks and yield them in real-time."""
54
+ # Preprocess the reference audio and text
55
+ ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
56
+
57
+ # Load reference audio
58
+ audio, sr = torchaudio.load(ref_audio)
59
+
60
+ # Run inference for the input text
61
+ audio_chunk, final_sample_rate, _ = infer_batch_process(
62
+ (audio, sr), ref_text, [text], self.model, self.vocoder, device=self.device # Pass vocoder here
63
+ )
64
+
65
+ # Break the generated audio into chunks and send them
66
+ chunk_size = int(final_sample_rate * play_steps_in_s)
67
+ total_chunks = len(audio_chunk) // chunk_size + (1 if len(audio_chunk) % chunk_size != 0 else 0)
68
+
69
+ for i in range(0, len(audio_chunk), chunk_size):
70
+ chunk = audio_chunk[i:i + chunk_size]
71
+
72
+ # Check if it's the final chunk
73
+ if i + chunk_size >= len(audio_chunk):
74
+ chunk = audio_chunk[i:]
75
+
76
+ # Avoid sending empty or repeated chunks
77
+ if len(chunk) == 0:
78
+ break
79
+
80
+ # Pack and send the audio chunk
81
+ packed_audio = struct.pack(f'{len(chunk)}f', *chunk)
82
+ yield packed_audio
83
+
84
+ # Ensure that no final word is repeated by not resending partial chunks
85
+ if len(audio_chunk) % chunk_size != 0:
86
+ remaining_chunk = audio_chunk[-(len(audio_chunk) % chunk_size):]
87
+ packed_audio = struct.pack(f'{len(remaining_chunk)}f', *remaining_chunk)
88
+ yield packed_audio
89
+
90
+
91
+ def handle_client(client_socket, processor):
92
+ try:
93
+ while True:
94
+ # Receive data from the client
95
+ data = client_socket.recv(1024).decode("utf-8")
96
+ if not data:
97
+ break
98
+
99
+ try:
100
+ # The client sends the text input
101
+ text = data.strip()
102
+
103
+ # Generate and stream audio chunks
104
+ for audio_chunk in processor.generate_stream(text):
105
+ client_socket.sendall(audio_chunk)
106
+
107
+ # Send end-of-audio signal
108
+ client_socket.sendall(b"END_OF_AUDIO")
109
+
110
+ except Exception as inner_e:
111
+ print(f"Error during processing: {inner_e}")
112
+ traceback.print_exc() # Print the full traceback to diagnose the issue
113
+ break
114
+
115
+ except Exception as e:
116
+ print(f"Error handling client: {e}")
117
+ traceback.print_exc()
118
+ finally:
119
+ client_socket.close()
120
+
121
+
122
+ def start_server(host, port, processor):
123
+ server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
124
+ server.bind((host, port))
125
+ server.listen(5)
126
+ print(f"Server listening on {host}:{port}")
127
+
128
+ while True:
129
+ client_socket, addr = server.accept()
130
+ print(f"Accepted connection from {addr}")
131
+ client_handler = Thread(target=handle_client, args=(client_socket, processor))
132
+ client_handler.start()
133
+
134
+
135
+ if __name__ == "__main__":
136
+ try:
137
+ # Load the model and vocoder using the provided files
138
+ ckpt_file = "ckpts/model/model_1096.pt" #pointing your checkpoints
139
+ vocab_file = "" # Add vocab file path if needed
140
+ ref_audio = "./tests/ref_audio/reference.wav"
141
+ ref_text = ""
142
+
143
+ # Initialize the processor with the model and vocoder
144
+ processor = TTSStreamingProcessor(
145
+ ckpt_file=ckpt_file,
146
+ vocab_file=vocab_file,
147
+ ref_audio=ref_audio,
148
+ ref_text=ref_text,
149
+ dtype=torch.float32
150
+ )
151
+
152
+ # Start the server
153
+ start_server("0.0.0.0", 9998, processor)
154
+ except KeyboardInterrupt:
155
+ gc.collect()