File size: 6,013 Bytes
2c15189 0316ec3 d9ea17d 4189fe1 9bf14d0 a09ea48 d9ea17d 0316ec3 7b0d42c d9ea17d 2c15189 9bf14d0 2008a3f 7b0d42c 1ab029d 0316ec3 9bf14d0 0dfc310 9bf14d0 d9ea17d 9bf14d0 d9ea17d 9bf14d0 d9ea17d 7b0d42c d9ea17d 7b0d42c d9ea17d 9bf14d0 d9ea17d 7b0d42c 9bf14d0 d9ea17d 9bf14d0 3281189 7b0d42c d9ea17d 7b0d42c d9ea17d 7b0d42c d9ea17d 7b0d42c d9ea17d 7b0d42c d9ea17d 7b0d42c d9ea17d 7b0d42c d9ea17d 9bf14d0 d9ea17d 9bf14d0 d9ea17d 9bf14d0 a8606ac 2c15189 a09ea48 4189fe1 10be82b 9bf14d0 d9ea17d 10be82b c70d8eb fd06e70 4189fe1 fd06e70 a09ea48 2c15189 a09ea48 d9ea17d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import os
import json
import torch
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer
from snac import SNAC
# — HF‑Token & Login (falls gesetzt) —
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
# — Device auswählen —
device = "cuda" if torch.cuda.is_available() else "cpu"
app = FastAPI()
@app.get("/")
async def read_root():
return {"message": "Hello, world!"}
# — Globale Modelle —
model = None
tokenizer = None
snac_model = None
@app.on_event("startup")
async def load_models():
global model, tokenizer, snac_model
# 1) SNAC laden
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
# 2) Orpheus‑TTS (public “natural”-Variante)
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
tokenizer = AutoTokenizer.from_pretrained(REPO)
model = AutoModelForCausalLM.from_pretrained(
REPO,
device_map="auto" if device == "cuda" else None,
torch_dtype=torch.bfloat16 if device == "cuda" else None,
low_cpu_mem_usage=True
).to(device)
model.config.pad_token_id = model.config.eos_token_id
# — Marker und Offsets —
START_TOKEN = 128259
END_TOKENS = [128009, 128260]
AUDIO_OFFSET = 128266
def process_single_prompt(prompt: str, voice: str) -> list[int]:
# Prompt zusammenstellen
text = f"{voice}: {prompt}" if voice and voice != "in_prompt" else prompt
# Tokenize + Marker
ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
start = torch.tensor([[START_TOKEN]], dtype=torch.int64, device=device)
end = torch.tensor([END_TOKENS], dtype=torch.int64, device=device)
input_ids = torch.cat([start, ids, end], dim=1)
attention_mask = torch.ones_like(input_ids)
# Generieren
gen = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=4000,
do_sample=True,
temperature=0.6,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=128258,
use_cache=True,
)
# Nach letztem START_TOKEN croppen
token_to_find = 128257
token_to_remove = 128258
idxs = (gen == token_to_find).nonzero(as_tuple=True)[1]
if idxs.numel() > 0:
cropped = gen[:, idxs[-1] + 1 :]
else:
cropped = gen
# Padding entfernen & Länge auf Vielfaches von 7 bringen
row = cropped[0][cropped[0] != token_to_remove]
new_len = (row.size(0) // 7) * 7
trimmed = row[:new_len].tolist()
# Offset abziehen
return [t - AUDIO_OFFSET for t in trimmed]
def redistribute_codes(code_list: list[int]) -> np.ndarray:
# 7er‑Blöcke auf 3 Layer verteilen
layer1, layer2, layer3 = [], [], []
for i in range(len(code_list) // 7):
b = code_list[7*i : 7*i+7]
layer1.append(b[0])
layer2.append(b[1] - 4096)
layer3.append(b[2] - 2*4096)
layer3.append(b[3] - 3*4096)
layer2.append(b[4] - 4*4096)
layer3.append(b[5] - 5*4096)
layer3.append(b[6] - 6*4096)
codes = [
torch.tensor(layer1, device=device).unsqueeze(0),
torch.tensor(layer2, device=device).unsqueeze(0),
torch.tensor(layer3, device=device).unsqueeze(0),
]
audio = snac_model.decode(codes).squeeze().cpu().numpy()
return audio # float32 @24 kHz
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
msg = await ws.receive_text()
req = json.loads(msg)
text = req.get("text", "")
voice = req.get("voice", "")
# 1) Prompt vorbereiten
input_ids, attention_mask = prepare_inputs(text, voice)
past_kvs = None
buffer = []
# 2) Token‑für‑Token (oder in kleinen Blöcken)
while True:
# Nur max_new_tokens=50 pro Aufruf
out = model.generate(
input_ids=input_ids if past_kvs is None else None,
attention_mask=attention_mask if past_kvs is None else None,
past_key_values=past_kvs,
use_cache=True,
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1,
max_new_tokens=50,
eos_token_id=128258,
return_dict_in_generate=True,
output_past_key_values=True,
return_legacy_cache=True, # falls Ihr noch das alte past_key_values-Format braucht
)
# Extrahiere neue Token (ohne die already generated ones)
new_ids = out.sequences[0, input_ids.shape[-1]:].tolist()
past_kvs = out.past_key_values
for tok in new_ids:
if tok == model.config.eos_token_id:
# Stream zu Ende
break
if tok == 128257: # Reset-Start‑Marker
buffer = []
continue
buffer.append(tok - AUDIO_OFFSET)
# Sobald wir 7 Audio‑Codes gesammelt haben → dekodieren & schicken
if len(buffer) == 7:
pcm = decode_block(buffer)
buffer = []
await ws.send_bytes(pcm)
# Wenn EOS im Chunk war, abbrechen
if model.config.eos_token_id in new_ids:
break
# Danach weiter mit nächsten 50 Tokens,
# input_ids & attention_mask nur beim ersten Aufruf nötig
input_ids = None
attention_mask = None
# 3) Am Ende WebSocket sauber schließen
await ws.close()
except WebSocketDisconnect:
pass
except Exception as e:
print("Error in /ws/tts:", e)
await ws.close(code=1011)
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|