# app.py ────────────────────────────────────────────────────────────── import os, json, asyncio, torch, logging from fastapi import FastAPI, WebSocket, WebSocketDisconnect from huggingface_hub import login from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor from transformers.generation.utils import Cache from snac import SNAC # ── 0. Auth & Device ──────────────────────────────────────────────── HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: login(HF_TOKEN) device = "cuda" if torch.cuda.is_available() else "cpu" torch.backends.cuda.enable_flash_sdp(False) # Flash‑Bug umgehen logging.getLogger("transformers.generation.utils").setLevel("ERROR") # ── 1. Konstanten ─────────────────────────────────────────────────── MODEL_REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" CHUNK_TOKENS = 50 START_TOKEN = 128259 # <𝑠> NEW_BLOCK_TOKEN = 128257 # 🔊‑Start EOS_TOKEN = 128258 # PROMPT_END = [128009, 128260] AUDIO_BASE = 128266 VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096) # ── 2. Logit‑Masker ───────────────────────────────────────────────── class AudioMask(LogitsProcessor): def __init__(self, allowed: torch.Tensor): super().__init__() self.allowed = allowed def __call__(self, input_ids, scores): mask = torch.full_like(scores, float("-inf")) mask[:, self.allowed] = 0 return scores + mask ALLOWED_IDS = torch.cat([ VALID_AUDIO_IDS, torch.tensor([START_TOKEN, NEW_BLOCK_TOKEN, EOS_TOKEN]) ]).to(device) MASKER = AudioMask(ALLOWED_IDS) # ── 3. FastAPI Grundgerüst ────────────────────────────────────────── app = FastAPI() @app.get("/") async def ping(): return {"message": "Orpheus‑TTS ready"} @app.on_event("startup") async def load_models(): global tok, model, snac tok = AutoTokenizer.from_pretrained(MODEL_REPO) model = AutoModelForCausalLM.from_pretrained( MODEL_REPO, low_cpu_mem_usage=True, device_map={"": 0} if device == "cuda" else None, torch_dtype=torch.bfloat16 if device == "cuda" else None, ) model.config.pad_token_id = model.config.eos_token_id snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) # ── 4. Hilfsfunktionen ────────────────────────────────────────────── def build_inputs(text: str, voice: str): prompt = f"{voice}: {text}" if voice and voice != "in_prompt" else text ids = tok(prompt, return_tensors="pt").input_ids.to(device) ids = torch.cat([ torch.tensor([[START_TOKEN]], device=device), ids, torch.tensor([PROMPT_END], device=device) ], 1) mask = torch.ones_like(ids) return ids, mask # shape (1, L) def decode_block(block7: list[int]) -> bytes: l1, l2, l3 = [], [], [] b = block7 l1.append(b[0]) l2.append(b[1] - 4096) l3 += [b[2]-8192, b[3]-12288] l2.append(b[4] - 16384) l3 += [b[5]-20480, b[6]-24576] codes = [torch.tensor(x, device=device).unsqueeze(0) for x in (l1, l2, l3)] audio = snac.decode(codes).squeeze().cpu().numpy() return (audio * 32767).astype("int16").tobytes() # ── 5. WebSocket‑Endpoint ─────────────────────────────────────────── @app.websocket("/ws/tts") async def tts(ws: WebSocket): await ws.accept() try: req = json.loads(await ws.receive_text()) ids, attn = build_inputs(req.get("text", ""), req.get("voice", "Jakob")) prompt_len = ids.size(1) past, buf = None, [] while True: out = model.generate( input_ids=ids if past is None else None, attention_mask=attn if past is None else None, past_key_values=past, max_new_tokens=CHUNK_TOKENS, logits_processor=[MASKER], do_sample=True, top_p=0.95, temperature=0.7, return_dict_in_generate=True, use_cache=True, return_legacy_cache=True, # ⇠ Warnung verschwindet ) past = out.past_key_values # unverändert weiterreichen seq = out.sequences[0].tolist() new = seq[prompt_len:]; prompt_len = len(seq) if not new: # selten, aber möglich continue for t in new: if t == EOS_TOKEN: await ws.close() return if t == NEW_BLOCK_TOKEN: buf.clear(); continue if t < AUDIO_BASE: # sollte durch Maske nie passieren continue buf.append(t - AUDIO_BASE) if len(buf) == 7: await ws.send_bytes(decode_block(buf)) buf.clear() # Ab jetzt nur noch Cache – IDs & Mask nicht mehr nötig ids = attn = None except WebSocketDisconnect: pass except Exception as e: print("WS‑Error:", e) if ws.client_state.name == "CONNECTED": await ws.close(code=1011) # ── 6. Lokaler Start ──────────────────────────────────────────────── if __name__ == "__main__": import uvicorn, sys port = int(sys.argv[1]) if len(sys.argv) > 1 else 7860 uvicorn.run("app:app", host="0.0.0.0", port=port)