File size: 13,001 Bytes
8ff13a2 8ee26de 3b14f66 c65eb13 8ff13a2 916cd25 28190c5 8ff13a2 8ee26de 8ff13a2 3b14f66 0d60bb9 3b14f66 8ff13a2 28190c5 8ff13a2 0d60bb9 28190c5 8ff13a2 577a6c7 8ff13a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 |
# -*- coding: utf-8 -*-
"""OrpheusEngine
~~~~~~~~~~~~~~~~
A drop‑in replacement for the original ``orpheus_engine.py`` that fixes
all outstanding token‑streaming issues and eliminates audible clicks by
* streaming **token‑IDs** instead of partial text
* dynamically sending a *tiny* first audio chunk (3×7 codes) followed by
steady blocks (30×7)
* mapping vLLM/OpenAI token‑IDs → SNAC codes without fragile
``"<custom_token_"`` string parsing
* adding an optional fade‑in / fade‑out per chunk
* emitting a proper WAV header as the first element in the queue so that
browsers / HTML5 `<audio>` tags start playback immediately.
The API (``get_voices()``, ``set_voice()``, …) is unchanged, so you can
keep using it from RealTimeTTS.
"""
from __future__ import annotations
from snac import SNAC, __version__ as snac_version
###############################################################################
# Standard library & 3rd‑party imports #
###############################################################################
import json
import logging
import struct
import time
import os
import torch
from queue import Queue
from typing import Generator, Iterable, List, Optional
import numpy as np
import pyaudio # provided by RealTimeTTS[system]
import requests
from RealtimeTTS.engines import BaseEngine
###############################################################################
# Constants #
###############################################################################
DEFAULT_API_URL = "http://127.0.0.1:1234"
DEFAULT_MODEL = "SebastianBodza/Kartoffel_Orpheus-3B_german_synthetic-v0.1"
DEFAULT_HEADERS = {"Content-Type": "application/json"}
DEFAULT_VOICE = "Martin"
# Audio
SAMPLE_RATE = 24_000
BITS_PER_SAMPLE = 16
AUDIO_CHANNELS = 1
# Token‑ID magic numbers (defined in the model card)
CODE_START_TOKEN_ID = 128257 # <|audio|>
CODE_REMOVE_TOKEN_ID = 128258
CODE_TOKEN_OFFSET = 128266 # <custom_token_?> – first usable code id
# Chunking strategy
_INITIAL_GROUPS = 3 # 3×7 = 21 codes ≈ 90 ms @24 kHz
_STEADY_GROUPS = 30 # 30×7 = 210 codes ≈ 900 ms
SNAC_MODEL = os.getenv("SNAC_MODEL", "hubertsiuzdak/snac_24khz")
###############################################################################
# Helper functions #
###############################################################################
def _create_wav_header(sample_rate: int, bits_per_sample: int, channels: int) -> bytes:
"""Return a 44‑byte WAV/PCM header with unknown data size (0xFFFFFFFF)."""
riff_size = 0xFFFFFFFF
header = b"RIFF" + struct.pack("<I", riff_size) + b"WAVEfmt "
header += struct.pack("<IHHIIHH", 16, 1, channels, sample_rate,
sample_rate * channels * bits_per_sample // 8,
channels * bits_per_sample // 8, bits_per_sample)
header += b"data" + struct.pack("<I", 0xFFFFFFFF)
return header
def _fade_in_out(audio: np.ndarray, fade_ms: int = 50) -> np.ndarray:
"""Apply linear fade‑in/out to avoid clicks."""
if fade_ms <= 0:
return audio
fade_samples = int(SAMPLE_RATE * fade_ms / 1000)
fade_samples -= fade_samples % 2 # keep it even
if fade_samples == 0 or audio.size < 2 * fade_samples:
return audio
ramp = np.linspace(0.0, 1.0, fade_samples, dtype=np.float32)
audio[:fade_samples] *= ramp
audio[-fade_samples:] *= ramp[::-1]
return audio
###############################################################################
# SNAC – lightweight wrapper #
###############################################################################
try:
from snac import SNAC
_snac_model: Optional[SNAC] = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
_snac_model = _snac_model.to("cuda" if _snac_model and _snac_model.torch.cuda.is_available() else "cpu")
except Exception as exc: # pragma: no cover
logging.warning("SNAC model could not be loaded – %s", exc)
_snac_model = None
def _codes_to_audio(codes: List[int]) -> bytes:
"""Convert a *flat* list of SNAC codes to 16‑bit PCM bytes."""
if not _snac_model or not codes:
return b""
# --- redistribute into 3 snac layers (see original paper) --------------
groups = len(codes) // 7
codes = codes[: groups * 7] # trim incomplete tail
if groups == 0:
return b""
l1, l2, l3 = [], [], []
for g in range(groups):
base = g * 7
l1.append(codes[base])
l2.append(codes[base + 1] - 4096)
l3.extend([
codes[base + 2] - 2 * 4096,
codes[base + 3] - 3 * 4096,
codes[base + 5] - 5 * 4096,
codes[base + 6] - 6 * 4096,
])
l2.append(codes[base + 4] - 4 * 4096)
import torch
with torch.no_grad():
layers = [
torch.tensor(l1, device=_snac_model.device).unsqueeze(0),
torch.tensor(l2, device=_snac_model.device).unsqueeze(0),
torch.tensor(l3, device=_snac_model.device).unsqueeze(0),
]
wav = _snac_model.decode(layers).cpu().numpy().squeeze()
wav = _fade_in_out(wav)
pcm = np.clip(wav * 32767, -32768, 32767).astype(np.int16).tobytes()
return pcm
###############################################################################
# Main class #
###############################################################################
class OrpheusVoice:
def __init__(self, name: str, gender: str | None = None):
self.name = name
self.gender = gender
class OrpheusEngine(BaseEngine):
"""Realtime TTS engine using the Orpheus SNAC model via vLLM."""
_SPEAKERS = [
OrpheusVoice("Martin", "m"), OrpheusVoice("Emma", "f"),
OrpheusVoice("Luca", "m"), OrpheusVoice("Anna", "f"),
OrpheusVoice("Jakob", "m"), OrpheusVoice("Anton", "m"),
OrpheusVoice("Julian", "m"), OrpheusVoice("Jan", "m"),
OrpheusVoice("Alexander", "m"), OrpheusVoice("Emil", "m"),
OrpheusVoice("Ben", "m"), OrpheusVoice("Elias", "m"),
OrpheusVoice("Felix", "m"), OrpheusVoice("Jonas", "m"),
OrpheusVoice("Noah", "m"), OrpheusVoice("Maximilian", "m"),
OrpheusVoice("Sophie", "f"), OrpheusVoice("Marie", "f"),
OrpheusVoice("Mia", "f"), OrpheusVoice("Maria", "f"),
OrpheusVoice("Sophia", "f"), OrpheusVoice("Lina", "f"),
OrpheusVoice("Lea", "f"),
]
def _load_snac(self, model_name: str = SNAC_MODEL):
"""
Lädt den SNAC-Decoder auf CPU/GPU.
Fällt bei jedem Fehler sauber auf CPU zurück.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
snac = SNAC.from_pretrained(model_name).to(device)
if device == "cuda": # half() nur auf GPU – ältere SNAC-Versionen haben keine .half()
snac = snac.half()
snac.eval()
logging.info(f"SNAC {snac_version} loaded on {device}")
return snac
except Exception as e:
logging.exception("SNAC load failed – running with silent fallback")
return None
# ---------------------------------------------------------------------
def __init__(
self,
api_url: str = DEFAULT_API_URL,
model: str = DEFAULT_MODEL,
headers: dict = DEFAULT_HEADERS,
voice: Optional[OrpheusVoice] = None,
temperature: float = 0.6,
top_p: float = 0.9,
max_tokens: int = 1200,
repetition_penalty: float = 1.1,
debug: bool = False,
) -> None:
super().__init__()
self.api_url = api_url.rstrip("/")
self.model = model
self.headers = headers
self.voice = voice or OrpheusVoice(DEFAULT_VOICE)
self.temperature = temperature
self.top_p = top_p
self.max_tokens = max_tokens
self.repetition_penalty = repetition_penalty
self.debug = debug
self.queue: "Queue[bytes | None]" = Queue()
self.snac = self._load_snac() # Decoder laden
if self.snac is None: # Fallback-Hinweis
logging.warning("⚠️ No SNAC – audio generation disabled.")
self.engine_name = "orpheus"
# ------------------------------------------------------------------ API
def get_stream_info(self):
return pyaudio.paInt16, AUDIO_CHANNELS, SAMPLE_RATE
def get_voices(self):
return self._SPEAKERS
def set_voice(self, voice_name: str):
if voice_name not in {v.name for v in self._SPEAKERS}:
raise ValueError(f"Unknown Orpheus speaker '{voice_name}'")
self.voice = OrpheusVoice(voice_name)
# --------------------------------------------------------------- public
def synthesize(self, text: str) -> bool: # noqa: C901 (long)
"""Start streaming TTS for **text** – blocks until finished."""
super().synthesize(text)
self.queue.put(_create_wav_header(SAMPLE_RATE, BITS_PER_SAMPLE, AUDIO_CHANNELS))
try:
code_stream = self._stream_snac_codes(text)
first_chunk = True
buffer: List[int] = []
sent = 0
groups_needed = _INITIAL_GROUPS
for code_id in code_stream:
buffer.append(code_id)
available = len(buffer) - sent
if available >= groups_needed * 7:
chunk_codes = buffer[sent : sent + groups_needed * 7]
sent += groups_needed * 7
pcm = _codes_to_audio(chunk_codes)
if pcm:
self.queue.put(pcm)
first_chunk = False
groups_needed = _STEADY_GROUPS
# flush remaining full groups
remaining = len(buffer) - sent
final_groups = remaining // 7
if final_groups:
pcm = _codes_to_audio(buffer[sent : sent + final_groups * 7])
if pcm:
self.queue.put(pcm)
return True
except Exception as exc: # pragma: no cover
logging.exception("OrpheusEngine: synthesis failed – %s", exc)
return False
finally:
self.queue.put(None) # close stream
# ------------------------------------------------------------ internals
def _format_prompt(self, prompt: str) -> str:
return f"<|audio|>{self.voice.name}: {prompt}<|eot_id|>"
def _stream_snac_codes(self, prompt: str) -> Generator[int, None, None]:
"""Yield SNAC code‑IDs as they arrive from the model."""
payload = {
"model": self.model,
"prompt": self._format_prompt(prompt),
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"stream": True,
"skip_special_tokens": False,
"frequency_penalty": self.repetition_penalty,
}
url = f"{self.api_url}/v1/completions" # plain completion endpoint
with requests.post(url, headers=self.headers, json=payload, stream=True, timeout=600) as r:
r.raise_for_status()
started = False
for line in r.iter_lines():
if not line:
continue
if line.startswith(b"data: "):
data = line[6:].decode()
if data.strip() == "[DONE]":
break
try:
obj = json.loads(data)
delta = obj["choices"][0]
tid: int = delta.get("token_id") # vLLM ≥0.9 provides this
if tid is None:
# fallback: derive from text
text_piece = delta.get("text", "")
if not text_piece:
continue
tid = ord(text_piece[-1]) # NOT reliable; skip
continue
except Exception:
continue
if not started:
if tid == CODE_START_TOKEN_ID:
started = True
continue
if tid == CODE_REMOVE_TOKEN_ID or tid < CODE_TOKEN_OFFSET:
continue
yield tid - CODE_TOKEN_OFFSET
# ------------------------------------------------------------------ misc
def __del__(self):
try:
self.queue.put(None)
except Exception:
pass
|