dev-mode-realtts-orpheus / engines /orpheus_engine.py
Tomtom84's picture
Update engines/orpheus_engine.py
28190c5 verified
# -*- coding: utf-8 -*-
"""OrpheusEngine
~~~~~~~~~~~~~~~~
A drop‑in replacement for the original ``orpheus_engine.py`` that fixes
all outstanding token‑streaming issues and eliminates audible clicks by
* streaming **token‑IDs** instead of partial text
* dynamically sending a *tiny* first audio chunk (3×7 codes) followed by
steady blocks (30×7)
* mapping vLLM/OpenAI token‑IDs → SNAC codes without fragile
``"<custom_token_"`` string parsing
* adding an optional fade‑in / fade‑out per chunk
* emitting a proper WAV header as the first element in the queue so that
browsers / HTML5 `<audio>` tags start playback immediately.
The API (``get_voices()``, ``set_voice()``, …) is unchanged, so you can
keep using it from RealTimeTTS.
"""
from __future__ import annotations
from snac import SNAC, __version__ as snac_version
###############################################################################
# Standard library & 3rd‑party imports #
###############################################################################
import json
import logging
import struct
import time
import os
import torch
from queue import Queue
from typing import Generator, Iterable, List, Optional
import numpy as np
import pyaudio # provided by RealTimeTTS[system]
import requests
from RealtimeTTS.engines import BaseEngine
###############################################################################
# Constants #
###############################################################################
DEFAULT_API_URL = "http://127.0.0.1:1234"
DEFAULT_MODEL = "SebastianBodza/Kartoffel_Orpheus-3B_german_synthetic-v0.1"
DEFAULT_HEADERS = {"Content-Type": "application/json"}
DEFAULT_VOICE = "Martin"
# Audio
SAMPLE_RATE = 24_000
BITS_PER_SAMPLE = 16
AUDIO_CHANNELS = 1
# Token‑ID magic numbers (defined in the model card)
CODE_START_TOKEN_ID = 128257 # <|audio|>
CODE_REMOVE_TOKEN_ID = 128258
CODE_TOKEN_OFFSET = 128266 # <custom_token_?> – first usable code id
# Chunking strategy
_INITIAL_GROUPS = 3 # 3×7 = 21 codes ≈ 90 ms @24 kHz
_STEADY_GROUPS = 30 # 30×7 = 210 codes ≈ 900 ms
SNAC_MODEL = os.getenv("SNAC_MODEL", "hubertsiuzdak/snac_24khz")
###############################################################################
# Helper functions #
###############################################################################
def _create_wav_header(sample_rate: int, bits_per_sample: int, channels: int) -> bytes:
"""Return a 44‑byte WAV/PCM header with unknown data size (0xFFFFFFFF)."""
riff_size = 0xFFFFFFFF
header = b"RIFF" + struct.pack("<I", riff_size) + b"WAVEfmt "
header += struct.pack("<IHHIIHH", 16, 1, channels, sample_rate,
sample_rate * channels * bits_per_sample // 8,
channels * bits_per_sample // 8, bits_per_sample)
header += b"data" + struct.pack("<I", 0xFFFFFFFF)
return header
def _fade_in_out(audio: np.ndarray, fade_ms: int = 50) -> np.ndarray:
"""Apply linear fade‑in/out to avoid clicks."""
if fade_ms <= 0:
return audio
fade_samples = int(SAMPLE_RATE * fade_ms / 1000)
fade_samples -= fade_samples % 2 # keep it even
if fade_samples == 0 or audio.size < 2 * fade_samples:
return audio
ramp = np.linspace(0.0, 1.0, fade_samples, dtype=np.float32)
audio[:fade_samples] *= ramp
audio[-fade_samples:] *= ramp[::-1]
return audio
###############################################################################
# SNAC – lightweight wrapper #
###############################################################################
try:
from snac import SNAC
_snac_model: Optional[SNAC] = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
_snac_model = _snac_model.to("cuda" if _snac_model and _snac_model.torch.cuda.is_available() else "cpu")
except Exception as exc: # pragma: no cover
logging.warning("SNAC model could not be loaded – %s", exc)
_snac_model = None
def _codes_to_audio(codes: List[int]) -> bytes:
"""Convert a *flat* list of SNAC codes to 16‑bit PCM bytes."""
if not _snac_model or not codes:
return b""
# --- redistribute into 3 snac layers (see original paper) --------------
groups = len(codes) // 7
codes = codes[: groups * 7] # trim incomplete tail
if groups == 0:
return b""
l1, l2, l3 = [], [], []
for g in range(groups):
base = g * 7
l1.append(codes[base])
l2.append(codes[base + 1] - 4096)
l3.extend([
codes[base + 2] - 2 * 4096,
codes[base + 3] - 3 * 4096,
codes[base + 5] - 5 * 4096,
codes[base + 6] - 6 * 4096,
])
l2.append(codes[base + 4] - 4 * 4096)
import torch
with torch.no_grad():
layers = [
torch.tensor(l1, device=_snac_model.device).unsqueeze(0),
torch.tensor(l2, device=_snac_model.device).unsqueeze(0),
torch.tensor(l3, device=_snac_model.device).unsqueeze(0),
]
wav = _snac_model.decode(layers).cpu().numpy().squeeze()
wav = _fade_in_out(wav)
pcm = np.clip(wav * 32767, -32768, 32767).astype(np.int16).tobytes()
return pcm
###############################################################################
# Main class #
###############################################################################
class OrpheusVoice:
def __init__(self, name: str, gender: str | None = None):
self.name = name
self.gender = gender
class OrpheusEngine(BaseEngine):
"""Realtime TTS engine using the Orpheus SNAC model via vLLM."""
_SPEAKERS = [
OrpheusVoice("Martin", "m"), OrpheusVoice("Emma", "f"),
OrpheusVoice("Luca", "m"), OrpheusVoice("Anna", "f"),
OrpheusVoice("Jakob", "m"), OrpheusVoice("Anton", "m"),
OrpheusVoice("Julian", "m"), OrpheusVoice("Jan", "m"),
OrpheusVoice("Alexander", "m"), OrpheusVoice("Emil", "m"),
OrpheusVoice("Ben", "m"), OrpheusVoice("Elias", "m"),
OrpheusVoice("Felix", "m"), OrpheusVoice("Jonas", "m"),
OrpheusVoice("Noah", "m"), OrpheusVoice("Maximilian", "m"),
OrpheusVoice("Sophie", "f"), OrpheusVoice("Marie", "f"),
OrpheusVoice("Mia", "f"), OrpheusVoice("Maria", "f"),
OrpheusVoice("Sophia", "f"), OrpheusVoice("Lina", "f"),
OrpheusVoice("Lea", "f"),
]
def _load_snac(self, model_name: str = SNAC_MODEL):
"""
Lädt den SNAC-Decoder auf CPU/GPU.
Fällt bei jedem Fehler sauber auf CPU zurück.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
snac = SNAC.from_pretrained(model_name).to(device)
if device == "cuda": # half() nur auf GPU – ältere SNAC-Versionen haben keine .half()
snac = snac.half()
snac.eval()
logging.info(f"SNAC {snac_version} loaded on {device}")
return snac
except Exception as e:
logging.exception("SNAC load failed – running with silent fallback")
return None
# ---------------------------------------------------------------------
def __init__(
self,
api_url: str = DEFAULT_API_URL,
model: str = DEFAULT_MODEL,
headers: dict = DEFAULT_HEADERS,
voice: Optional[OrpheusVoice] = None,
temperature: float = 0.6,
top_p: float = 0.9,
max_tokens: int = 1200,
repetition_penalty: float = 1.1,
debug: bool = False,
) -> None:
super().__init__()
self.api_url = api_url.rstrip("/")
self.model = model
self.headers = headers
self.voice = voice or OrpheusVoice(DEFAULT_VOICE)
self.temperature = temperature
self.top_p = top_p
self.max_tokens = max_tokens
self.repetition_penalty = repetition_penalty
self.debug = debug
self.queue: "Queue[bytes | None]" = Queue()
self.snac = self._load_snac() # Decoder laden
if self.snac is None: # Fallback-Hinweis
logging.warning("⚠️ No SNAC – audio generation disabled.")
self.engine_name = "orpheus"
# ------------------------------------------------------------------ API
def get_stream_info(self):
return pyaudio.paInt16, AUDIO_CHANNELS, SAMPLE_RATE
def get_voices(self):
return self._SPEAKERS
def set_voice(self, voice_name: str):
if voice_name not in {v.name for v in self._SPEAKERS}:
raise ValueError(f"Unknown Orpheus speaker '{voice_name}'")
self.voice = OrpheusVoice(voice_name)
# --------------------------------------------------------------- public
def synthesize(self, text: str) -> bool: # noqa: C901 (long)
"""Start streaming TTS for **text** – blocks until finished."""
super().synthesize(text)
self.queue.put(_create_wav_header(SAMPLE_RATE, BITS_PER_SAMPLE, AUDIO_CHANNELS))
try:
code_stream = self._stream_snac_codes(text)
first_chunk = True
buffer: List[int] = []
sent = 0
groups_needed = _INITIAL_GROUPS
for code_id in code_stream:
buffer.append(code_id)
available = len(buffer) - sent
if available >= groups_needed * 7:
chunk_codes = buffer[sent : sent + groups_needed * 7]
sent += groups_needed * 7
pcm = _codes_to_audio(chunk_codes)
if pcm:
self.queue.put(pcm)
first_chunk = False
groups_needed = _STEADY_GROUPS
# flush remaining full groups
remaining = len(buffer) - sent
final_groups = remaining // 7
if final_groups:
pcm = _codes_to_audio(buffer[sent : sent + final_groups * 7])
if pcm:
self.queue.put(pcm)
return True
except Exception as exc: # pragma: no cover
logging.exception("OrpheusEngine: synthesis failed – %s", exc)
return False
finally:
self.queue.put(None) # close stream
# ------------------------------------------------------------ internals
def _format_prompt(self, prompt: str) -> str:
return f"<|audio|>{self.voice.name}: {prompt}<|eot_id|>"
def _stream_snac_codes(self, prompt: str) -> Generator[int, None, None]:
"""Yield SNAC code‑IDs as they arrive from the model."""
payload = {
"model": self.model,
"prompt": self._format_prompt(prompt),
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"stream": True,
"skip_special_tokens": False,
"frequency_penalty": self.repetition_penalty,
}
url = f"{self.api_url}/v1/completions" # plain completion endpoint
with requests.post(url, headers=self.headers, json=payload, stream=True, timeout=600) as r:
r.raise_for_status()
started = False
for line in r.iter_lines():
if not line:
continue
if line.startswith(b"data: "):
data = line[6:].decode()
if data.strip() == "[DONE]":
break
try:
obj = json.loads(data)
delta = obj["choices"][0]
tid: int = delta.get("token_id") # vLLM ≥0.9 provides this
if tid is None:
# fallback: derive from text
text_piece = delta.get("text", "")
if not text_piece:
continue
tid = ord(text_piece[-1]) # NOT reliable; skip
continue
except Exception:
continue
if not started:
if tid == CODE_START_TOKEN_ID:
started = True
continue
if tid == CODE_REMOVE_TOKEN_ID or tid < CODE_TOKEN_OFFSET:
continue
yield tid - CODE_TOKEN_OFFSET
# ------------------------------------------------------------------ misc
def __del__(self):
try:
self.queue.put(None)
except Exception:
pass