# app.py ------------------------------------------------------------- import os, json, torch from fastapi import FastAPI, WebSocket, WebSocketDisconnect from huggingface_hub import login from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor from transformers.generation.utils import Cache from snac import SNAC # ── 0. Auth & Device ──────────────────────────────────────────────── if (tok := os.getenv("HF_TOKEN")): login(tok) device = "cuda" if torch.cuda.is_available() else "cpu" torch.backends.cuda.enable_flash_sdp(False) # PyTorch‑2.2 fix # ── 1. Konstanten ─────────────────────────────────────────────────── REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" CHUNK_TOKENS = 50 # ≤ 50 → < 1 s Latenz START_TOKEN = 128259 NEW_BLOCK_TOKEN = 128257 EOS_TOKEN = 128258 AUDIO_BASE = 128266 VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096) # ── 2. Logit‑Maske (nur Audio‑ und Steuer‑Token) ────────────────── class AudioMask(LogitsProcessor): def __init__(self, allowed: torch.Tensor): # allowed @device! self.allowed = allowed def __call__(self, _ids, scores): mask = torch.full_like(scores, float("-inf")) mask[:, self.allowed] = 0.0 return scores + mask ALLOWED_IDS = torch.cat( [VALID_AUDIO_IDS, torch.tensor([NEW_BLOCK_TOKEN, EOS_TOKEN])] ).to(device) MASKER = AudioMask(ALLOWED_IDS) # ── 3. FastAPI Grundgerüst ────────────────────────────────────────── app = FastAPI() @app.get("/") async def root(): return {"msg": "Orpheus‑TTS ready"} # global handles tok = model = snac = None @app.on_event("startup") async def load_models(): global tok, model, snac tok = AutoTokenizer.from_pretrained(REPO) snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) model = AutoModelForCausalLM.from_pretrained( REPO, low_cpu_mem_usage=True, device_map={"": 0} if device == "cuda" else None, torch_dtype=torch.bfloat16 if device == "cuda" else None, ) model.config.pad_token_id = model.config.eos_token_id model.config.use_cache = True # ── 4. Helper ─────────────────────────────────────────────────────── def build_inputs(text: str, voice: str): prompt = f"{voice}: {text}" ids = tok(prompt, return_tensors="pt").input_ids.to(device) ids = torch.cat( [ torch.tensor([[START_TOKEN]], device=device), ids, torch.tensor([[128009, 128260]], device=device), ], 1, ) return ids, torch.ones_like(ids) def decode_block(b7: list[int]) -> bytes: l1, l2, l3 = [], [], [] l1.append(b7[0]) l2.append(b7[1] - 4096) l3.extend([b7[2] - 8192, b7[3] - 12288]) l2.append(b7[4] - 16384) l3.extend([b7[5] - 20480, b7[6] - 24576]) codes = [torch.tensor(x, device=device).unsqueeze(0) for x in (l1, l2, l3)] audio = snac.decode(codes).squeeze().cpu().numpy() return (audio * 32767).astype("int16").tobytes() def new_tokens_only(full_seq, prev_len): """liefert Liste der Tokens, die *neu* hinzukamen""" return full_seq[prev_len:].tolist() # ── 5. WebSocket‑Endpoint ─────────────────────────────────────────── @app.websocket("/ws/tts") async def tts(ws: WebSocket): await ws.accept() try: req = json.loads(await ws.receive_text()) ids, attn = build_inputs(req.get("text", ""), req.get("voice", "Jakob")) prompt_len = ids.size(1) past, buf = None, [] while True: gen = model.generate( input_ids=ids if past is None else None, attention_mask=attn if past is None else None, past_key_values=past, max_new_tokens=CHUNK_TOKENS, logits_processor=[MASKER], do_sample=True, temperature=0.7, top_p=0.95, return_dict_in_generate=True, use_cache=True, return_legacy_cache=True, ) past = gen.past_key_values if not isinstance(gen.past_key_values, Cache) else gen.past_key_values.to_legacy() seq = gen.sequences[0].tolist() new_tok = seq[prompt_len:] prompt_len = len(seq) if not new_tok: continue # selten, aber möglich for t in new_tok: if t == EOS_TOKEN: # ein einziges Close‑Frame genügt await ws.close() # <── einziges explizites close return if t == NEW_BLOCK_TOKEN: buf.clear(); continue buf.append(t - AUDIO_BASE) if len(buf) == 7: await ws.send_bytes(decode_block(buf)) buf.clear() ids = attn = None # nur noch Cache except WebSocketDisconnect: pass # Client ging von selbst except Exception as e: print("WS‑Error:", e) if ws.client_state.name == "CONNECTED": await ws.close(code=1011) # Fehler melden # ── 6. Local run ──────────────────────────────────────────────────── if __name__ == "__main__": import uvicorn, sys port = int(sys.argv[1]) if len(sys.argv) > 1 else 7860 uvicorn.run("app:app", host="0.0.0.0", port=port)