Spaces:
Paused
Paused
# app.py ────────────────────────────────────────────────────────────── | |
import os, json, torch, asyncio | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from huggingface_hub import login | |
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor | |
from transformers.generation.utils import Cache | |
from snac import SNAC | |
# ── 0 · Login & Device ─────────────────────────────────────────────── | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if HF_TOKEN: | |
login(HF_TOKEN) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch.backends.cuda.enable_flash_sdp(False) # CUDA‑Assert‑Fix | |
# ── 1 · Konstanten ─────────────────────────────────────────────────── | |
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
CHUNK_TOKENS = 50 | |
START_TOKEN = 128259 | |
NEW_BLOCK = 128257 | |
EOS_TOKEN = 128258 | |
AUDIO_BASE = 128266 | |
VALID_AUDIO = torch.arange(AUDIO_BASE, AUDIO_BASE+4096) | |
# ── 2 · Logit‑Masker ───────────────────────────────────────────────── | |
class DynamicAudioMask(LogitsProcessor): | |
def __init__(self, audio_ids: torch.Tensor, min_blocks:int=1): | |
super().__init__() | |
self.audio_ids = audio_ids | |
self.ctrl_ids = torch.tensor([NEW_BLOCK], device=audio_ids.device) | |
self.min_blocks = min_blocks | |
self.blocks = 0 | |
def __call__(self, inp, scores): | |
allow = torch.cat([self.audio_ids, self.ctrl_ids]) | |
if self.blocks >= self.min_blocks: | |
allow = torch.cat([allow, | |
torch.tensor([EOS_TOKEN], device=scores.device)]) | |
mask = torch.full_like(scores, float("-inf")) | |
mask[:, allow] = 0 | |
return scores + mask | |
# ── 3 · FastAPI‑App ────────────────────────────────────────────────── | |
app = FastAPI() | |
async def root(): | |
return {"msg": "Orpheus‑TTS alive"} | |
async def load(): | |
global tok, model, snac, masker | |
print("⏳ Lade Modelle …") | |
tok = AutoTokenizer.from_pretrained(REPO) | |
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
model = AutoModelForCausalLM.from_pretrained( | |
REPO, | |
low_cpu_mem_usage=True, | |
device_map={"":0} if device=="cuda" else None, | |
torch_dtype=torch.bfloat16 if device=="cuda" else None) | |
model.config.pad_token_id = model.config.eos_token_id | |
model.config.use_cache = True | |
masker = DynamicAudioMask(VALID_AUDIO.to(device)) | |
print("✅ Modelle geladen") | |
# ── 4 · Hilfsfunktionen ────────────────────────────────────────────── | |
def build_inputs(text:str, voice:str): | |
prompt = f"{voice}: {text}" | |
ids = tok(prompt, return_tensors="pt").input_ids.to(device) | |
ids = torch.cat([torch.tensor([[START_TOKEN]], device=device), | |
ids, | |
torch.tensor([[128009,128260]], device=device)],1) | |
return ids, torch.ones_like(ids) | |
def decode_block(block): | |
l1,l2,l3=[],[],[] | |
l1.append(block[0]) | |
l2.append(block[1]-4096) | |
l3.extend([block[2]-8192, block[3]-12288]) | |
l2.append(block[4]-16384) | |
l3.extend([block[5]-20480, block[6]-24576]) | |
codes=[torch.tensor(x,device=device).unsqueeze(0) for x in (l1,l2,l3)] | |
audio=snac.decode(codes).squeeze().cpu().numpy() | |
return (audio*32767).astype("int16").tobytes() | |
# ── 5 · WebSocket‑TTS ──────────────────────────────────────────────── | |
async def tts(ws:WebSocket): | |
await ws.accept() | |
try: | |
req = json.loads(await ws.receive_text()) | |
text = req.get("text","") | |
voice = req.get("voice","Jakob") | |
ids, attn = build_inputs(text, voice) | |
total_len = ids.shape[1] # Länge des Prompts | |
past = None | |
last_tok = None | |
buf = [] | |
while True: | |
out = model.generate( | |
input_ids = ids if past is None else torch.tensor([[last_tok]], device=device), | |
attention_mask = attn if past is None else None, | |
past_key_values = past, | |
max_new_tokens = CHUNK_TOKENS, | |
logits_processor= [masker], | |
do_sample=True, temperature=0.7, top_p=0.95, | |
use_cache=True, return_dict_in_generate=True, | |
return_legacy_cache=True) | |
pkv = out.past_key_values | |
if isinstance(pkv, Cache): pkv = pkv.to_legacy() | |
past = pkv | |
seq = out.sequences[0].tolist() | |
new = seq[total_len:] # alles *nach* Prompt | |
total_len = len(seq) # fürs nächste Mal | |
print("new tokens:", new[:32]) | |
if not new: # nichts generiert | |
raise StopIteration | |
for t in new: | |
last_tok = t | |
if t == EOS_TOKEN: raise StopIteration | |
if t == NEW_BLOCK: | |
buf.clear(); continue | |
buf.append(t-AUDIO_BASE) | |
if len(buf)==7: | |
await ws.send_bytes(decode_block(buf)) | |
buf.clear() | |
masker.blocks += 1 | |
ids, attn = None, None # ab jetzt 1‑Token‑Step | |
except (StopIteration, WebSocketDisconnect): | |
pass | |
except Exception as e: | |
print("❌ WS‑Error:", e) | |
if ws.client_state.name != "DISCONNECTED": | |
await ws.close(code=1011) | |
finally: | |
if ws.client_state.name != "DISCONNECTED": | |
try: await ws.close() | |
except RuntimeError: pass | |
# ── 6 · local run ──────────────────────────────────────────────────── | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=7860) | |