Spaces:
Paused
Paused
File size: 5,266 Bytes
2c15189 0316ec3 4189fe1 d4630a2 0316ec3 a09ea48 d4630a2 0316ec3 67c3132 a09ea48 2c15189 d4630a2 2008a3f d4630a2 1ab029d 0316ec3 d4630a2 a09ea48 0316ec3 d4630a2 674acbf d4630a2 0dfc310 d4630a2 9cd424e d4630a2 0dfc310 d4630a2 3281189 67c3132 d4630a2 2c15189 67c3132 d4630a2 2c15189 d4630a2 a8606ac 2c15189 a09ea48 4189fe1 2c15189 d4630a2 2c15189 d4630a2 2c15189 d4630a2 2c15189 67c3132 d4630a2 2c15189 d4630a2 2c15189 d4630a2 2c15189 d4630a2 2c15189 4189fe1 2c15189 a09ea48 2c15189 a09ea48 f3890ef d4630a2 2c15189 f3890ef 67c3132 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import PlainTextResponse
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# — ENV & HF‑AUTH —
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
# automatisch über huggingface-cli eingeloggt
os.environ["HUGGINGFACE_HUB_TOKEN"] = HF_TOKEN
# — FastAPI →
app = FastAPI()
@app.get("/")
async def hello():
return PlainTextResponse("Hallo Welt!")
# — Device konfigurieren —
device = "cuda" if torch.cuda.is_available() else "cpu"
# — SNAC laden —
print("Loading SNAC model…")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
# — Orpheus/Kartoffel‑3B über PEFT laden —
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
print(f"Loading base LM + PEFT from {model_name}…")
base = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model = PeftModel.from_pretrained(
base,
model_name,
device_map="auto",
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
# sicherstellen, dass pad_token_id gesetzt ist
model.config.pad_token_id = model.config.eos_token_id
# — Hilfsfunktionen —
def prepare_prompt(text: str, voice: str):
"""Setzt Start‑ und End‑Marker um den eigentlichen Prompt."""
if voice:
full = f"{voice}: {text}"
else:
full = text
start = torch.tensor([[128259]], dtype=torch.int64) # BOS für Audio
end = torch.tensor([[128009, 128260]], dtype=torch.int64) # ggf. Speaker‑ID + Marker
enc = tokenizer(full, return_tensors="pt").input_ids
seq = torch.cat([start, enc, end], dim=1).to(device)
mask = torch.ones_like(seq).to(device)
return seq, mask
def extract_audio_tokens(generated: torch.LongTensor):
"""Croppe alles bis zum echten Audio-Start, entferne EOS und mache 7er-Batches."""
bos_tok = 128257
eos_tok = 128258
# letzten Start‑Token finden und ab da weiter
idxs = (generated == bos_tok).nonzero(as_tuple=True)[1]
if idxs.numel() > 0:
cut = idxs[-1].item() + 1
cropped = generated[:, cut:]
else:
cropped = generated
# EOS‑Marker entfernen
flat = cropped[0][cropped[0] != eos_tok]
# nur ein Vielfaches von 7 behalten
length = (flat.size(0) // 7) * 7
flat = flat[:length]
# Die Audio‑Token beginnen ab Offset 128266
return [(t.item() - 128266) for t in flat]
def decode_and_stream(tokens: list[int], ws: WebSocket):
"""Wandelt 7er‑Gruppen in Wave‑Samples um und streamt in 0.1 s Chunks."""
# gruppiere nach 7 und dekodiere jeweils
pcm16 = bytearray()
offset = 0
while offset + 7 <= len(tokens):
block = tokens[offset:offset+7]
offset += 7
# SNAC‑Input vorbereiten
# Layer‑1: direkt, Layer‑2/3 mit Offsets
l1, l2, l3 = [], [], []
l1.append(block[0])
l2.append(block[1] - 4096)
l3.append(block[2] - 2*4096)
l3.append(block[3] - 3*4096)
l2.append(block[4] - 4*4096)
l3.append(block[5] - 5*4096)
l3.append(block[6] - 6*4096)
t1 = torch.tensor(l1, device=device).unsqueeze(0)
t2 = torch.tensor(l2, device=device).unsqueeze(0)
t3 = torch.tensor(l3, device=device).unsqueeze(0)
audio = snac.decode([t1, t2, t3]).squeeze().cpu().numpy()
# in PCM16 @24 kHz
pcm = (audio * 32767).astype("int16").tobytes()
pcm16.extend(pcm)
# in 0.1 s‑Chunks (2400 Samples ×2 Bytes)
chunk_size = 2400 * 2
for i in range(0, len(pcm16), chunk_size):
ws.send_bytes(pcm16[i : i+chunk_size])
# ohne Pause kann das WebSocket überlastet werden
asyncio.sleep(0.1)
# — WebSocket TTS Endpoint —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
while True:
raw = await ws.receive_text()
req = json.loads(raw)
text = req.get("text", "")
voice = req.get("voice", "")
# Prompt vorbereiten
ids, mask = prepare_prompt(text, voice)
# Audio‑Token generieren
gen = model.generate(
input_ids=ids,
attention_mask=mask,
max_new_tokens=4000,
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=128258,
forced_bos_token_id=128259,
use_cache=True,
)
codes = extract_audio_tokens(gen)
# stream synchron
await decode_and_stream(codes, ws)
# sauber schließen
await ws.close(code=1000)
break
except WebSocketDisconnect:
print("Client disconnected")
except Exception as e:
print("Error in /ws/tts:", e)
await ws.close(code=1011)
# — Lokal starten —
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|