Tomtom84's picture
Update app.py
d44e840 verified
raw
history blame
6.05 kB
# app.py ──────────────────────────────────────────────────────────────
import os, json, torch, asyncio
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
from transformers.generation.utils import Cache
from snac import SNAC
# 0 · Auth & Device ---------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.enable_flash_sdp(False) # SDP‑Assert fix
# 1 · Konstanten ------------------------------------------------------
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS = 50
START_TOKEN = 128259
NEW_BLOCK = 128257
EOS_TOKEN = 128258
AUDIO_BASE = 128266
AUDIO_SPAN = 4096 * 7 # 28 672 Codes
VALID_AUDIO = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
# 2 · Logit‑Masker ----------------------------------------------------
class DynamicMask(LogitsProcessor):
def __init__(self, audio_ids: torch.Tensor, min_blocks:int=1):
super().__init__()
self.audio_ids = audio_ids
self.ctrl_ids = torch.tensor([NEW_BLOCK], device=audio_ids.device)
self.blocks = 0
self.min_blk = min_blocks
def __call__(self, inp_ids, scores):
allow = torch.cat([self.audio_ids, self.ctrl_ids])
if self.blocks >= self.min_blk:
allow = torch.cat([allow,
torch.tensor([EOS_TOKEN], device=scores.device)])
mask = torch.full_like(scores, float("-inf"))
mask[:, allow] = 0
return scores + mask
# 3 · FastAPI‑App -----------------------------------------------------
app = FastAPI()
@app.get("/")
async def root():
return {"msg": "Orpheus‑TTS online"}
@app.on_event("startup")
async def load():
global tok, model, snac, masker
print("⏳ Lade Modelle …")
tok = AutoTokenizer.from_pretrained(REPO)
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model = AutoModelForCausalLM.from_pretrained(
REPO,
low_cpu_mem_usage=True,
device_map={"":0} if device=="cuda" else None,
torch_dtype=torch.bfloat16 if device=="cuda" else None,
)
model.config.pad_token_id = model.config.eos_token_id
model.config.use_cache = True
masker = DynamicMask(VALID_AUDIO.to(device))
print("✅ Modelle geladen")
# 4 · Hilfsfunktionen -------------------------------------------------
def build_inputs(text:str, voice:str):
prompt = f"{voice}: {text}"
ids = tok(prompt, return_tensors="pt").input_ids.to(device)
ids = torch.cat([torch.tensor([[START_TOKEN]], device=device),
ids,
torch.tensor([[128009,128260]], device=device)],1)
return ids, torch.ones_like(ids)
def decode_block(b):
l1,l2,l3=[],[],[]
l1.append(b[0])
l2.append(b[1]-4096)
l3 += [b[2]-8192, b[3]-12288]
l2.append(b[4]-16384)
l3 += [b[5]-20480, b[6]-24576]
codes=[torch.tensor(x,device=device).unsqueeze(0) for x in (l1,l2,l3)]
audio=snac.decode(codes).squeeze().cpu().numpy()
return (audio*32767).astype("int16").tobytes()
# 5 · WebSocket‑Endpoint ---------------------------------------------
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
await ws.accept()
try:
req = json.loads(await ws.receive_text())
ids, attn = build_inputs(req.get("text",""), req.get("voice","Jakob"))
past, last_tok, buf = None, None, []
prompt_len = ids.shape[1]
while True:
print(f"DEBUG: Before generate - past is None: {past is None}") # Added logging
out = model.generate(
input_ids = ids if past is None else torch.tensor([[last_tok]], device=device),
attention_mask = attn if past is None else None,
past_key_values= past,
max_new_tokens = CHUNK_TOKENS,
logits_processor=[masker],
do_sample=True, temperature=0.7, top_p=0.95,
use_cache=True, return_dict_in_generate=True,
return_legacy_cache=True)
print(f"DEBUG: After generate - type of out.past_key_values: {type(out.past_key_values)}") # Added logging
pkv = out.past_key_values
print(f"DEBUG: After getting pkv - type of pkv: {type(pkv)}") # Added logging
if isinstance(pkv, Cache): pkv = pkv.to_legacy()
past = pkv
print(f"DEBUG: After cache handling - past is None: {past is None}") # Added logging
seq = out.sequences[0].tolist()
new = seq[prompt_len:]; prompt_len = len(seq)
print("new tokens:", new[:25])
if not new: raise StopIteration
for t in new:
last_tok = t
if t == EOS_TOKEN: raise StopIteration
if t == NEW_BLOCK: buf.clear(); continue
buf.append(t - AUDIO_BASE)
if len(buf) == 7:
await ws.send_bytes(decode_block(buf))
buf.clear()
masker.blocks += 1
ids, attn = None, None # ab jetzt 1‑Token‑Step
except (StopIteration, WebSocketDisconnect):
pass
except Exception as e:
print("❌ WS‑Error:", e)
if ws.client_state.name != "DISCONNECTED":
await ws.close(code=1011)
finally:
if ws.client_state.name != "DISCONNECTED":
try: await ws.close()
except RuntimeError: pass
# 6 · Local run -------------------------------------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)