Tomtom84 commited on
Commit
ba50109
·
verified ·
1 Parent(s): c6d4d5b

Create base_engines.py

Browse files
Files changed (1) hide show
  1. engines/base_engines.py +299 -0
engines/base_engines.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module defines a base framework for speech synthesis engines. It includes:
3
+ - A TimingInfo class to capture timing details (start, end, and word) of audio segments.
4
+ - A BaseEngine abstract class (using a custom metaclass) that sets up default properties and common audio processing methods (such as applying fade-ins/outs and trimming silence) along with abstract methods for voice management and synthesis.
5
+ """
6
+
7
+ import torch.multiprocessing as mp
8
+ from abc import ABCMeta, ABC
9
+ from typing import Union
10
+ import numpy as np
11
+ import shutil
12
+ import queue
13
+
14
+ class TimingInfo:
15
+ def __init__(self, start_time, end_time, word):
16
+ self.start_time = start_time
17
+ self.end_time = end_time
18
+ self.word = word
19
+
20
+ def __str__(self):
21
+ return f"Word: {self.word}, Start Time: {self.start_time}, End Time: {self.end_time}"
22
+
23
+ # Define a meta class that will automatically call the BaseEngine's __init__ method
24
+ # and also the post_init method if it exists.
25
+ class BaseInitMeta(ABCMeta):
26
+ def __call__(cls, *args, **kwargs):
27
+ # Create an instance of the class that this meta class is used on.
28
+ instance = super().__call__(*args, **kwargs)
29
+
30
+ # Call the __init__ method of BaseEngine to set default properties.
31
+ BaseEngine.__init__(instance)
32
+
33
+ # If the instance has a post_init method, call it.
34
+ # This allows subclasses to define additional initialization steps.
35
+ if hasattr(instance, "post_init"):
36
+ instance.post_init()
37
+
38
+ return instance
39
+
40
+
41
+ # Define a base class for engines with the custom meta class.
42
+ class BaseEngine(ABC, metaclass=BaseInitMeta):
43
+ def __init__(self):
44
+ self.engine_name = "unknown"
45
+
46
+ # Indicates if the engine can handle generators.
47
+ self.can_consume_generators = False
48
+
49
+ # Queue to manage audio chunks for the engine.
50
+ self.queue = queue.Queue()
51
+
52
+ # Queue to manage word level timings for the engine.
53
+ self.timings = queue.Queue()
54
+
55
+ # Callback to be called when an audio chunk is available.
56
+ self.on_audio_chunk = None
57
+
58
+ # Callback to be called when the engine is starting to synthesize audio.
59
+ self.on_playback_start = None
60
+
61
+ self.stop_synthesis_event = mp.Event()
62
+
63
+ self.reset_audio_duration()
64
+
65
+ def reset_audio_duration(self):
66
+ """
67
+ Resets the audio duration to 0.
68
+ """
69
+ self.audio_duration = 0
70
+
71
+ def apply_fade_in(self, audio: np.ndarray, sample_rate: int = -1, fade_duration_ms: int = 15) -> np.ndarray:
72
+ """
73
+ Applies a linear fade-in over fade_duration_ms at the start of the audio.
74
+ """
75
+ sample_rate = self.verify_sample_rate(sample_rate)
76
+ audio = audio.copy()
77
+
78
+ fade_samples = int(sample_rate * fade_duration_ms / 1000)
79
+ if fade_samples == 0 or len(audio) < fade_samples:
80
+ fade_samples = len(audio)
81
+ fade_in = np.linspace(0.0, 1.0, fade_samples)
82
+ audio[:fade_samples] *= fade_in
83
+ return audio
84
+
85
+ def apply_fade_out(self, audio: np.ndarray, sample_rate: int = -1, fade_duration_ms: int = 15) -> np.ndarray:
86
+ """
87
+ Applies a linear fade-out over fade_duration_ms at the end of the audio.
88
+ """
89
+ sample_rate = self.verify_sample_rate(sample_rate)
90
+ audio = audio.copy()
91
+
92
+ fade_samples = int(sample_rate * fade_duration_ms / 1000)
93
+ if fade_samples == 0 or len(audio) < fade_samples:
94
+ fade_samples = len(audio)
95
+ fade_out = np.linspace(1.0, 0.0, fade_samples)
96
+ audio[-fade_samples:] *= fade_out
97
+ return audio
98
+
99
+ def trim_silence_start(
100
+ self,
101
+ audio_data: np.ndarray,
102
+ sample_rate: int = 24000,
103
+ silence_threshold: float = 0.01,
104
+ extra_ms: int = 25,
105
+ fade_in_ms: int = 15
106
+ ) -> np.ndarray:
107
+ """
108
+ Removes leading silence from audio_data, applies extra trimming, and fades-in if trimming occurred.
109
+
110
+ Args:
111
+ audio_data (np.ndarray): The audio data to process.
112
+ sample_rate (int): The sample rate of the audio data.
113
+ silence_threshold (float): The threshold for silence detection.
114
+ extra_ms (int): Additional milliseconds to trim from the start.
115
+ fade_in_ms (int): Milliseconds for fade-in effect.
116
+ """
117
+ sample_rate = self.verify_sample_rate(sample_rate)
118
+ trimmed = False
119
+ audio_data = audio_data.copy()
120
+ non_silent = np.where(np.abs(audio_data) > silence_threshold)[0]
121
+ if len(non_silent) > 0:
122
+ start_index = non_silent[0]
123
+ if start_index > 0:
124
+ trimmed = True
125
+ audio_data = audio_data[start_index:]
126
+
127
+ extra_samples = int(extra_ms * sample_rate / 1000)
128
+ if extra_samples > 0 and len(audio_data) > extra_samples:
129
+ audio_data = audio_data[extra_samples:]
130
+ trimmed = True
131
+
132
+ if trimmed:
133
+ audio_data = self.apply_fade_in(audio_data, sample_rate, fade_in_ms)
134
+ return audio_data
135
+
136
+ def trim_silence_end(
137
+ self,
138
+ audio_data: np.ndarray,
139
+ sample_rate: int = -1,
140
+ silence_threshold: float = 0.01,
141
+ extra_ms: int = 50,
142
+ fade_out_ms: int = 15
143
+ ) -> np.ndarray:
144
+ """
145
+ Removes trailing silence from audio_data, applies extra trimming, and fades-out if trimming occurred.
146
+
147
+ Args:
148
+ audio_data (np.ndarray): The audio data to be trimmed.
149
+ sample_rate (int): The sample rate of the audio data. Default is -1.
150
+ silence_threshold (float): The threshold below which audio is considered silent. Default is 0.01.
151
+ extra_ms (int): Extra milliseconds to trim from the end of the audio. Default is 50.
152
+ fade_out_ms (int): Milliseconds for fade-out effect at the end of the audio. Default is 15.
153
+ """
154
+ sample_rate = self.verify_sample_rate(sample_rate)
155
+ trimmed = False
156
+ audio_data = audio_data.copy()
157
+ non_silent = np.where(np.abs(audio_data) > silence_threshold)[0]
158
+ if len(non_silent) > 0:
159
+ end_index = non_silent[-1] + 1
160
+ if end_index < len(audio_data):
161
+ trimmed = True
162
+ audio_data = audio_data[:end_index]
163
+
164
+ extra_samples = int(extra_ms * sample_rate / 1000)
165
+ if extra_samples > 0 and len(audio_data) > extra_samples:
166
+ audio_data = audio_data[:-extra_samples]
167
+ trimmed = True
168
+
169
+ if trimmed:
170
+ audio_data = self.apply_fade_out(audio_data, sample_rate, fade_out_ms)
171
+ return audio_data
172
+
173
+ def verify_sample_rate(self, sample_rate: int) -> int:
174
+ """
175
+ Verifies and returns the sample rate.
176
+ If the sample rate is -1, it will be obtained from the engine's configuration.
177
+ """
178
+ if sample_rate == -1:
179
+ _, _, sample_rate = self.get_stream_info()
180
+ if sample_rate == -1:
181
+ raise ValueError("Sample rate must be provided or obtained from get_stream_info.")
182
+ return sample_rate
183
+
184
+ def _trim_silence(
185
+ self,
186
+ audio_data: np.ndarray,
187
+ sample_rate: int = -1,
188
+ silence_threshold: float = 0.005,
189
+ extra_start_ms: int = 15,
190
+ extra_end_ms: int = 15,
191
+ fade_in_ms: int = 10,
192
+ fade_out_ms: int = 10
193
+ ) -> np.ndarray:
194
+ """
195
+ Removes silence from both the start and end of audio_data.
196
+ If trimming occurs on either end, the corresponding fade is applied.
197
+ """
198
+ sample_rate = self.verify_sample_rate(sample_rate)
199
+
200
+ audio_data = self.trim_silence_start(
201
+ audio_data, sample_rate, silence_threshold, extra_start_ms, fade_in_ms
202
+ )
203
+ audio_data = self.trim_silence_end(
204
+ audio_data, sample_rate, silence_threshold, extra_end_ms, fade_out_ms
205
+ )
206
+ return audio_data
207
+
208
+
209
+ def get_stream_info(self):
210
+ """
211
+ Returns the audio stream configuration information suitable for PyAudio.
212
+
213
+ Returns:
214
+ tuple: A tuple containing the audio format, number of channels, and the sample rate.
215
+ - Format (int): The format of the audio stream. pyaudio.paInt16 represents 16-bit integers.
216
+ - Channels (int): The number of audio channels. 1 represents mono audio.
217
+ - Sample Rate (int): The sample rate of the audio in Hz. 16000 represents 16kHz sample rate.
218
+ """
219
+ raise NotImplementedError(
220
+ "The get_stream_info method must be implemented by the derived class."
221
+ )
222
+
223
+ def synthesize(self, text: str) -> bool:
224
+ """
225
+ Synthesizes text to audio stream.
226
+
227
+ Args:
228
+ text (str): Text to synthesize.
229
+ """
230
+ self.stop_synthesis_event.clear()
231
+
232
+ def get_voices(self):
233
+ """
234
+ Retrieves the voices available from the specific voice source.
235
+
236
+ This method should be overridden by the derived class to fetch the list of available voices.
237
+
238
+ Returns:
239
+ list: A list containing voice objects representing each available voice.
240
+ """
241
+ raise NotImplementedError(
242
+ "The get_voices method must be implemented by the derived class."
243
+ )
244
+
245
+ def set_voice(self, voice: Union[str, object]):
246
+ """
247
+ Sets the voice to be used for speech synthesis.
248
+
249
+ Args:
250
+ voice (Union[str, object]): The voice to be used for speech synthesis.
251
+
252
+ This method should be overridden by the derived class to set the desired voice.
253
+ """
254
+ raise NotImplementedError(
255
+ "The set_voice method must be implemented by the derived class."
256
+ )
257
+
258
+ def set_voice_parameters(self, **voice_parameters):
259
+ """
260
+ Sets the voice parameters to be used for speech synthesis.
261
+
262
+ Args:
263
+ **voice_parameters: The voice parameters to be used for speech synthesis.
264
+
265
+ This method should be overridden by the derived class to set the desired voice parameters.
266
+ """
267
+ raise NotImplementedError(
268
+ "The set_voice_parameters method must be implemented by the derived class."
269
+ )
270
+
271
+ def shutdown(self):
272
+ """
273
+ Shuts down the engine.
274
+ """
275
+ pass
276
+
277
+ def is_installed(self, lib_name: str) -> bool:
278
+ """
279
+ Check if the given library or software is installed and accessible.
280
+
281
+ This method uses shutil.which to determine if the given library or software is
282
+ installed and available in the system's PATH.
283
+
284
+ Args:
285
+ lib_name (str): Name of the library or software to check.
286
+
287
+ Returns:
288
+ bool: True if the library is installed, otherwise False.
289
+ """
290
+ lib = shutil.which(lib_name)
291
+ if lib is None:
292
+ return False
293
+ return True
294
+
295
+ def stop(self):
296
+ """
297
+ Stops the engine.
298
+ """
299
+ self.stop_synthesis_event.set()