Spaces:
Paused
Paused
File size: 5,121 Bytes
a09ea48 0316ec3 4189fe1 0316ec3 a09ea48 0dfc310 0316ec3 0dfc310 a09ea48 2008a3f e97a876 1ab029d 0316ec3 0dfc310 e97a876 a09ea48 0316ec3 674acbf e97a876 0dfc310 f001a32 e97a876 a09ea48 0316ec3 e97a876 0dfc310 a09ea48 0316ec3 e97a876 a09ea48 0dfc310 a09ea48 ad94d02 a09ea48 e97a876 a09ea48 0dfc310 0316ec3 a09ea48 0dfc310 e97a876 0dfc310 e97a876 a09ea48 e97a876 0dfc310 a09ea48 0dfc310 e97a876 97006e1 0dfc310 4189fe1 a8606ac a09ea48 4189fe1 a09ea48 0dfc310 674acbf 0dfc310 e97a876 a09ea48 0dfc310 e97a876 0dfc310 e97a876 0dfc310 e97a876 0dfc310 e97a876 0dfc310 e97a876 0dfc310 e97a876 0dfc310 4189fe1 a09ea48 4189fe1 a09ea48 4189fe1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login, snapshot_download
# — ENV & HF‑AUTH —
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
# — Device —
device = "cuda" if torch.cuda.is_available() else "cpu"
# — Modelle laden —
print("Loading SNAC model…")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
# Nur die Konfig + Safetensors, alles andere wird ignoriert
snapshot_download(
repo_id=model_name,
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
ignore_patterns=[
"optimizer.pt", "pytorch_model.bin", "training_args.bin",
"scheduler.pt", "tokenizer.json", "tokenizer_config.json",
"special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.*"
]
)
print("Loading Orpheus model…")
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16
).to(device)
model.config.pad_token_id = model.config.eos_token_id
tokenizer = AutoTokenizer.from_pretrained(model_name)
# — Konstanten für Audio‑Token →
# (muss übereinstimmen mit Deinem Training; hier 128266)
AUDIO_TOKEN_OFFSET = 128266
# — Hilfsfunktionen —
def process_prompt(text: str, voice: str):
prompt = f"{voice}: {text}"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
# Laut Spezifikation:
# start_token=128259, end_tokens=(128009,128260)
start = torch.tensor([[128259]], dtype=torch.int64)
end = torch.tensor([[128009, 128260]], dtype=torch.int64)
ids = torch.cat([start, input_ids, end], dim=1).to(device)
mask = torch.ones_like(ids).to(device)
return ids, mask
def parse_output(generated_ids: torch.LongTensor):
"""
Croppt nach dem letzten 128257-Start-Token, entfernt Padding (128258)
und zieht dann den Audio‑Offset ab, um echte Code‑IDs zu bekommen.
"""
# finde letztes Audio‑Start‑Token
token_to_start = 128257
token_to_remove = model.config.eos_token_id # 128258
idxs = (generated_ids == token_to_start).nonzero(as_tuple=True)[1]
if idxs.numel() > 0:
cut = idxs[-1].item() + 1
cropped = generated_ids[:, cut:]
else:
cropped = generated_ids
# flatten & remove PAD, dann Offset abziehen
flat = cropped[0][cropped[0] != token_to_remove]
codes = [(int(t) - AUDIO_TOKEN_OFFSET) for t in flat]
return codes
def redistribute_codes(code_list: list[int], snac_model: SNAC):
"""
Verteilt die flache Code‑Liste in 3 Layers und dekodiert mit SNAC.
"""
layer1, layer2, layer3 = [], [], []
for i in range(len(code_list) // 7):
base = code_list[7*i : 7*i+7]
layer1.append(base[0])
layer2.append(base[1] - 4096)
layer3.append(base[2] - 2*4096)
layer3.append(base[3] - 3*4096)
layer2.append(base[4] - 4*4096)
layer3.append(base[5] - 5*4096)
layer3.append(base[6] - 6*4096)
dev = next(snac_model.parameters()).device
c1 = torch.tensor(layer1, device=dev).unsqueeze(0)
c2 = torch.tensor(layer2, device=dev).unsqueeze(0)
c3 = torch.tensor(layer3, device=dev).unsqueeze(0)
audio = snac_model.decode([c1, c2, c3])
return audio.detach().squeeze().cpu().numpy()
# — FastAPI + WebSocket-Endpoint —
app = FastAPI()
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
while True:
msg = await ws.receive_text()
data = json.loads(msg)
text = data.get("text", "")
voice = data.get("voice", "Jakob")
# 1) Prompt → Token‑Tensoren
ids, mask = process_prompt(text, voice)
# 2) Generation
gen_ids = model.generate(
input_ids=ids,
attention_mask=mask,
max_new_tokens=200, # zum Debug
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=model.config.eos_token_id,
)
# 3) Token → Code‑Liste → Audio (Float32 @24 kHz)
code_list = parse_output(gen_ids)
audio_np = redistribute_codes(code_list, snac)
# 4) In 0.1 s‑Chunks (2400 Samples) als PCM16 streamen
pcm16 = (audio_np * 32767).astype("int16").tobytes()
chunk = 2400 * 2
for i in range(0, len(pcm16), chunk):
await ws.send_bytes(pcm16[i : i+chunk])
await asyncio.sleep(0.1)
except WebSocketDisconnect:
print("Client disconnected")
except Exception as e:
print("Error in /ws/tts:", e)
await ws.close(code=1011)
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|