File size: 4,843 Bytes
2c15189 fd06e70 0316ec3 4189fe1 9bf14d0 a09ea48 0316ec3 fd06e70 2c15189 9bf14d0 2008a3f fd06e70 1ab029d 0316ec3 fd06e70 9bf14d0 0dfc310 fd06e70 9bf14d0 fd06e70 9bf14d0 fd06e70 9bf14d0 fd06e70 9bf14d0 fd06e70 9bf14d0 fd06e70 9bf14d0 fd06e70 9bf14d0 fd06e70 9bf14d0 3281189 fd06e70 9bf14d0 fd06e70 9bf14d0 fd06e70 c70d8eb fd06e70 9bf14d0 c70d8eb 9bf14d0 fd06e70 9bf14d0 fd06e70 a8606ac 2c15189 a09ea48 4189fe1 fd06e70 9bf14d0 fd06e70 9bf14d0 fd06e70 c70d8eb fd06e70 9bf14d0 2c15189 fd06e70 c70d8eb fd06e70 9bf14d0 fd06e70 c70d8eb fd06e70 d4630a2 fd06e70 2c15189 fd06e70 c70d8eb fd06e70 c70d8eb fd06e70 9bf14d0 2c15189 fd06e70 c70d8eb fd06e70 4189fe1 fd06e70 a09ea48 2c15189 a09ea48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
# — ENV & AUTH —
HF_TOKEN = os.getenv("HF_TOKEN", "")
if HF_TOKEN:
login(HF_TOKEN)
# — DEVICE SETUP —
device = "cuda" if torch.cuda.is_available() else "cpu"
# — FASTAPI INSTANCE —
app = FastAPI()
# — HEALTHCHECK / ROOT —
@app.get("/")
async def read_root():
return {"message": "TTS WebSocket up and running!"}
# — LOAD MODELS ON STARTUP —
@app.on_event("startup")
async def startup_event():
global tokenizer, model, snac
# 1) SNAC vocoder
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
# 2) TTS model & tokenizer
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16 if device == "cuda" else None,
low_cpu_mem_usage=True
)
# make pad == eos
model.config.pad_token_id = model.config.eos_token_id
# — HELPERS —
START_TOKEN = 128259
END_TOKENS = [128009, 128260]
RESET_MARKER = 128257
EOS_TOKEN = 128258
AUDIO_TOKEN_OFFSET = 128266 # to subtract from token→audio code
def prepare_inputs(text: str, voice: str):
prompt = f"{voice}: {text}"
in_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
start = torch.tensor([[START_TOKEN]], dtype=torch.int64, device=device)
end = torch.tensor([END_TOKENS], dtype=torch.int64, device=device)
ids = torch.cat([start, in_ids, end], dim=1)
mask = torch.ones_like(ids)
return ids, mask
def decode_seven(tokens: list[int]) -> bytes:
"""Take exactly 7 audio‑codes, build SNAC input and decode to PCM16 bytes."""
b = tokens
l1 = [ b[0] ]
l2 = [ b[1] - 1*4096, b[4] - 4*4096 ]
l3 = [ b[2] - 2*4096, b[3] - 3*4096, b[5] - 5*4096, b[6] - 6*4096 ]
codes = [
torch.tensor(l1, device=device).unsqueeze(0),
torch.tensor(l2, device=device).unsqueeze(0),
torch.tensor(l3, device=device).unsqueeze(0),
]
audio = snac.decode(codes).squeeze().cpu().numpy()
pcm16 = (audio * 32767).astype("int16").tobytes()
return pcm16
# — WEBSOCKET ENDPOINT —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
# 1) receive JSON request
msg = await ws.receive_text()
req = json.loads(msg)
text = req.get("text", "")
voice = req.get("voice", "Jakob")
# 2) prepare prompt
input_ids, attention_mask = prepare_inputs(text, voice)
prompt_len = input_ids.size(1)
# 3) chunked generation setup
past_kvs = None
buffer: list[int] = []
generated_offset = 0
while True:
# 4) generate up to 50 new tokens at a time
out = model.generate(
input_ids= input_ids if past_kvs is None else None,
attention_mask=attention_mask if past_kvs is None else None,
max_new_tokens=50,
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=EOS_TOKEN,
pad_token_id=EOS_TOKEN,
use_cache=True,
return_dict_in_generate=False,
return_legacy_cache=True,
past_key_values=past_kvs,
)
# out is a tuple: (generated_ids, new_past_kvs)
gen_ids, past_kvs = out
# 5) extract only newly generated tokens
seq = gen_ids[0]
new_seq = seq[prompt_len + generated_offset :]
generated_offset += new_seq.size(0)
# 6) process each new token
stop = False
for t in new_seq.tolist():
if t == EOS_TOKEN:
stop = True
break
if t == RESET_MARKER:
buffer.clear()
continue
# convert to audio-code
buffer.append(t - AUDIO_TOKEN_OFFSET)
# once we have 7 codes, decode & stream
if len(buffer) >= 7:
block = buffer[:7]
buffer = buffer[7:]
pcm_bytes = decode_seven(block)
await ws.send_bytes(pcm_bytes)
if stop:
break
# 7) clean close
await ws.close()
except WebSocketDisconnect:
pass
except Exception as e:
print("Error in /ws/tts:", e)
await ws.close(code=1011)
|