Spaces:
Paused
Paused
File size: 4,839 Bytes
a09ea48 0316ec3 4189fe1 0316ec3 a09ea48 0dfc310 0316ec3 0dfc310 a09ea48 2008a3f e97a876 1ab029d 0316ec3 0dfc310 e97a876 a09ea48 0316ec3 674acbf b3e4aa7 0dfc310 f001a32 e97a876 b3e4aa7 a09ea48 e97a876 b3e4aa7 0dfc310 b3e4aa7 a09ea48 b3e4aa7 a09ea48 b3e4aa7 a09ea48 ad94d02 b3e4aa7 e97a876 b3e4aa7 a09ea48 b3e4aa7 0316ec3 a09ea48 0dfc310 b3e4aa7 0dfc310 b3e4aa7 97006e1 0dfc310 4189fe1 a8606ac a09ea48 4189fe1 b3e4aa7 a09ea48 0dfc310 674acbf 0dfc310 b3e4aa7 a09ea48 0dfc310 b3e4aa7 0dfc310 b3e4aa7 0dfc310 b3e4aa7 0dfc310 b3e4aa7 0dfc310 b3e4aa7 0dfc310 b3e4aa7 4189fe1 a09ea48 b3e4aa7 a09ea48 4189fe1 a09ea48 4189fe1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login, snapshot_download
# — ENV & HF‑AUTH —
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
# — Device —
device = "cuda" if torch.cuda.is_available() else "cpu"
# — Modelle laden —
print("Loading SNAC model…")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
print("Downloading Orpheus weights (konfig + safetensors)…")
snapshot_download(
repo_id=model_name,
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
ignore_patterns=[
"optimizer.pt", "pytorch_model.bin", "training_args.bin",
"scheduler.pt", "tokenizer.json", "tokenizer_config.json",
"special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.*"
]
)
print("Loading Orpheus model…")
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
model.config.pad_token_id = model.config.eos_token_id
tokenizer = AutoTokenizer.from_pretrained(model_name)
# — Hilfsfunktionen —
def process_prompt(text: str, voice: str):
"""Erzeuge input_ids und attention_mask für einen Prompt."""
prompt = f"{voice}: {text}"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
start = torch.tensor([[128259]], dtype=torch.int64, device=device)
end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
ids = torch.cat([start, input_ids, end], dim=1)
mask = torch.ones_like(ids)
return ids, mask
def parse_output(generated_ids: torch.LongTensor) -> list[int]:
"""Extrahiere rohe Tokenliste nach dem letzten 128257-Start-Token."""
token_to_find = 128257
token_to_remove = 128258
idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
if idxs.numel() > 0:
cropped = generated_ids[:, idxs[-1].item() + 1 :]
else:
cropped = generated_ids
row = cropped[0]
row = row[row != token_to_remove]
return row.tolist()
def redistribute_codes(code_list: list[int]) -> bytes:
"""Verteile die Codes auf die drei SNAC-Layer und dekodiere zu PCM16-Bytes."""
l1, l2, l3 = [], [], []
for i in range((len(code_list) + 1) // 7):
base = code_list[7*i : 7*i+7]
l1.append(base[0])
l2.append(base[1] - 4096)
l3.append(base[2] - 2*4096)
l3.append(base[3] - 3*4096)
l2.append(base[4] - 4*4096)
l3.append(base[5] - 5*4096)
l3.append(base[6] - 6*4096)
dev = next(snac.parameters()).device
codes = [
torch.tensor(l1, device=dev).unsqueeze(0),
torch.tensor(l2, device=dev).unsqueeze(0),
torch.tensor(l3, device=dev).unsqueeze(0),
]
audio = snac.decode(codes).squeeze().cpu().numpy() # float32 @24 kHz
pcm16 = (audio * 32767).astype("int16").tobytes()
return pcm16
# — FastAPI + WebSocket-Endpoint —
app = FastAPI()
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
while True:
# 1) Nachricht empfangen
msg = await ws.receive_text()
data = json.loads(msg)
text = data.get("text", "")
voice = data.get("voice", "Jakob")
# 2) Prompt → IDs/Mask
ids, mask = process_prompt(text, voice)
# 3) Token-Generation
gen_ids = model.generate(
input_ids=ids,
attention_mask=mask,
max_new_tokens=2000,
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=128258,
)
# 4) Parse + SNAC → PCM16‑Bytes
codes = parse_output(gen_ids)
pcm16 = redistribute_codes(codes)
chunk_sz = 2400 * 2 # 0.1 s @24 kHz
# 5) Stream audio‑Chunks
for i in range(0, len(pcm16), chunk_sz):
await ws.send_bytes(pcm16[i : i + chunk_sz])
await asyncio.sleep(0.1)
# 6) Ende‑Signal
await ws.send_json({"event": "eos"})
# (Verbindung bleibt offen für nächste Anfrage)
except WebSocketDisconnect:
print("Client disconnected")
except Exception as e:
print("Error in /ws/tts:", e)
# Schließe erst, nachdem Fehler gemeldet
await ws.close(code=1011)
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|