|
|
|
"""OrpheusEngine |
|
~~~~~~~~~~~~~~~~ |
|
A drop‑in replacement for the original ``orpheus_engine.py`` that fixes |
|
all outstanding token‑streaming issues and eliminates audible clicks by |
|
|
|
* streaming **token‑IDs** instead of partial text |
|
* dynamically sending a *tiny* first audio chunk (3×7 codes) followed by |
|
steady blocks (30×7) |
|
* mapping vLLM/OpenAI token‑IDs → SNAC codes without fragile |
|
``"<custom_token_"`` string parsing |
|
* adding an optional fade‑in / fade‑out per chunk |
|
* emitting a proper WAV header as the first element in the queue so that |
|
browsers / HTML5 `<audio>` tags start playback immediately. |
|
|
|
The API (``get_voices()``, ``set_voice()``, …) is unchanged, so you can |
|
keep using it from RealTimeTTS. |
|
""" |
|
|
|
from __future__ import annotations |
|
from snac import SNAC, __version__ as snac_version |
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
import struct |
|
import time |
|
import os |
|
import torch |
|
from queue import Queue |
|
from typing import Generator, Iterable, List, Optional |
|
|
|
import numpy as np |
|
import pyaudio |
|
import requests |
|
from RealtimeTTS.engines import BaseEngine |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_API_URL = "http://127.0.0.1:1234" |
|
DEFAULT_MODEL = "SebastianBodza/Kartoffel_Orpheus-3B_german_synthetic-v0.1" |
|
DEFAULT_HEADERS = {"Content-Type": "application/json"} |
|
DEFAULT_VOICE = "Martin" |
|
|
|
|
|
SAMPLE_RATE = 24_000 |
|
BITS_PER_SAMPLE = 16 |
|
AUDIO_CHANNELS = 1 |
|
|
|
|
|
CODE_START_TOKEN_ID = 128257 |
|
CODE_REMOVE_TOKEN_ID = 128258 |
|
CODE_TOKEN_OFFSET = 128266 |
|
|
|
|
|
_INITIAL_GROUPS = 3 |
|
_STEADY_GROUPS = 30 |
|
|
|
|
|
SNAC_MODEL = os.getenv("SNAC_MODEL", "hubertsiuzdak/snac_24khz") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_wav_header(sample_rate: int, bits_per_sample: int, channels: int) -> bytes: |
|
"""Return a 44‑byte WAV/PCM header with unknown data size (0xFFFFFFFF).""" |
|
riff_size = 0xFFFFFFFF |
|
header = b"RIFF" + struct.pack("<I", riff_size) + b"WAVEfmt " |
|
header += struct.pack("<IHHIIHH", 16, 1, channels, sample_rate, |
|
sample_rate * channels * bits_per_sample // 8, |
|
channels * bits_per_sample // 8, bits_per_sample) |
|
header += b"data" + struct.pack("<I", 0xFFFFFFFF) |
|
return header |
|
|
|
|
|
def _fade_in_out(audio: np.ndarray, fade_ms: int = 50) -> np.ndarray: |
|
"""Apply linear fade‑in/out to avoid clicks.""" |
|
if fade_ms <= 0: |
|
return audio |
|
fade_samples = int(SAMPLE_RATE * fade_ms / 1000) |
|
fade_samples -= fade_samples % 2 |
|
if fade_samples == 0 or audio.size < 2 * fade_samples: |
|
return audio |
|
ramp = np.linspace(0.0, 1.0, fade_samples, dtype=np.float32) |
|
audio[:fade_samples] *= ramp |
|
audio[-fade_samples:] *= ramp[::-1] |
|
return audio |
|
|
|
|
|
|
|
|
|
try: |
|
from snac import SNAC |
|
_snac_model: Optional[SNAC] = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval() |
|
_snac_model = _snac_model.to("cuda" if _snac_model and _snac_model.torch.cuda.is_available() else "cpu") |
|
except Exception as exc: |
|
logging.warning("SNAC model could not be loaded – %s", exc) |
|
_snac_model = None |
|
|
|
|
|
def _codes_to_audio(codes: List[int]) -> bytes: |
|
"""Convert a *flat* list of SNAC codes to 16‑bit PCM bytes.""" |
|
if not _snac_model or not codes: |
|
return b"" |
|
|
|
|
|
groups = len(codes) // 7 |
|
codes = codes[: groups * 7] |
|
if groups == 0: |
|
return b"" |
|
|
|
l1, l2, l3 = [], [], [] |
|
for g in range(groups): |
|
base = g * 7 |
|
l1.append(codes[base]) |
|
l2.append(codes[base + 1] - 4096) |
|
l3.extend([ |
|
codes[base + 2] - 2 * 4096, |
|
codes[base + 3] - 3 * 4096, |
|
codes[base + 5] - 5 * 4096, |
|
codes[base + 6] - 6 * 4096, |
|
]) |
|
l2.append(codes[base + 4] - 4 * 4096) |
|
|
|
import torch |
|
|
|
with torch.no_grad(): |
|
layers = [ |
|
torch.tensor(l1, device=_snac_model.device).unsqueeze(0), |
|
torch.tensor(l2, device=_snac_model.device).unsqueeze(0), |
|
torch.tensor(l3, device=_snac_model.device).unsqueeze(0), |
|
] |
|
wav = _snac_model.decode(layers).cpu().numpy().squeeze() |
|
|
|
wav = _fade_in_out(wav) |
|
pcm = np.clip(wav * 32767, -32768, 32767).astype(np.int16).tobytes() |
|
return pcm |
|
|
|
|
|
|
|
|
|
class OrpheusVoice: |
|
def __init__(self, name: str, gender: str | None = None): |
|
self.name = name |
|
self.gender = gender |
|
|
|
|
|
class OrpheusEngine(BaseEngine): |
|
"""Realtime TTS engine using the Orpheus SNAC model via vLLM.""" |
|
|
|
_SPEAKERS = [ |
|
OrpheusVoice("Martin", "m"), OrpheusVoice("Emma", "f"), |
|
OrpheusVoice("Luca", "m"), OrpheusVoice("Anna", "f"), |
|
OrpheusVoice("Jakob", "m"), OrpheusVoice("Anton", "m"), |
|
OrpheusVoice("Julian", "m"), OrpheusVoice("Jan", "m"), |
|
OrpheusVoice("Alexander", "m"), OrpheusVoice("Emil", "m"), |
|
OrpheusVoice("Ben", "m"), OrpheusVoice("Elias", "m"), |
|
OrpheusVoice("Felix", "m"), OrpheusVoice("Jonas", "m"), |
|
OrpheusVoice("Noah", "m"), OrpheusVoice("Maximilian", "m"), |
|
OrpheusVoice("Sophie", "f"), OrpheusVoice("Marie", "f"), |
|
OrpheusVoice("Mia", "f"), OrpheusVoice("Maria", "f"), |
|
OrpheusVoice("Sophia", "f"), OrpheusVoice("Lina", "f"), |
|
OrpheusVoice("Lea", "f"), |
|
] |
|
def _load_snac(self, model_name: str = SNAC_MODEL): |
|
""" |
|
Lädt den SNAC-Decoder auf CPU/GPU. |
|
Fällt bei jedem Fehler sauber auf CPU zurück. |
|
""" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
try: |
|
snac = SNAC.from_pretrained(model_name).to(device) |
|
if device == "cuda": |
|
snac = snac.half() |
|
snac.eval() |
|
logging.info(f"SNAC {snac_version} loaded on {device}") |
|
return snac |
|
except Exception as e: |
|
logging.exception("SNAC load failed – running with silent fallback") |
|
return None |
|
|
|
def __init__( |
|
self, |
|
api_url: str = DEFAULT_API_URL, |
|
model: str = DEFAULT_MODEL, |
|
headers: dict = DEFAULT_HEADERS, |
|
voice: Optional[OrpheusVoice] = None, |
|
temperature: float = 0.6, |
|
top_p: float = 0.9, |
|
max_tokens: int = 1200, |
|
repetition_penalty: float = 1.1, |
|
debug: bool = False, |
|
) -> None: |
|
super().__init__() |
|
self.api_url = api_url.rstrip("/") |
|
self.model = model |
|
self.headers = headers |
|
self.voice = voice or OrpheusVoice(DEFAULT_VOICE) |
|
self.temperature = temperature |
|
self.top_p = top_p |
|
self.max_tokens = max_tokens |
|
self.repetition_penalty = repetition_penalty |
|
self.debug = debug |
|
self.queue: "Queue[bytes | None]" = Queue() |
|
self.snac = self._load_snac() |
|
if self.snac is None: |
|
logging.warning("⚠️ No SNAC – audio generation disabled.") |
|
self.engine_name = "orpheus" |
|
|
|
|
|
def get_stream_info(self): |
|
return pyaudio.paInt16, AUDIO_CHANNELS, SAMPLE_RATE |
|
|
|
def get_voices(self): |
|
return self._SPEAKERS |
|
|
|
def set_voice(self, voice_name: str): |
|
if voice_name not in {v.name for v in self._SPEAKERS}: |
|
raise ValueError(f"Unknown Orpheus speaker '{voice_name}'") |
|
self.voice = OrpheusVoice(voice_name) |
|
|
|
|
|
def synthesize(self, text: str) -> bool: |
|
"""Start streaming TTS for **text** – blocks until finished.""" |
|
super().synthesize(text) |
|
self.queue.put(_create_wav_header(SAMPLE_RATE, BITS_PER_SAMPLE, AUDIO_CHANNELS)) |
|
|
|
try: |
|
code_stream = self._stream_snac_codes(text) |
|
first_chunk = True |
|
buffer: List[int] = [] |
|
sent = 0 |
|
groups_needed = _INITIAL_GROUPS |
|
|
|
for code_id in code_stream: |
|
buffer.append(code_id) |
|
available = len(buffer) - sent |
|
if available >= groups_needed * 7: |
|
chunk_codes = buffer[sent : sent + groups_needed * 7] |
|
sent += groups_needed * 7 |
|
pcm = _codes_to_audio(chunk_codes) |
|
if pcm: |
|
self.queue.put(pcm) |
|
first_chunk = False |
|
groups_needed = _STEADY_GROUPS |
|
|
|
|
|
remaining = len(buffer) - sent |
|
final_groups = remaining // 7 |
|
if final_groups: |
|
pcm = _codes_to_audio(buffer[sent : sent + final_groups * 7]) |
|
if pcm: |
|
self.queue.put(pcm) |
|
|
|
return True |
|
except Exception as exc: |
|
logging.exception("OrpheusEngine: synthesis failed – %s", exc) |
|
return False |
|
finally: |
|
self.queue.put(None) |
|
|
|
|
|
def _format_prompt(self, prompt: str) -> str: |
|
return f"<|audio|>{self.voice.name}: {prompt}<|eot_id|>" |
|
|
|
def _stream_snac_codes(self, prompt: str) -> Generator[int, None, None]: |
|
"""Yield SNAC code‑IDs as they arrive from the model.""" |
|
payload = { |
|
"model": self.model, |
|
"prompt": self._format_prompt(prompt), |
|
"max_tokens": self.max_tokens, |
|
"temperature": self.temperature, |
|
"top_p": self.top_p, |
|
"stream": True, |
|
"skip_special_tokens": False, |
|
"frequency_penalty": self.repetition_penalty, |
|
} |
|
url = f"{self.api_url}/v1/completions" |
|
with requests.post(url, headers=self.headers, json=payload, stream=True, timeout=600) as r: |
|
r.raise_for_status() |
|
started = False |
|
for line in r.iter_lines(): |
|
if not line: |
|
continue |
|
if line.startswith(b"data: "): |
|
data = line[6:].decode() |
|
if data.strip() == "[DONE]": |
|
break |
|
try: |
|
obj = json.loads(data) |
|
delta = obj["choices"][0] |
|
tid: int = delta.get("token_id") |
|
if tid is None: |
|
|
|
text_piece = delta.get("text", "") |
|
if not text_piece: |
|
continue |
|
tid = ord(text_piece[-1]) |
|
continue |
|
except Exception: |
|
continue |
|
|
|
if not started: |
|
if tid == CODE_START_TOKEN_ID: |
|
started = True |
|
continue |
|
if tid == CODE_REMOVE_TOKEN_ID or tid < CODE_TOKEN_OFFSET: |
|
continue |
|
yield tid - CODE_TOKEN_OFFSET |
|
|
|
|
|
def __del__(self): |
|
try: |
|
self.queue.put(None) |
|
except Exception: |
|
pass |
|
|