File size: 5,200 Bytes
2c15189 0316ec3 4189fe1 9bf14d0 a09ea48 0316ec3 9bf14d0 2c15189 9bf14d0 2008a3f c70d8eb 1ab029d 0316ec3 9bf14d0 0dfc310 c70d8eb 9bf14d0 c70d8eb 9bf14d0 c70d8eb 9bf14d0 c70d8eb 9bf14d0 3281189 c70d8eb 9bf14d0 c70d8eb 9bf14d0 c70d8eb 9bf14d0 c70d8eb 9bf14d0 c70d8eb a8606ac 2c15189 a09ea48 4189fe1 c70d8eb 9bf14d0 c70d8eb 9bf14d0 c70d8eb 9bf14d0 2c15189 c70d8eb 9bf14d0 c70d8eb d4630a2 c70d8eb 2c15189 c70d8eb 9bf14d0 c70d8eb 9bf14d0 c70d8eb 9bf14d0 c70d8eb 9bf14d0 2c15189 c70d8eb 4189fe1 c70d8eb a09ea48 2c15189 a09ea48 c70d8eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
# — HF‑Token & Login —
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
# — Device auswählen —
device = "cuda" if torch.cuda.is_available() else "cpu"
# — FastAPI instanziieren —
app = FastAPI()
# — Hello‑Route, damit GET / nicht 404 gibt —
@app.get("/")
async def read_root():
return {"message": "Hello, world!"}
# — Modelle beim Startup laden —
@app.on_event("startup")
async def load_models():
global tokenizer, model, snac
# SNAC laden
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
# TTS‑Modell laden
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto" if device=="cuda" else None,
torch_dtype=torch.bfloat16 if device=="cuda" else None,
low_cpu_mem_usage=True
).to(device)
model.config.pad_token_id = model.config.eos_token_id
# — Input‑Vorbereitung —
def prepare_inputs(text: str, voice: str):
prompt = f"{voice}: {text}"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
start = torch.tensor([[128259]], dtype=torch.int64, device=device)
end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
ids = torch.cat([start, input_ids, end], dim=1)
mask = torch.ones_like(ids, device=device)
return ids, mask
# — SNAC‑Dekodierung eines 7‑Token‑Blocks →
def decode_block(tokens: list[int]) -> bytes:
l1, l2, l3 = [], [], []
b = tokens
l1.append(b[0])
l2.append(b[1]-4096)
l3.append(b[2]-2*4096)
l3.append(b[3]-3*4096)
l2.append(b[4]-4*4096)
l3.append(b[5]-5*4096)
l3.append(b[6]-6*4096)
codes = [
torch.tensor(l1, device=device).unsqueeze(0),
torch.tensor(l2, device=device).unsqueeze(0),
torch.tensor(l3, device=device).unsqueeze(0),
]
audio = snac.decode(codes).squeeze().cpu().numpy()
return (audio * 32767).astype("int16").tobytes()
# — WebSocket‑Endpoint mit Chunked‑Generate (max_new_tokens=50) —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
# 1) Anfrage einlesen
msg = await ws.receive_text()
req = json.loads(msg)
text = req.get("text", "")
voice = req.get("voice", "Jakob")
# 2) Inputs bauen
input_ids, attention_mask = prepare_inputs(text, voice)
past_kvs = None
buffer_codes: list[int] = []
# 3) Chunk‑Generate‑Loop
chunk_size = 50
eos_id = model.config.eos_token_id
# Wir tracken bisher erzeugte Länge, um abzugrenzen, was neu ist
prev_len = 0
while True:
out = model.generate(
input_ids = input_ids if past_kvs is None else None,
attention_mask=attention_mask if past_kvs is None else None,
max_new_tokens=chunk_size,
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=eos_id,
use_cache=True,
return_dict_in_generate=True,
output_scores=False,
past_key_values=past_kvs
)
# Update past_kvs und sequences
past_kvs = out.past_key_values
seqs = out.sequences # (1, total_length)
total_len = seqs.shape[1]
# 4) Neue Tokens extrahieren
new_tokens = seqs[0, prev_len:total_len].tolist()
prev_len = total_len
# 5) Jeden neuen Token aufbereiten
for tok in new_tokens:
if tok == eos_id:
# Ende
new_tokens = [] # clean up
break
if tok == 128257:
buffer_codes.clear()
continue
# offset und puffern
buffer_codes.append(tok - 128266)
# sobald 7 Codes gesammelt, dekodieren & senden
if len(buffer_codes) >= 7:
block = buffer_codes[:7]
buffer_codes = buffer_codes[7:]
pcm = decode_block(block)
await ws.send_bytes(pcm)
# 6) Abbruch, wenn EOS im Chunk war
if eos_id in new_tokens:
break
# Inputs für nächsten Durchgang nur beim ersten Mal
input_ids = attention_mask = None
# 7) Zum Schluss sauber schließen
await ws.close()
except WebSocketDisconnect:
return
except Exception as e:
print("Error in /ws/tts:", e)
await ws.close(code=1011)
# — Main für lokalen Test —
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|