File size: 4,489 Bytes
2c15189 a3af518 0316ec3 4189fe1 9bf14d0 d9ea17d a3af518 0316ec3 a3af518 d9ea17d 2c15189 9bf14d0 2008a3f a3af518 1ab029d 0316ec3 a3af518 9bf14d0 0dfc310 a3af518 9bf14d0 d9ea17d 9bf14d0 a3af518 9bf14d0 d9ea17d a3af518 7b0d42c d9ea17d 9bf14d0 d9ea17d a3af518 7b0d42c 9bf14d0 a3af518 9bf14d0 3281189 a3af518 7b0d42c d9ea17d a3af518 9bf14d0 a3af518 9bf14d0 a3af518 9bf14d0 a3af518 a8606ac 2c15189 a09ea48 4189fe1 a3af518 10be82b 9bf14d0 a3af518 d9ea17d 10be82b a3af518 10be82b a3af518 10be82b a3af518 10be82b a3af518 10be82b a3af518 10be82b a3af518 10be82b a3af518 c70d8eb fd06e70 4189fe1 a3af518 fd06e70 a3af518 a09ea48 a3af518 2c15189 a09ea48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
# — HF‑Token & Login —
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
# — Device wählen —
device = "cuda" if torch.cuda.is_available() else "cpu"
# — FastAPI instanziieren —
app = FastAPI()
# — Hello‑Route, damit GET / nicht 404 wirft —
@app.get("/")
async def read_root():
return {"message": "Hello, world!"}
# — Modelle bei Startup laden —
@app.on_event("startup")
async def load_models():
global tokenizer, model, snac
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
tokenizer = AutoTokenizer.from_pretrained(REPO)
model = AutoModelForCausalLM.from_pretrained(
REPO,
device_map="auto",
torch_dtype=torch.bfloat16 if device == "cuda" else None,
low_cpu_mem_usage=True
)
# Für pad-token fallback auf eos
model.config.pad_token_id = model.config.eos_token_id
# — Hilfsfunktionen —
START_TOKEN = 128259
END_TOKENS = [128009, 128260]
RESET_TOKEN = 128257
AUDIO_OFFSET = 128266
EOS_TOKEN = model.config.eos_token_id if 'model' in globals() else 128258
def prepare_inputs(text: str, voice: str):
prompt = f"{voice}: {text}"
ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
start = torch.tensor([[START_TOKEN]], device=device)
end = torch.tensor([END_TOKENS], device=device)
input_ids = torch.cat([start, ids, end], dim=1)
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
def decode_block(block: list[int]):
# aus genau 7 Audio‑Codes ein PCM‑Byte‑Block bauen
l1, l2, l3 = [], [], []
b = block
l1.append(b[0])
l2.append(b[1] - 4096)
l3.append(b[2] - 2*4096)
l3.append(b[3] - 3*4096)
l2.append(b[4] - 4*4096)
l3.append(b[5] - 5*4096)
l3.append(b[6] - 6*4096)
codes = [
torch.tensor(l1, device=device).unsqueeze(0),
torch.tensor(l2, device=device).unsqueeze(0),
torch.tensor(l3, device=device).unsqueeze(0),
]
audio = snac.decode(codes).squeeze().cpu().numpy()
return (audio * 32767).astype("int16").tobytes()
# — WebSocket‑Endpoint für TTS Streaming —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
msg = await ws.receive_text()
req = json.loads(msg)
text = req.get("text", "")
voice = req.get("voice", "Jakob")
input_ids, attention_mask = prepare_inputs(text, voice)
past_kvs = None
collected = []
# Token‑für‑Token mit eigener Sampling‑Schleife
while True:
out = model(
input_ids=input_ids if past_kvs is None else None,
attention_mask=attention_mask if past_kvs is None else None,
past_key_values=past_kvs,
use_cache=True,
)
logits = out.logits[:, -1, :]
past_kvs = out.past_key_values
# Sampling
probs = torch.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1).item()
# EOS → fertig
if nxt == EOS_TOKEN:
break
# RESET → alte Sammlung verwerfen
if nxt == RESET_TOKEN:
collected = []
# und input_ids für nächsten Durchlauf auf None setzen
input_ids = None
attention_mask = None
continue
# Audio‑Code abziehen & sammeln
collected.append(nxt - AUDIO_OFFSET)
# jede 7 Codes → dekodieren & streamen
if len(collected) == 7:
pcm = decode_block(collected)
collected = []
await ws.send_bytes(pcm)
# nur beim allerersten Schritt mit IDs arbeiten
input_ids = None
attention_mask = None
# Stream sauber beenden
await ws.close()
except WebSocketDisconnect:
# Client hat Disconnect gemacht → nichts tun
pass
except Exception as e:
# auf Fehler 1011 senden
print("Error in /ws/tts:", e)
await ws.close(code=1011)
|