Tomtom84 commited on
Commit
5ef7cfa
·
verified ·
1 Parent(s): 4715aa2

Create orpheus_engine.py

Browse files
Files changed (1) hide show
  1. orpheus_engine.py +373 -0
orpheus_engine.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import logging
4
+ import pyaudio
5
+ import requests
6
+ import traceback
7
+ import numpy as np
8
+ from queue import Queue
9
+ from typing import Optional, Union
10
+ from .base_engine import BaseEngine
11
+
12
+ # Default configuration values
13
+ DEFAULT_API_URL = "http://127.0.0.1:1234/v1/completions"
14
+ DEFAULT_HEADERS = {"Content-Type": "application/json"}
15
+ DEFAULT_MODEL = "orpheus-3b-0.1-ft"
16
+ DEFAULT_VOICE = "tara"
17
+ AVAILABLE_VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
18
+ SAMPLE_RATE = 24000 # Specific sample rate for Orpheus
19
+
20
+ # Special token definitions for prompt formatting and token decoding
21
+ START_TOKEN_ID = 128259
22
+ END_TOKEN_IDS = [128009, 128260, 128261, 128257]
23
+ CUSTOM_TOKEN_PREFIX = "<custom_token_"
24
+
25
+
26
+ class OrpheusVoice:
27
+ """
28
+ Represents the configuration for an Orpheus voice.
29
+
30
+ Attributes:
31
+ name (str): The name of the voice. Must be one of the AVAILABLE_VOICES.
32
+
33
+ Raises:
34
+ ValueError: If the voice name provided is not in AVAILABLE_VOICES.
35
+ """
36
+ def __init__(self, name: str):
37
+ # if name not in AVAILABLE_VOICES:
38
+ # raise ValueError(f"Invalid voice '{name}'. Available voices: {AVAILABLE_VOICES}")
39
+ self.name = name
40
+
41
+ def __repr__(self):
42
+ return f"OrpheusVoice(name='{self.name}')"
43
+
44
+
45
+ class OrpheusEngine(BaseEngine):
46
+ """
47
+ Real-time Text-to-Speech (TTS) engine for the Orpheus model via LM Studio API.
48
+
49
+ This engine supports real-time token generation, audio synthesis, and voice configuration.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ api_url: str = DEFAULT_API_URL,
55
+ model: str = DEFAULT_MODEL,
56
+ headers: dict = DEFAULT_HEADERS,
57
+ voice: Optional[OrpheusVoice] = None,
58
+ temperature: float = 0.6,
59
+ top_p: float = 0.9,
60
+ max_tokens: int = 1200,
61
+ repetition_penalty: float = 1.1,
62
+ debug: bool = False
63
+ ):
64
+ """
65
+ Initialize the Orpheus TTS engine with the given parameters.
66
+
67
+ Args:
68
+ api_url (str): Endpoint URL for the LM Studio API.
69
+ model (str): Model name to use for synthesis.
70
+ headers (dict): HTTP headers for API requests.
71
+ voice (Optional[OrpheusVoice]): OrpheusVoice configuration. Defaults to DEFAULT_VOICE.
72
+ temperature (float): Sampling temperature (0-1) for text generation.
73
+ top_p (float): Top-p sampling parameter for controlling diversity.
74
+ max_tokens (int): Maximum tokens to generate per API request.
75
+ repetition_penalty (float): Penalty factor for repeated phrases.
76
+ debug (bool): Flag to enable debug output.
77
+ """
78
+ super().__init__()
79
+ self.api_url = api_url
80
+ self.model = model
81
+ self.headers = headers
82
+ self.voice = voice or OrpheusVoice(DEFAULT_VOICE)
83
+ self.temperature = temperature
84
+ self.top_p = top_p
85
+ self.max_tokens = max_tokens
86
+ self.repetition_penalty = repetition_penalty
87
+ self.debug = debug
88
+ self.queue = Queue()
89
+ self.post_init()
90
+
91
+ def post_init(self):
92
+ """Set up additional engine attributes."""
93
+ self.engine_name = "orpheus"
94
+
95
+ def get_stream_info(self):
96
+ """
97
+ Retrieve PyAudio stream configuration.
98
+
99
+ Returns:
100
+ tuple: Format, channel count, and sample rate for PyAudio.
101
+ """
102
+ return pyaudio.paInt16, 1, SAMPLE_RATE
103
+
104
+ def synthesize(self, text: str) -> bool:
105
+ """
106
+ Convert text to speech and stream audio data.
107
+
108
+ Args:
109
+ text (str): The input text to be synthesized.
110
+
111
+ Returns:
112
+ bool: True if synthesis was successful, False otherwise.
113
+ """
114
+ super().synthesize(text)
115
+
116
+ try:
117
+ # Process tokens and put generated audio chunks into the queue
118
+ for audio_chunk in self._token_decoder(self._generate_tokens(text)):
119
+ # bail out immediately if someone called .stop()
120
+ if self.stop_synthesis_event.is_set():
121
+ logging.info("OrpheusEngine: synthesis stopped by user")
122
+ return False
123
+ print(f"Audio chunk size: {len(audio_chunk)}")
124
+ self.queue.put(audio_chunk)
125
+ return True
126
+ except Exception as e:
127
+ traceback.print_exc()
128
+ logging.error(f"Synthesis error: {e}")
129
+ return False
130
+
131
+ def synthesize(self, text: str) -> bool:
132
+ """
133
+ Convert text to speech and stream audio data via Orpheus.
134
+ Drops initial and trailing near-silent chunks.
135
+ """
136
+ super().synthesize(text)
137
+
138
+ try:
139
+ for audio_chunk in self._token_decoder(self._generate_tokens(text)):
140
+ # bail out if user called .stop()
141
+ if self.stop_synthesis_event.is_set():
142
+ logging.info("OrpheusEngine: synthesis stopped by user")
143
+ return False
144
+
145
+ # forward this chunk
146
+ self.queue.put(audio_chunk)
147
+
148
+ return True
149
+
150
+ except Exception as e:
151
+ traceback.print_exc()
152
+ logging.error(f"Synthesis error: {e}")
153
+ return False
154
+
155
+
156
+ def _generate_tokens(self, prompt: str):
157
+ """
158
+ Generate a token stream using the LM Studio API.
159
+
160
+ Args:
161
+ prompt (str): The input text prompt.
162
+
163
+ Yields:
164
+ str: Each token's text as it is received from the API.
165
+ """
166
+ logging.debug(f"Generating tokens for prompt: {prompt}")
167
+ formatted_prompt = self._format_prompt(prompt)
168
+
169
+ payload = {
170
+ "model": self.model,
171
+ "prompt": formatted_prompt,
172
+ "max_tokens": self.max_tokens,
173
+ "temperature": self.temperature,
174
+ "top_p": self.top_p,
175
+ "repeat_penalty": self.repetition_penalty,
176
+ "stream": True
177
+ }
178
+
179
+ try:
180
+ logging.debug(f"Requesting API URL: {self.api_url} with payload: {payload} and headers: {self.headers}")
181
+ response = requests.post(
182
+ self.api_url,
183
+ headers=self.headers,
184
+ json=payload,
185
+ stream=True
186
+ )
187
+ response.raise_for_status()
188
+
189
+ token_counter = 0
190
+ start_time = time.time() # Start timing token generation
191
+ for line in response.iter_lines():
192
+ # stop on demand
193
+ if self.stop_synthesis_event.is_set():
194
+ logging.debug("OrpheusEngine: token generation aborted")
195
+ break
196
+ if line:
197
+ line = line.decode('utf-8')
198
+ if line.startswith('data: '):
199
+ data_str = line[6:]
200
+ if data_str.strip() == '[DONE]':
201
+ break
202
+
203
+ try:
204
+ data = json.loads(data_str)
205
+ if 'choices' in data and data['choices']:
206
+ token_text = data['choices'][0].get('text', '')
207
+ if token_text:
208
+ token_counter += 1
209
+ # Print the time it took to get the first token
210
+ if token_counter == 1:
211
+ elapsed = time.time() - start_time
212
+ logging.info(f"Time to first token: {elapsed:.2f} seconds")
213
+ yield token_text
214
+ except json.JSONDecodeError as e:
215
+ logging.error(f"Error decoding JSON: {e}")
216
+ continue
217
+
218
+ except requests.RequestException as e:
219
+ logging.error(f"API request failed: {e}")
220
+
221
+ def _format_prompt(self, prompt: str) -> str:
222
+ """
223
+ Format the text prompt with special tokens required by Orpheus.
224
+
225
+ Args:
226
+ prompt (str): The raw text prompt.
227
+
228
+ Returns:
229
+ str: The formatted prompt including voice and termination token.
230
+ """
231
+ return f"<|audio|>{self.voice.name}: {prompt}<|eot_id|>"
232
+
233
+ def _token_decoder(self, token_gen):
234
+ """
235
+ Decode tokens from the generator and convert them into audio samples.
236
+
237
+ This method aggregates tokens in a buffer and converts them into audio chunks
238
+ once enough tokens have been collected.
239
+
240
+ Args:
241
+ token_gen: Generator yielding token strings.
242
+
243
+ Yields:
244
+ Audio samples ready to be streamed.
245
+ """
246
+ buffer = []
247
+ count = 0
248
+
249
+ logging.debug("Starting token decoding from token generator.")
250
+ for token_text in token_gen:
251
+ # bail out if stop was requested
252
+ if self.stop_synthesis_event.is_set():
253
+ logging.debug("OrpheusEngine: token decoding aborted")
254
+ break
255
+ token = self.turn_token_into_id(token_text, count)
256
+ if token is not None and token > 0:
257
+ buffer.append(token)
258
+ count += 1
259
+
260
+ # Process every 7 tokens after an initial threshold
261
+ if count % 7 == 0 and count > 27:
262
+ buffer_to_proc = buffer[-28:]
263
+ audio_samples = self._convert_buffer(buffer_to_proc, count)
264
+ if audio_samples is not None:
265
+ yield audio_samples
266
+
267
+ def turn_token_into_id(self, token_string: str, index: int) -> Optional[int]:
268
+ """
269
+ Convert a token string to a numeric ID for audio processing.
270
+
271
+ The conversion takes into account the custom token prefix and an index-based offset.
272
+
273
+ Args:
274
+ token_string (str): The token text.
275
+ index (int): The current token index.
276
+
277
+ Returns:
278
+ Optional[int]: The numeric token ID or None if conversion fails.
279
+ """
280
+ token_string = token_string.strip()
281
+ last_token_start = token_string.rfind(CUSTOM_TOKEN_PREFIX)
282
+
283
+ if last_token_start == -1:
284
+ return None
285
+
286
+ last_token = token_string[last_token_start:]
287
+
288
+ if last_token.startswith(CUSTOM_TOKEN_PREFIX) and last_token.endswith(">"):
289
+ try:
290
+ number_str = last_token[14:-1]
291
+ token_id = int(number_str) - 10 - ((index % 7) * 4096)
292
+ return token_id
293
+ except ValueError:
294
+ return None
295
+ else:
296
+ return None
297
+
298
+ def _convert_buffer(self, multiframe, count: int):
299
+ """
300
+ Convert a buffer of token frames into audio samples.
301
+
302
+ This method uses an external decoder to convert the collected token frames.
303
+
304
+ Args:
305
+ multiframe: List of token IDs to be converted.
306
+ count (int): The current token count (used for conversion logic).
307
+
308
+ Returns:
309
+ Converted audio samples if successful; otherwise, None.
310
+ """
311
+ try:
312
+ from .orpheus_decoder import convert_to_audio as orpheus_convert_to_audio
313
+ converted = orpheus_convert_to_audio(multiframe, count)
314
+ if converted is None:
315
+ logging.warning("Conversion returned None.")
316
+ return converted
317
+ except Exception as e:
318
+ logging.error(f"Failed to convert buffer to audio: {e}")
319
+ logging.info("Returning None after failed conversion.")
320
+ return None
321
+
322
+ def get_voices(self):
323
+ """
324
+ Retrieve the list of available voices.
325
+
326
+ Returns:
327
+ list: A list of OrpheusVoice instances for each available voice.
328
+ """
329
+ return [OrpheusVoice(name) for name in AVAILABLE_VOICES]
330
+
331
+ def set_voice(self, voice: Union[str, OrpheusVoice]):
332
+ """
333
+ Set the current voice for synthesis.
334
+
335
+ Args:
336
+ voice (Union[str, OrpheusVoice]): The voice name or an OrpheusVoice instance.
337
+
338
+ Raises:
339
+ ValueError: If the provided voice name is invalid.
340
+ TypeError: If the voice argument is neither a string nor an OrpheusVoice instance.
341
+ """
342
+ if isinstance(voice, str):
343
+ # if voice not in AVAILABLE_VOICES:
344
+ # raise ValueError(f"Invalid voice '{voice}'")
345
+ self.voice = OrpheusVoice(voice)
346
+ elif isinstance(voice, OrpheusVoice):
347
+ self.voice = voice
348
+ else:
349
+ raise TypeError("Voice must be a string or an OrpheusVoice instance.")
350
+
351
+ def set_voice_parameters(self, **kwargs):
352
+ """
353
+ Update voice generation parameters.
354
+
355
+ Valid parameters include 'temperature', 'top_p', 'max_tokens', and 'repetition_penalty'.
356
+
357
+ Args:
358
+ **kwargs: Arbitrary keyword arguments for valid voice parameters.
359
+ """
360
+ valid_params = ['temperature', 'top_p', 'max_tokens', 'repetition_penalty']
361
+ for param, value in kwargs.items():
362
+ if param in valid_params:
363
+ setattr(self, param, value)
364
+ elif self.debug:
365
+ logging.warning(f"Ignoring invalid parameter: {param}")
366
+
367
+ def __del__(self):
368
+ """
369
+ Destructor to clean up resources.
370
+
371
+ Puts a None into the queue to signal termination of audio processing.
372
+ """
373
+ self.queue.put(None)