# -*- coding: utf-8 -*- """OrpheusEngine ~~~~~~~~~~~~~~~~ A drop‑in replacement for the original ``orpheus_engine.py`` that fixes all outstanding token‑streaming issues and eliminates audible clicks by * streaming **token‑IDs** instead of partial text * dynamically sending a *tiny* first audio chunk (3×7 codes) followed by steady blocks (30×7) * mapping vLLM/OpenAI token‑IDs → SNAC codes without fragile ``"` tags start playback immediately. The API (``get_voices()``, ``set_voice()``, …) is unchanged, so you can keep using it from RealTimeTTS. """ from __future__ import annotations from snac import SNAC, __version__ as snac_version ############################################################################### # Standard library & 3rd‑party imports # ############################################################################### import json import logging import struct import time import os import torch from queue import Queue from typing import Generator, Iterable, List, Optional import numpy as np import pyaudio # provided by RealTimeTTS[system] import requests from RealtimeTTS.engines import BaseEngine ############################################################################### # Constants # ############################################################################### DEFAULT_API_URL = "http://127.0.0.1:1234" DEFAULT_MODEL = "SebastianBodza/Kartoffel_Orpheus-3B_german_synthetic-v0.1" DEFAULT_HEADERS = {"Content-Type": "application/json"} DEFAULT_VOICE = "Martin" # Audio SAMPLE_RATE = 24_000 BITS_PER_SAMPLE = 16 AUDIO_CHANNELS = 1 # Token‑ID magic numbers (defined in the model card) CODE_START_TOKEN_ID = 128257 # <|audio|> CODE_REMOVE_TOKEN_ID = 128258 CODE_TOKEN_OFFSET = 128266 # – first usable code id # Chunking strategy _INITIAL_GROUPS = 3 # 3×7 = 21 codes ≈ 90 ms @24 kHz _STEADY_GROUPS = 30 # 30×7 = 210 codes ≈ 900 ms SNAC_MODEL = os.getenv("SNAC_MODEL", "hubertsiuzdak/snac_24khz") ############################################################################### # Helper functions # ############################################################################### def _create_wav_header(sample_rate: int, bits_per_sample: int, channels: int) -> bytes: """Return a 44‑byte WAV/PCM header with unknown data size (0xFFFFFFFF).""" riff_size = 0xFFFFFFFF header = b"RIFF" + struct.pack(" np.ndarray: """Apply linear fade‑in/out to avoid clicks.""" if fade_ms <= 0: return audio fade_samples = int(SAMPLE_RATE * fade_ms / 1000) fade_samples -= fade_samples % 2 # keep it even if fade_samples == 0 or audio.size < 2 * fade_samples: return audio ramp = np.linspace(0.0, 1.0, fade_samples, dtype=np.float32) audio[:fade_samples] *= ramp audio[-fade_samples:] *= ramp[::-1] return audio ############################################################################### # SNAC – lightweight wrapper # ############################################################################### try: from snac import SNAC _snac_model: Optional[SNAC] = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval() _snac_model = _snac_model.to("cuda" if _snac_model and _snac_model.torch.cuda.is_available() else "cpu") except Exception as exc: # pragma: no cover logging.warning("SNAC model could not be loaded – %s", exc) _snac_model = None def _codes_to_audio(codes: List[int]) -> bytes: """Convert a *flat* list of SNAC codes to 16‑bit PCM bytes.""" if not _snac_model or not codes: return b"" # --- redistribute into 3 snac layers (see original paper) -------------- groups = len(codes) // 7 codes = codes[: groups * 7] # trim incomplete tail if groups == 0: return b"" l1, l2, l3 = [], [], [] for g in range(groups): base = g * 7 l1.append(codes[base]) l2.append(codes[base + 1] - 4096) l3.extend([ codes[base + 2] - 2 * 4096, codes[base + 3] - 3 * 4096, codes[base + 5] - 5 * 4096, codes[base + 6] - 6 * 4096, ]) l2.append(codes[base + 4] - 4 * 4096) import torch with torch.no_grad(): layers = [ torch.tensor(l1, device=_snac_model.device).unsqueeze(0), torch.tensor(l2, device=_snac_model.device).unsqueeze(0), torch.tensor(l3, device=_snac_model.device).unsqueeze(0), ] wav = _snac_model.decode(layers).cpu().numpy().squeeze() wav = _fade_in_out(wav) pcm = np.clip(wav * 32767, -32768, 32767).astype(np.int16).tobytes() return pcm ############################################################################### # Main class # ############################################################################### class OrpheusVoice: def __init__(self, name: str, gender: str | None = None): self.name = name self.gender = gender class OrpheusEngine(BaseEngine): """Realtime TTS engine using the Orpheus SNAC model via vLLM.""" _SPEAKERS = [ OrpheusVoice("Martin", "m"), OrpheusVoice("Emma", "f"), OrpheusVoice("Luca", "m"), OrpheusVoice("Anna", "f"), OrpheusVoice("Jakob", "m"), OrpheusVoice("Anton", "m"), OrpheusVoice("Julian", "m"), OrpheusVoice("Jan", "m"), OrpheusVoice("Alexander", "m"), OrpheusVoice("Emil", "m"), OrpheusVoice("Ben", "m"), OrpheusVoice("Elias", "m"), OrpheusVoice("Felix", "m"), OrpheusVoice("Jonas", "m"), OrpheusVoice("Noah", "m"), OrpheusVoice("Maximilian", "m"), OrpheusVoice("Sophie", "f"), OrpheusVoice("Marie", "f"), OrpheusVoice("Mia", "f"), OrpheusVoice("Maria", "f"), OrpheusVoice("Sophia", "f"), OrpheusVoice("Lina", "f"), OrpheusVoice("Lea", "f"), ] def _load_snac(self, model_name: str = SNAC_MODEL): """ Lädt den SNAC-Decoder auf CPU/GPU. Fällt bei jedem Fehler sauber auf CPU zurück. """ device = "cuda" if torch.cuda.is_available() else "cpu" try: snac = SNAC.from_pretrained(model_name).to(device) if device == "cuda": # half() nur auf GPU – ältere SNAC-Versionen haben keine .half() snac = snac.half() snac.eval() logging.info(f"SNAC {snac_version} loaded on {device}") return snac except Exception as e: logging.exception("SNAC load failed – running with silent fallback") return None # --------------------------------------------------------------------- def __init__( self, api_url: str = DEFAULT_API_URL, model: str = DEFAULT_MODEL, headers: dict = DEFAULT_HEADERS, voice: Optional[OrpheusVoice] = None, temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 1200, repetition_penalty: float = 1.1, debug: bool = False, ) -> None: super().__init__() self.api_url = api_url.rstrip("/") self.model = model self.headers = headers self.voice = voice or OrpheusVoice(DEFAULT_VOICE) self.temperature = temperature self.top_p = top_p self.max_tokens = max_tokens self.repetition_penalty = repetition_penalty self.debug = debug self.queue: "Queue[bytes | None]" = Queue() self.snac = self._load_snac() # Decoder laden if self.snac is None: # Fallback-Hinweis logging.warning("⚠️ No SNAC – audio generation disabled.") self.engine_name = "orpheus" # ------------------------------------------------------------------ API def get_stream_info(self): return pyaudio.paInt16, AUDIO_CHANNELS, SAMPLE_RATE def get_voices(self): return self._SPEAKERS def set_voice(self, voice_name: str): if voice_name not in {v.name for v in self._SPEAKERS}: raise ValueError(f"Unknown Orpheus speaker '{voice_name}'") self.voice = OrpheusVoice(voice_name) # --------------------------------------------------------------- public def synthesize(self, text: str) -> bool: # noqa: C901 (long) """Start streaming TTS for **text** – blocks until finished.""" super().synthesize(text) self.queue.put(_create_wav_header(SAMPLE_RATE, BITS_PER_SAMPLE, AUDIO_CHANNELS)) try: code_stream = self._stream_snac_codes(text) first_chunk = True buffer: List[int] = [] sent = 0 groups_needed = _INITIAL_GROUPS for code_id in code_stream: buffer.append(code_id) available = len(buffer) - sent if available >= groups_needed * 7: chunk_codes = buffer[sent : sent + groups_needed * 7] sent += groups_needed * 7 pcm = _codes_to_audio(chunk_codes) if pcm: self.queue.put(pcm) first_chunk = False groups_needed = _STEADY_GROUPS # flush remaining full groups remaining = len(buffer) - sent final_groups = remaining // 7 if final_groups: pcm = _codes_to_audio(buffer[sent : sent + final_groups * 7]) if pcm: self.queue.put(pcm) return True except Exception as exc: # pragma: no cover logging.exception("OrpheusEngine: synthesis failed – %s", exc) return False finally: self.queue.put(None) # close stream # ------------------------------------------------------------ internals def _format_prompt(self, prompt: str) -> str: return f"<|audio|>{self.voice.name}: {prompt}<|eot_id|>" def _stream_snac_codes(self, prompt: str) -> Generator[int, None, None]: """Yield SNAC code‑IDs as they arrive from the model.""" payload = { "model": self.model, "prompt": self._format_prompt(prompt), "max_tokens": self.max_tokens, "temperature": self.temperature, "top_p": self.top_p, "stream": True, "skip_special_tokens": False, "frequency_penalty": self.repetition_penalty, } url = f"{self.api_url}/v1/completions" # plain completion endpoint with requests.post(url, headers=self.headers, json=payload, stream=True, timeout=600) as r: r.raise_for_status() started = False for line in r.iter_lines(): if not line: continue if line.startswith(b"data: "): data = line[6:].decode() if data.strip() == "[DONE]": break try: obj = json.loads(data) delta = obj["choices"][0] tid: int = delta.get("token_id") # vLLM ≥0.9 provides this if tid is None: # fallback: derive from text text_piece = delta.get("text", "") if not text_piece: continue tid = ord(text_piece[-1]) # NOT reliable; skip continue except Exception: continue if not started: if tid == CODE_START_TOKEN_ID: started = True continue if tid == CODE_REMOVE_TOKEN_ID or tid < CODE_TOKEN_OFFSET: continue yield tid - CODE_TOKEN_OFFSET # ------------------------------------------------------------------ misc def __del__(self): try: self.queue.put(None) except Exception: pass