File size: 5,631 Bytes
bca75ea 4189fe1 9bf14d0 bca75ea d9ea17d 0316ec3 bca75ea d9ea17d 2c15189 a4cfefc 2008a3f 1ab029d 0316ec3 bca75ea 9bf14d0 0dfc310 9bf14d0 bca75ea 9bf14d0 d9ea17d bca75ea 9bf14d0 f63f843 bca75ea f63f843 bca75ea a8606ac bca75ea a09ea48 4189fe1 bca75ea f63f843 bca75ea f63f843 bca75ea f63f843 bca75ea f63f843 bca75ea a4cfefc a09ea48 bca75ea a09ea48 bca75ea a4cfefc bca75ea a4cfefc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
# app.py βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
import os, json, asyncio, torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor
from snac import SNAC
# ββ 0.Β HFβAuth & Device ββββββββββββββββββββββββββββββββββββββββββββββ
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
# FlashβAttentionβBug in PyTorchΒ 2.2.x
torch.backends.cuda.enable_flash_sdp(False)
# ββ 1.Β Konstanten ββββββββββββββββββββββββββββββββββββββββββββββββββββ
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_synthetic-v0.1"
CHUNK_TOKENS = 50
START_TOKEN = 128259
NEW_BLOCK_TOKEN = 128257
EOS_TOKEN = 128258
AUDIO_BASE = 128266
VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
# ββ 2.Β LogitβProcessor zum Maskieren ββββββββββββββββββββββββββββββββ
class AudioLogitMask(LogitsProcessor):
def __init__(self, allowed_ids: torch.Tensor):
super().__init__()
self.allowed = allowed_ids
def __call__(self, input_ids, scores):
# scores shape: [batch, vocab]
mask = torch.full_like(scores, float("-inf"))
mask[:, self.allowed] = 0
return scores + mask
ALLOWED_IDS = torch.cat(
[VALID_AUDIO_IDS, torch.tensor([NEW_BLOCK_TOKEN, EOS_TOKEN])]
).to(device)
MASKER = AudioLogitMask(ALLOWED_IDS)
# ββ 3.Β FastAPI β GrundgerΓΌst βββββββββββββββββββββββββββββββββββββββββ
app = FastAPI()
@app.get("/")
async def ping():
return {"msg": "OrpheusβTTS OK"}
@app.on_event("startup")
async def load_models():
global tok, model, snac
tok = AutoTokenizer.from_pretrained(REPO)
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model = AutoModelForCausalLM.from_pretrained(
REPO,
low_cpu_mem_usage=True,
device_map={"": 0} if device == "cuda" else None,
torch_dtype=torch.bfloat16 if device == "cuda" else None,
)
model.config.pad_token_id = model.config.eos_token_id
model.config.use_cache = True
# ββ 4.Β HilfsβFunktionen βββββββββββββββββββββββββββββββββββββββββββββ
def build_prompt(text:str, voice:str):
base = f"{voice}: {text}"
ids = tok(base, return_tensors="pt").input_ids.to(device)
ids = torch.cat(
[
torch.tensor([[START_TOKEN]], device=device),
ids,
torch.tensor([[128009, 128260]], device=device),
],
1,
)
return ids, torch.ones_like(ids)
def decode_snac(block7:list[int])->bytes:
l1,l2,l3=[],[],[]
b=block7
l1.append(b[0])
l2.append(b[1]-4096)
l3.extend([b[2]-8192, b[3]-12288])
l2.append(b[4]-16384)
l3.extend([b[5]-20480, b[6]-24576])
codes=[torch.tensor(x,device=device).unsqueeze(0)
for x in (l1,l2,l3)]
audio=snac.decode(codes).squeeze().cpu().numpy()
return (audio*32767).astype("int16").tobytes()
# ββ 5.Β WebSocketβEndpoint βββββββββββββββββββββββββββββββββββββββββββ
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
await ws.accept()
try:
req = json.loads(await ws.receive_text())
text = req.get("text","")
voice = req.get("voice","Jakob")
ids, attn = build_prompt(text, voice)
past = None
buf = []
while True:
out = model.generate(
input_ids=ids if past is None else None,
attention_mask=attn if past is None else None,
past_key_values=past,
max_new_tokens=CHUNK_TOKENS,
logits_processor=[MASKER],
do_sample=True, temperature=0.7, top_p=0.95,
use_cache=True,
return_dict_in_generate=True,
)
past = out.past_key_values
newtok = out.sequences[0,-out.num_generated_tokens:].tolist()
for t in newtok:
if t==EOS_TOKEN:
raise StopIteration
if t==NEW_BLOCK_TOKEN:
buf.clear(); continue
buf.append(t-AUDIO_BASE)
if len(buf)==7:
await ws.send_bytes(decode_snac(buf))
buf.clear()
# ab jetzt nur noch mit Cache weiterβgenerieren
ids, attn = None, None
except (StopIteration, WebSocketDisconnect):
pass
except Exception as e:
print("WSβError:", e)
await ws.close(code=1011)
finally:
if ws.client_state.name!="DISCONNECTED":
await ws.close()
# ββ 6.Β Lokaler Test βββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|