Tomtom84's picture
Update app.py
0b5b901 verified
raw
history blame
6.62 kB
# app.py ──────────────────────────────────────────────────────────────
import os, json, torch, asyncio
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
from transformers.generation.utils import Cache
from snac import SNAC
# ── 0 · Login & Device ───────────────────────────────────────────────
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.enable_flash_sdp(False) # CUDA‑Assert‑Fix
# ── 1 · Konstanten ───────────────────────────────────────────────────
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS = 50
START_TOKEN = 128259
NEW_BLOCK = 128257
EOS_TOKEN = 128258
AUDIO_BASE = 128266
VALID_AUDIO = torch.arange(AUDIO_BASE, AUDIO_BASE+4096)
# ── 2 · Logit‑Masker ─────────────────────────────────────────────────
class DynamicAudioMask(LogitsProcessor):
def __init__(self, audio_ids: torch.Tensor, min_blocks:int=1):
super().__init__()
self.audio_ids = audio_ids
self.ctrl_ids = torch.tensor([NEW_BLOCK], device=audio_ids.device)
self.min_blocks = min_blocks
self.blocks = 0
def __call__(self, inp, scores):
allow = torch.cat([self.audio_ids, self.ctrl_ids])
if self.blocks >= self.min_blocks:
allow = torch.cat([allow,
torch.tensor([EOS_TOKEN], device=scores.device)])
mask = torch.full_like(scores, float("-inf"))
mask[:, allow] = 0
return scores + mask
# ── 3 · FastAPI‑App ──────────────────────────────────────────────────
app = FastAPI()
@app.get("/")
async def root():
return {"msg": "Orpheus‑TTS alive"}
@app.on_event("startup")
async def load():
global tok, model, snac, masker
print("⏳ Lade Modelle …")
tok = AutoTokenizer.from_pretrained(REPO)
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model = AutoModelForCausalLM.from_pretrained(
REPO,
low_cpu_mem_usage=True,
device_map={"":0} if device=="cuda" else None,
torch_dtype=torch.bfloat16 if device=="cuda" else None)
model.config.pad_token_id = model.config.eos_token_id
model.config.use_cache = True
masker = DynamicAudioMask(VALID_AUDIO.to(device))
print("✅ Modelle geladen")
# ── 4 · Hilfsfunktionen ──────────────────────────────────────────────
def build_inputs(text:str, voice:str):
prompt = f"{voice}: {text}"
ids = tok(prompt, return_tensors="pt").input_ids.to(device)
ids = torch.cat([torch.tensor([[START_TOKEN]], device=device),
ids,
torch.tensor([[128009,128260]], device=device)],1)
return ids, torch.ones_like(ids)
def decode_block(block):
l1,l2,l3=[],[],[]
l1.append(block[0])
l2.append(block[1]-4096)
l3.extend([block[2]-8192, block[3]-12288])
l2.append(block[4]-16384)
l3.extend([block[5]-20480, block[6]-24576])
codes=[torch.tensor(x,device=device).unsqueeze(0) for x in (l1,l2,l3)]
audio=snac.decode(codes).squeeze().cpu().numpy()
return (audio*32767).astype("int16").tobytes()
# ── 5 · WebSocket‑TTS ────────────────────────────────────────────────
@app.websocket("/ws/tts")
async def tts(ws:WebSocket):
await ws.accept()
try:
req = json.loads(await ws.receive_text())
text = req.get("text","")
voice = req.get("voice","Jakob")
ids, attn = build_inputs(text, voice)
total_len = ids.shape[1] # Länge des Prompts
past = None
last_tok = None
buf = []
while True:
out = model.generate(
input_ids = ids if past is None else torch.tensor([[last_tok]], device=device),
attention_mask = attn if past is None else None,
past_key_values = past,
max_new_tokens = CHUNK_TOKENS,
logits_processor= [masker],
do_sample=True, temperature=0.7, top_p=0.95,
use_cache=True, return_dict_in_generate=True,
return_legacy_cache=True)
pkv = out.past_key_values
if isinstance(pkv, Cache): pkv = pkv.to_legacy()
past = pkv
seq = out.sequences[0].tolist()
new = seq[total_len:] # alles *nach* Prompt
total_len = len(seq) # fürs nächste Mal
print("new tokens:", new[:32])
if not new: # nichts generiert
raise StopIteration
for t in new:
last_tok = t
if t == EOS_TOKEN: raise StopIteration
if t == NEW_BLOCK:
buf.clear(); continue
buf.append(t-AUDIO_BASE)
if len(buf)==7:
await ws.send_bytes(decode_block(buf))
buf.clear()
masker.blocks += 1
ids, attn = None, None # ab jetzt 1‑Token‑Step
except (StopIteration, WebSocketDisconnect):
pass
except Exception as e:
print("❌ WS‑Error:", e)
if ws.client_state.name != "DISCONNECTED":
await ws.close(code=1011)
finally:
if ws.client_state.name != "DISCONNECTED":
try: await ws.close()
except RuntimeError: pass
# ── 6 · local run ────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)