File size: 5,170 Bytes
2c15189 a3af518 0316ec3 4189fe1 9bf14d0 d9ea17d a3af518 0316ec3 a3af518 d9ea17d 2c15189 9bf14d0 2008a3f 986d4cd 1ab029d 0316ec3 986d4cd a3af518 9bf14d0 0dfc310 986d4cd 9bf14d0 986d4cd 9bf14d0 a3af518 9bf14d0 d9ea17d a3af518 986d4cd a3af518 986d4cd 9bf14d0 986d4cd a3af518 986d4cd 9bf14d0 a3af518 986d4cd a3af518 986d4cd a3af518 986d4cd a3af518 9bf14d0 a3af518 9bf14d0 a3af518 986d4cd 9bf14d0 a3af518 a8606ac 2c15189 a09ea48 4189fe1 986d4cd 9bf14d0 a3af518 d9ea17d 986d4cd a3af518 c70d8eb 4189fe1 fd06e70 a09ea48 2c15189 a09ea48 986d4cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
# — HF‑Token & Login —
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
# — Gerät wählen —
device = "cuda" if torch.cuda.is_available() else "cpu"
# — Modell‑Parameter —
MODEL_NAME = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
START_MARKER = 128259 # <|startoftranscript|>
RESTART_MARKER = 128257 # <|startoftranscript_again|>
EOS_TOKEN = 128258 # <|endoftranscript|>
AUDIO_TOKEN_OFFSET = 128266 # Offset zum Zurückrechnen
BLOCK_TOKENS = 7 # SNAC erwartet 7 Audio‑Tokens pro Block
CHUNK_TOKENS = 50 # Anzahl neuer Tokens pro Generate‑Runde
# — FastAPI instanziieren —
app = FastAPI()
# — Damit GET / nicht 404 wirft —
@app.get("/")
async def read_root():
return {"message": "Orpheus TTS Server ist live 🎙️"}
# — Modelle bei Startup laden —
@app.on_event("startup")
async def load_models():
global tokenizer, model, snac
# SNAC laden
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# TTS‑LM
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
torch_dtype=torch.bfloat16 if device=="cuda" else None,
low_cpu_mem_usage=True
)
model.config.pad_token_id = EOS_TOKEN
# — Eingabe aufbereiten —
def prepare_inputs(text: str, voice: str):
prompt = f"{voice}: {text}"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
start = torch.tensor([[START_MARKER]], device=device)
end = torch.tensor([[128009, EOS_TOKEN]], device=device)
ids = torch.cat([start, input_ids, end], dim=1)
attn_mask = torch.ones_like(ids)
return ids, attn_mask
# — Aus 7 Audio‑Tokens ein PCM‑Block erzeugen —
def decode_block(block: list[int]) -> bytes:
l1, l2, l3 = [], [], []
b = block
l1.append(b[0])
l2.append(b[1] - 4096)
l3.append(b[2] - 2*4096)
l3.append(b[3] - 3*4096)
l2.append(b[4] - 4*4096)
l3.append(b[5] - 5*4096)
l3.append(b[6] - 6*4096)
codes = [
torch.tensor(l1, device=device).unsqueeze(0),
torch.tensor(l2, device=device).unsqueeze(0),
torch.tensor(l3, device=device).unsqueeze(0),
]
audio = snac.decode(codes).squeeze().cpu().numpy()
pcm16 = (audio * 32767).astype("int16").tobytes()
return pcm16
# — Generator: kleine Chunks token‑weise erzeugen und block‑weise dekodieren —
async def generate_and_stream(ws: WebSocket, ids, attn_mask):
buffer: list[int] = []
past_kvs = None
while True:
# wir rufen model.generate in Häppchen auf
outputs = model.generate(
input_ids = ids if past_kvs is None else None,
attention_mask = attn_mask if past_kvs is None else None,
past_key_values= past_kvs,
use_cache = True,
max_new_tokens = CHUNK_TOKENS,
do_sample = True,
temperature = 0.7,
top_p = 0.95,
repetition_penalty = 1.1,
eos_token_id = EOS_TOKEN,
pad_token_id = EOS_TOKEN,
return_dict_in_generate = True,
output_scores = False,
)
# update past_kvs
past_kvs = outputs.past_key_values
# erhalte nur die gerade neu generierten Token
seq = outputs.sequences[0]
new_tokens = seq[-CHUNK_TOKENS:].tolist() if past_kvs is not None else seq[ids.shape[-1]:].tolist()
for tok in new_tokens:
# Neustart bei erneutem START‑Marker
if tok == RESTART_MARKER:
buffer = []
continue
# Ende
if tok == EOS_TOKEN:
return
# Audio‑Code berechnen
buffer.append(tok - AUDIO_TOKEN_OFFSET)
# sobald 7 Audio‑Tokens, dekodieren und streamen
if len(buffer) >= BLOCK_TOKENS:
block = buffer[:BLOCK_TOKENS]
buffer = buffer[BLOCK_TOKENS:]
pcm = decode_block(block)
await ws.send_bytes(pcm)
# — WebSocket‑Endpoint für TTS Streaming —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
data = await ws.receive_text()
req = json.loads(data)
text = req.get("text", "")
voice = req.get("voice", "Jakob")
ids, attn_mask = prepare_inputs(text, voice)
await generate_and_stream(ws, ids, attn_mask)
await ws.close()
except WebSocketDisconnect:
pass
except Exception as e:
print("Error in /ws/tts:", e)
await ws.close(code=1011)
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|