Tomtom84's picture
Update app.py
4c833ce verified
raw
history blame
6.09 kB
# app.py -------------------------------------------------------------
import os, json, torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
from transformers.generation.utils import Cache
from snac import SNAC
# ── 0. Auth & Device ────────────────────────────────────────────────
if (tok := os.getenv("HF_TOKEN")):
login(tok)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.enable_flash_sdp(False) # PyTorch‑2.2 fix
# ── 1. Konstanten ───────────────────────────────────────────────────
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS = 50 # ≀ 50Β β†’Β <Β 1Β s Latenz
START_TOKEN = 128259
NEW_BLOCK_TOKEN = 128257
EOS_TOKEN = 128258
AUDIO_BASE = 128266
VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
# ── 2. Logit‑Maske (nur Audio‑ und Steuer‑Token) ──────────────────
class AudioMask(LogitsProcessor):
def __init__(self, allowed: torch.Tensor): # allowed @device!
self.allowed = allowed
def __call__(self, _ids, scores):
mask = torch.full_like(scores, float("-inf"))
mask[:, self.allowed] = 0.0
return scores + mask
ALLOWED_IDS = torch.cat(
[VALID_AUDIO_IDS,
torch.tensor([NEW_BLOCK_TOKEN, EOS_TOKEN])]
).to(device)
MASKER = AudioMask(ALLOWED_IDS)
# ── 3. FastAPI GrundgerΓΌst ──────────────────────────────────────────
app = FastAPI()
@app.get("/")
async def root():
return {"msg": "Orpheus‑TTS ready"}
# global handles
tok = model = snac = None
@app.on_event("startup")
async def load_models():
global tok, model, snac
tok = AutoTokenizer.from_pretrained(REPO)
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model = AutoModelForCausalLM.from_pretrained(
REPO,
low_cpu_mem_usage=True,
device_map={"": 0} if device == "cuda" else None,
torch_dtype=torch.bfloat16 if device == "cuda" else None,
)
model.config.pad_token_id = model.config.eos_token_id
model.config.use_cache = True
# ── 4. Helper ───────────────────────────────────────────────────────
def build_inputs(text: str, voice: str):
prompt = f"{voice}: {text}"
ids = tok(prompt, return_tensors="pt").input_ids.to(device)
ids = torch.cat(
[
torch.tensor([[START_TOKEN]], device=device),
ids,
torch.tensor([[128009, 128260]], device=device),
],
1,
)
return ids, torch.ones_like(ids)
def decode_block(b7: list[int]) -> bytes:
l1, l2, l3 = [], [], []
l1.append(b7[0])
l2.append(b7[1] - 4096)
l3.extend([b7[2] - 8192, b7[3] - 12288])
l2.append(b7[4] - 16384)
l3.extend([b7[5] - 20480, b7[6] - 24576])
codes = [torch.tensor(x, device=device).unsqueeze(0) for x in (l1, l2, l3)]
audio = snac.decode(codes).squeeze().cpu().numpy()
return (audio * 32767).astype("int16").tobytes()
def new_tokens_only(full_seq, prev_len):
"""liefert Liste der Tokens, die *neu* hinzukamen"""
return full_seq[prev_len:].tolist()
# ── 5. WebSocket‑Endpoint ───────────────────────────────────────────
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
await ws.accept()
try:
req = json.loads(await ws.receive_text())
ids, attn = build_inputs(req.get("text", ""), req.get("voice", "Jakob"))
prompt_len = ids.size(1)
past, buf = None, []
while True:
gen = model.generate(
input_ids=ids if past is None else None,
attention_mask=attn if past is None else None,
past_key_values=past,
max_new_tokens=CHUNK_TOKENS,
logits_processor=[MASKER],
do_sample=True, temperature=0.7, top_p=0.95,
return_dict_in_generate=True,
use_cache=True, return_legacy_cache=True,
)
past = gen.past_key_values if not isinstance(gen.past_key_values, Cache) else gen.past_key_values.to_legacy()
seq = gen.sequences[0].tolist()
new_tok = seq[prompt_len:]
prompt_len = len(seq)
if not new_tok:
continue # selten, aber mΓΆglich
for t in new_tok:
if t == EOS_TOKEN:
# ein einziges Close‑Frame genΓΌgt
await ws.close() # <── einziges explizites close
return
if t == NEW_BLOCK_TOKEN:
buf.clear(); continue
buf.append(t - AUDIO_BASE)
if len(buf) == 7:
await ws.send_bytes(decode_block(buf))
buf.clear()
ids = attn = None # nur noch Cache
except WebSocketDisconnect:
pass # Client ging von selbst
except Exception as e:
print("WS‑Error:", e)
if ws.client_state.name == "CONNECTED":
await ws.close(code=1011) # Fehler melden
# ── 6. Local run ────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn, sys
port = int(sys.argv[1]) if len(sys.argv) > 1 else 7860
uvicorn.run("app:app", host="0.0.0.0", port=port)