Spaces:
Paused
Paused
File size: 5,244 Bytes
a09ea48 0316ec3 4189fe1 0316ec3 a09ea48 0dfc310 0316ec3 0dfc310 a09ea48 2008a3f d408dd5 1ab029d 0316ec3 0dfc310 d408dd5 a09ea48 0316ec3 674acbf 0dfc310 f001a32 d408dd5 a09ea48 b3e4aa7 0dfc310 d408dd5 a09ea48 b3e4aa7 d408dd5 b3e4aa7 d408dd5 a09ea48 ad94d02 d408dd5 b3e4aa7 d408dd5 b3e4aa7 d408dd5 97006e1 d408dd5 4189fe1 d408dd5 a8606ac a09ea48 4189fe1 d408dd5 a09ea48 0dfc310 d408dd5 b3e4aa7 4189fe1 a09ea48 4189fe1 d408dd5 a09ea48 4189fe1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login, snapshot_download
# — ENV & HF‑AUTH —
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
# — Gerät wählen —
device = "cuda" if torch.cuda.is_available() else "cpu"
# — Modelle laden —
print("Loading SNAC model...")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
snapshot_download(
repo_id=model_name,
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
ignore_patterns=[
"optimizer.pt", "pytorch_model.bin", "training_args.bin",
"scheduler.pt", "tokenizer.json", "tokenizer_config.json",
"special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.*"
]
)
print("Loading Orpheus model...")
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16
).to(device)
model.config.pad_token_id = model.config.eos_token_id
tokenizer = AutoTokenizer.from_pretrained(model_name)
# — Konstanten für Token‑Mapping —
AUDIO_TOKEN_OFFSET = 128266
START_TOKEN = 128259
SOS_TOKEN = 128257
EOS_TOKEN = 128258
# — Hilfsfunktionen —
def process_prompt(text: str, voice: str):
prompt = f"{voice}: {text}"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
start = torch.tensor([[START_TOKEN]], dtype=torch.int64, device=device)
end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
ids = torch.cat([start, input_ids, end], dim=1)
mask = torch.ones_like(ids, dtype=torch.int64, device=device)
return ids, mask
def redistribute_codes(block: list[int], snac_model: SNAC):
# exakt wie vorher: 7 Codes → 3 Layer → SNAC.decode → NumPy float32 @24 kHz
l1, l2, l3 = [], [], []
for i in range(len(block)//7):
b = block[7*i:7*i+7]
l1.append(b[0])
l2.append(b[1] - 4096)
l3.append(b[2] - 2*4096)
l3.append(b[3] - 3*4096)
l2.append(b[4] - 4*4096)
l3.append(b[5] - 5*4096)
l3.append(b[6] - 6*4096)
dev = next(snac_model.parameters()).device
codes = [
torch.tensor(l1, device=dev).unsqueeze(0),
torch.tensor(l2, device=dev).unsqueeze(0),
torch.tensor(l3, device=dev).unsqueeze(0),
]
audio = snac_model.decode(codes) # → Tensor[1, T]
return audio.squeeze().cpu().numpy()
# — FastAPI Setup —
app = FastAPI()
# 1) Hello‑World Endpoint
@app.get("/")
async def root():
return {"message": "Hallo Welt"}
# 2) WebSocket Token‑für‑Token TTS
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
while True:
# JSON mit Text & Voice empfangen
raw = await ws.receive_text()
req = json.loads(raw)
text, voice = req.get("text", ""), req.get("voice", "Jakob")
ids, mask = process_prompt(text, voice)
past_kv = None
collected = []
# im Sampling‑Loop Token für Token generieren
with torch.no_grad():
for _ in range(2000): # max 200 Tokens
out = model(
input_ids=ids if past_kv is None else None,
attention_mask=mask if past_kv is None else None,
past_key_values=past_kv,
use_cache=True,
)
logits = out.logits[:, -1, :]
next_id = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)
past_kv = out.past_key_values
token = next_id.item()
# Ende
if token == EOS_TOKEN:
break
# Reset bei SOS
if token == SOS_TOKEN:
collected = []
continue
# in Audio‑Code konvertieren
collected.append(token - AUDIO_TOKEN_OFFSET)
# sobald 7 Codes → direkt dekodieren & streamen
if len(collected) >= 7:
block = collected[:7]
collected = collected[7:]
audio_np = redistribute_codes(block, snac)
pcm16 = (audio_np * 32767).astype("int16").tobytes()
await ws.send_bytes(pcm16)
# ab jetzt nur noch past_kv verwenden
ids = None
mask = None
# zum Schluss End‑Of‑Stream signalisieren
await ws.send_text(json.dumps({"event": "eos"}))
except WebSocketDisconnect:
print("Client disconnected")
except Exception as e:
print("Error in /ws/tts:", e)
await ws.close(code=1011)
# zum lokalen Test
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|