Tomtom84 commited on
Commit
8ff13a2
·
verified ·
1 Parent(s): 4708247

Create orpheus_engine.py

Browse files
Files changed (1) hide show
  1. engines/orpheus_engine.py +295 -0
engines/orpheus_engine.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """OrpheusEngine
3
+ ~~~~~~~~~~~~~~~~
4
+ A drop‑in replacement for the original ``orpheus_engine.py`` that fixes
5
+ all outstanding token‑streaming issues and eliminates audible clicks by
6
+
7
+ * streaming **token‑IDs** instead of partial text
8
+ * dynamically sending a *tiny* first audio chunk (3×7 codes) followed by
9
+ steady blocks (30×7)
10
+ * mapping vLLM/OpenAI token‑IDs → SNAC codes without fragile
11
+ ``"<custom_token_"`` string parsing
12
+ * adding an optional fade‑in / fade‑out per chunk
13
+ * emitting a proper WAV header as the first element in the queue so that
14
+ browsers / HTML5 `<audio>` tags start playback immediately.
15
+
16
+ The API (``get_voices()``, ``set_voice()``, …) is unchanged, so you can
17
+ keep using it from RealTimeTTS.
18
+ """
19
+ from __future__ import annotations
20
+
21
+ ###############################################################################
22
+ # Standard library & 3rd‑party imports #
23
+ ###############################################################################
24
+ import json
25
+ import logging
26
+ import struct
27
+ import time
28
+ from queue import Queue
29
+ from typing import Generator, Iterable, List, Optional
30
+
31
+ import numpy as np
32
+ import pyaudio # provided by RealTimeTTS[system]
33
+ import requests
34
+ from RealtimeTTS.engines import BaseEngine
35
+
36
+ ###############################################################################
37
+ # Constants #
38
+ ###############################################################################
39
+ DEFAULT_API_URL = "http://127.0.0.1:1234"
40
+ DEFAULT_MODEL = "SebastianBodza/Kartoffel_Orpheus-3B_german_synthetic-v0.1"
41
+ DEFAULT_HEADERS = {"Content-Type": "application/json"}
42
+ DEFAULT_VOICE = "Martin"
43
+
44
+ # Audio
45
+ SAMPLE_RATE = 24_000
46
+ BITS_PER_SAMPLE = 16
47
+ AUDIO_CHANNELS = 1
48
+
49
+ # Token‑ID magic numbers (defined in the model card)
50
+ CODE_START_TOKEN_ID = 128257 # <|audio|>
51
+ CODE_REMOVE_TOKEN_ID = 128258
52
+ CODE_TOKEN_OFFSET = 128266 # <custom_token_?> – first usable code id
53
+
54
+ # Chunking strategy
55
+ _INITIAL_GROUPS = 3 # 3×7 = 21 codes ≈ 90 ms @24 kHz
56
+ _STEADY_GROUPS = 30 # 30×7 = 210 codes ≈ 900 ms
57
+
58
+ ###############################################################################
59
+ # Helper functions #
60
+ ###############################################################################
61
+
62
+ def _create_wav_header(sample_rate: int, bits_per_sample: int, channels: int) -> bytes:
63
+ """Return a 44‑byte WAV/PCM header with unknown data size (0xFFFFFFFF)."""
64
+ riff_size = 0xFFFFFFFF
65
+ header = b"RIFF" + struct.pack("<I", riff_size) + b"WAVEfmt "
66
+ header += struct.pack("<IHHIIHH", 16, 1, channels, sample_rate,
67
+ sample_rate * channels * bits_per_sample // 8,
68
+ channels * bits_per_sample // 8, bits_per_sample)
69
+ header += b"data" + struct.pack("<I", 0xFFFFFFFF)
70
+ return header
71
+
72
+
73
+ def _fade_in_out(audio: np.ndarray, fade_ms: int = 50) -> np.ndarray:
74
+ """Apply linear fade‑in/out to avoid clicks."""
75
+ if fade_ms <= 0:
76
+ return audio
77
+ fade_samples = int(SAMPLE_RATE * fade_ms / 1000)
78
+ fade_samples -= fade_samples % 2 # keep it even
79
+ if fade_samples == 0 or audio.size < 2 * fade_samples:
80
+ return audio
81
+ ramp = np.linspace(0.0, 1.0, fade_samples, dtype=np.float32)
82
+ audio[:fade_samples] *= ramp
83
+ audio[-fade_samples:] *= ramp[::-1]
84
+ return audio
85
+
86
+ ###############################################################################
87
+ # SNAC – lightweight wrapper #
88
+ ###############################################################################
89
+ try:
90
+ from snac import SNAC
91
+ _snac_model: Optional[SNAC] = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
92
+ _snac_model = _snac_model.to("cuda" if _snac_model and _snac_model.is_cuda_available() else "cpu")
93
+ except Exception as exc: # pragma: no cover
94
+ logging.warning("SNAC model could not be loaded – %s", exc)
95
+ _snac_model = None
96
+
97
+
98
+ def _codes_to_audio(codes: List[int]) -> bytes:
99
+ """Convert a *flat* list of SNAC codes to 16‑bit PCM bytes."""
100
+ if not _snac_model or not codes:
101
+ return b""
102
+
103
+ # --- redistribute into 3 snac layers (see original paper) --------------
104
+ groups = len(codes) // 7
105
+ codes = codes[: groups * 7] # trim incomplete tail
106
+ if groups == 0:
107
+ return b""
108
+
109
+ l1, l2, l3 = [], [], []
110
+ for g in range(groups):
111
+ base = g * 7
112
+ l1.append(codes[base])
113
+ l2.append(codes[base + 1] - 4096)
114
+ l3.extend([
115
+ codes[base + 2] - 2 * 4096,
116
+ codes[base + 3] - 3 * 4096,
117
+ codes[base + 5] - 5 * 4096,
118
+ codes[base + 6] - 6 * 4096,
119
+ ])
120
+ l2.append(codes[base + 4] - 4 * 4096)
121
+
122
+ import torch
123
+
124
+ with torch.no_grad():
125
+ layers = [
126
+ torch.tensor(l1, device=_snac_model.device).unsqueeze(0),
127
+ torch.tensor(l2, device=_snac_model.device).unsqueeze(0),
128
+ torch.tensor(l3, device=_snac_model.device).unsqueeze(0),
129
+ ]
130
+ wav = _snac_model.decode(layers).cpu().numpy().squeeze()
131
+
132
+ wav = _fade_in_out(wav)
133
+ pcm = np.clip(wav * 32767, -32768, 32767).astype(np.int16).tobytes()
134
+ return pcm
135
+
136
+ ###############################################################################
137
+ # Main class #
138
+ ###############################################################################
139
+ class OrpheusVoice:
140
+ def __init__(self, name: str, gender: str | None = None):
141
+ self.name = name
142
+ self.gender = gender
143
+
144
+
145
+ class OrpheusEngine(BaseEngine):
146
+ """Realtime TTS engine using the Orpheus SNAC model via vLLM."""
147
+
148
+ _SPEAKERS = [
149
+ OrpheusVoice("Martin", "m"), OrpheusVoice("Emma", "f"),
150
+ OrpheusVoice("Luca", "m"), OrpheusVoice("Anna", "f"),
151
+ OrpheusVoice("Jakob", "m"), OrpheusVoice("Anton", "m"),
152
+ OrpheusVoice("Julian", "m"), OrpheusVoice("Jan", "m"),
153
+ OrpheusVoice("Alexander", "m"), OrpheusVoice("Emil", "m"),
154
+ OrpheusVoice("Ben", "m"), OrpheusVoice("Elias", "m"),
155
+ OrpheusVoice("Felix", "m"), OrpheusVoice("Jonas", "m"),
156
+ OrpheusVoice("Noah", "m"), OrpheusVoice("Maximilian", "m"),
157
+ OrpheusVoice("Sophie", "f"), OrpheusVoice("Marie", "f"),
158
+ OrpheusVoice("Mia", "f"), OrpheusVoice("Maria", "f"),
159
+ OrpheusVoice("Sophia", "f"), OrpheusVoice("Lina", "f"),
160
+ OrpheusVoice("Lea", "f"),
161
+ ]
162
+
163
+ # ---------------------------------------------------------------------
164
+ def __init__(
165
+ self,
166
+ api_url: str = DEFAULT_API_URL,
167
+ model: str = DEFAULT_MODEL,
168
+ headers: dict = DEFAULT_HEADERS,
169
+ voice: Optional[OrpheusVoice] = None,
170
+ temperature: float = 0.6,
171
+ top_p: float = 0.9,
172
+ max_tokens: int = 1200,
173
+ repetition_penalty: float = 1.1,
174
+ debug: bool = False,
175
+ ) -> None:
176
+ super().__init__()
177
+ self.api_url = api_url.rstrip("/")
178
+ self.model = model
179
+ self.headers = headers
180
+ self.voice = voice or OrpheusVoice(DEFAULT_VOICE)
181
+ self.temperature = temperature
182
+ self.top_p = top_p
183
+ self.max_tokens = max_tokens
184
+ self.repetition_penalty = repetition_penalty
185
+ self.debug = debug
186
+ self.queue: "Queue[bytes | None]" = Queue()
187
+ self.engine_name = "orpheus"
188
+
189
+ # ------------------------------------------------------------------ API
190
+ def get_stream_info(self):
191
+ return pyaudio.paInt16, AUDIO_CHANNELS, SAMPLE_RATE
192
+
193
+ def get_voices(self):
194
+ return self._SPEAKERS
195
+
196
+ def set_voice(self, voice_name: str):
197
+ if voice_name not in {v.name for v in self._SPEAKERS}:
198
+ raise ValueError(f"Unknown Orpheus speaker '{voice_name}'")
199
+ self.voice = OrpheusVoice(voice_name)
200
+
201
+ # --------------------------------------------------------------- public
202
+ def synthesize(self, text: str) -> bool: # noqa: C901 (long)
203
+ """Start streaming TTS for **text** – blocks until finished."""
204
+ super().synthesize(text)
205
+ self.queue.put(_create_wav_header(SAMPLE_RATE, BITS_PER_SAMPLE, AUDIO_CHANNELS))
206
+
207
+ try:
208
+ code_stream = self._stream_snac_codes(text)
209
+ first_chunk = True
210
+ buffer: List[int] = []
211
+ sent = 0
212
+ groups_needed = _INITIAL_GROUPS
213
+
214
+ for code_id in code_stream:
215
+ buffer.append(code_id)
216
+ available = len(buffer) - sent
217
+ if available >= groups_needed * 7:
218
+ chunk_codes = buffer[sent : sent + groups_needed * 7]
219
+ sent += groups_needed * 7
220
+ pcm = _codes_to_audio(chunk_codes)
221
+ if pcm:
222
+ self.queue.put(pcm)
223
+ first_chunk = False
224
+ groups_needed = _STEADY_GROUPS
225
+
226
+ # flush remaining full groups
227
+ remaining = len(buffer) - sent
228
+ final_groups = remaining // 7
229
+ if final_groups:
230
+ pcm = _codes_to_audio(buffer[sent : sent + final_groups * 7])
231
+ if pcm:
232
+ self.queue.put(pcm)
233
+
234
+ return True
235
+ except Exception as exc: # pragma: no cover
236
+ logging.exception("OrpheusEngine: synthesis failed – %s", exc)
237
+ return False
238
+ finally:
239
+ self.queue.put(None) # close stream
240
+
241
+ # ------------------------------------------------------------ internals
242
+ def _format_prompt(self, prompt: str) -> str:
243
+ return f"<|audio|>{self.voice.name}: {prompt}<|eot_id|>"
244
+
245
+ def _stream_snac_codes(self, prompt: str) -> Generator[int, None, None]:
246
+ """Yield SNAC code‑IDs as they arrive from the model."""
247
+ payload = {
248
+ "model": self.model,
249
+ "prompt": self._format_prompt(prompt),
250
+ "max_tokens": self.max_tokens,
251
+ "temperature": self.temperature,
252
+ "top_p": self.top_p,
253
+ "stream": True,
254
+ "skip_special_tokens": False,
255
+ "frequency_penalty": self.repetition_penalty,
256
+ }
257
+ url = f"{self.api_url}/v1/completions" # plain completion endpoint
258
+ with requests.post(url, headers=self.headers, json=payload, stream=True, timeout=600) as r:
259
+ r.raise_for_status()
260
+ started = False
261
+ for line in r.iter_lines():
262
+ if not line:
263
+ continue
264
+ if line.startswith(b"data: "):
265
+ data = line[6:].decode()
266
+ if data.strip() == "[DONE]":
267
+ break
268
+ try:
269
+ obj = json.loads(data)
270
+ delta = obj["choices"][0]
271
+ tid: int = delta.get("token_id") # vLLM ≥0.9 provides this
272
+ if tid is None:
273
+ # fallback: derive from text
274
+ text_piece = delta.get("text", "")
275
+ if not text_piece:
276
+ continue
277
+ tid = ord(text_piece[-1]) # NOT reliable; skip
278
+ continue
279
+ except Exception:
280
+ continue
281
+
282
+ if not started:
283
+ if tid == CODE_START_TOKEN_ID:
284
+ started = True
285
+ continue
286
+ if tid == CODE_REMOVE_TOKEN_ID or tid < CODE_TOKEN_OFFSET:
287
+ continue
288
+ yield tid - CODE_TOKEN_OFFSET
289
+
290
+ # ------------------------------------------------------------------ misc
291
+ def __del__(self):
292
+ try:
293
+ self.queue.put(None)
294
+ except Exception:
295
+ pass