File size: 6,052 Bytes
0b5b901 87012a8 4189fe1 9bf14d0 87012a8 5031731 d9ea17d 0316ec3 d44e840 479f253 2008a3f 1ab029d d44e840 0b5b901 479f253 d44e840 0b5b901 d44e840 0b5b901 bca75ea 0b5b901 bca75ea d44e840 9bf14d0 0dfc310 9bf14d0 0b5b901 d44e840 9bf14d0 0b5b901 5031731 0b5b901 9bf14d0 5031731 bca75ea 0b5b901 d44e840 f63f843 5031731 d44e840 5031731 bca75ea d44e840 0b5b901 5031731 f92444a 0b5b901 87012a8 479f253 d44e840 0b5b901 d44e840 0b5b901 d44e840 a8606ac d44e840 a09ea48 4189fe1 d44e840 f63f843 d44e840 479f253 d44e840 5031731 0b5b901 d44e840 5031731 d44e840 0b5b901 5031731 d44e840 f92444a d44e840 5d73119 d44e840 0b5b901 87012a8 0b5b901 d44e840 9ef5e61 bca75ea 0b5b901 d44e840 bca75ea 5031731 479f253 a09ea48 5031731 479f253 5031731 0b5b901 5031731 d44e840 a4cfefc 5031731 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# app.py ──────────────────────────────────────────────────────────────
import os, json, torch, asyncio
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
from transformers.generation.utils import Cache
from snac import SNAC
# 0 · Auth & Device ---------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.enable_flash_sdp(False) # SDP‑Assert fix
# 1 · Konstanten ------------------------------------------------------
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS = 50
START_TOKEN = 128259
NEW_BLOCK = 128257
EOS_TOKEN = 128258
AUDIO_BASE = 128266
AUDIO_SPAN = 4096 * 7 # 28 672 Codes
VALID_AUDIO = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
# 2 · Logit‑Masker ----------------------------------------------------
class DynamicMask(LogitsProcessor):
def __init__(self, audio_ids: torch.Tensor, min_blocks:int=1):
super().__init__()
self.audio_ids = audio_ids
self.ctrl_ids = torch.tensor([NEW_BLOCK], device=audio_ids.device)
self.blocks = 0
self.min_blk = min_blocks
def __call__(self, inp_ids, scores):
allow = torch.cat([self.audio_ids, self.ctrl_ids])
if self.blocks >= self.min_blk:
allow = torch.cat([allow,
torch.tensor([EOS_TOKEN], device=scores.device)])
mask = torch.full_like(scores, float("-inf"))
mask[:, allow] = 0
return scores + mask
# 3 · FastAPI‑App -----------------------------------------------------
app = FastAPI()
@app.get("/")
async def root():
return {"msg": "Orpheus‑TTS online"}
@app.on_event("startup")
async def load():
global tok, model, snac, masker
print("⏳ Lade Modelle …")
tok = AutoTokenizer.from_pretrained(REPO)
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model = AutoModelForCausalLM.from_pretrained(
REPO,
low_cpu_mem_usage=True,
device_map={"":0} if device=="cuda" else None,
torch_dtype=torch.bfloat16 if device=="cuda" else None,
)
model.config.pad_token_id = model.config.eos_token_id
model.config.use_cache = True
masker = DynamicMask(VALID_AUDIO.to(device))
print("✅ Modelle geladen")
# 4 · Hilfsfunktionen -------------------------------------------------
def build_inputs(text:str, voice:str):
prompt = f"{voice}: {text}"
ids = tok(prompt, return_tensors="pt").input_ids.to(device)
ids = torch.cat([torch.tensor([[START_TOKEN]], device=device),
ids,
torch.tensor([[128009,128260]], device=device)],1)
return ids, torch.ones_like(ids)
def decode_block(b):
l1,l2,l3=[],[],[]
l1.append(b[0])
l2.append(b[1]-4096)
l3 += [b[2]-8192, b[3]-12288]
l2.append(b[4]-16384)
l3 += [b[5]-20480, b[6]-24576]
codes=[torch.tensor(x,device=device).unsqueeze(0) for x in (l1,l2,l3)]
audio=snac.decode(codes).squeeze().cpu().numpy()
return (audio*32767).astype("int16").tobytes()
# 5 · WebSocket‑Endpoint ---------------------------------------------
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
await ws.accept()
try:
req = json.loads(await ws.receive_text())
ids, attn = build_inputs(req.get("text",""), req.get("voice","Jakob"))
past, last_tok, buf = None, None, []
prompt_len = ids.shape[1]
while True:
print(f"DEBUG: Before generate - past is None: {past is None}") # Added logging
out = model.generate(
input_ids = ids if past is None else torch.tensor([[last_tok]], device=device),
attention_mask = attn if past is None else None,
past_key_values= past,
max_new_tokens = CHUNK_TOKENS,
logits_processor=[masker],
do_sample=True, temperature=0.7, top_p=0.95,
use_cache=True, return_dict_in_generate=True,
return_legacy_cache=True)
print(f"DEBUG: After generate - type of out.past_key_values: {type(out.past_key_values)}") # Added logging
pkv = out.past_key_values
print(f"DEBUG: After getting pkv - type of pkv: {type(pkv)}") # Added logging
if isinstance(pkv, Cache): pkv = pkv.to_legacy()
past = pkv
print(f"DEBUG: After cache handling - past is None: {past is None}") # Added logging
seq = out.sequences[0].tolist()
new = seq[prompt_len:]; prompt_len = len(seq)
print("new tokens:", new[:25])
if not new: raise StopIteration
for t in new:
last_tok = t
if t == EOS_TOKEN: raise StopIteration
if t == NEW_BLOCK: buf.clear(); continue
buf.append(t - AUDIO_BASE)
if len(buf) == 7:
await ws.send_bytes(decode_block(buf))
buf.clear()
masker.blocks += 1
ids, attn = None, None # ab jetzt 1‑Token‑Step
except (StopIteration, WebSocketDisconnect):
pass
except Exception as e:
print("❌ WS‑Error:", e)
if ws.client_state.name != "DISCONNECTED":
await ws.close(code=1011)
finally:
if ws.client_state.name != "DISCONNECTED":
try: await ws.close()
except RuntimeError: pass
# 6 · Local run -------------------------------------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|