Tomtom84's picture
Update app.py
d4630a2 verified
raw
history blame
5.27 kB
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import PlainTextResponse
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# — ENV & HF‑AUTH —
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
# automatisch über huggingface-cli eingeloggt
os.environ["HUGGINGFACE_HUB_TOKEN"] = HF_TOKEN
# — FastAPI →
app = FastAPI()
@app.get("/")
async def hello():
return PlainTextResponse("Hallo Welt!")
# — Device konfigurieren —
device = "cuda" if torch.cuda.is_available() else "cpu"
# — SNAC laden —
print("Loading SNAC model…")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
# — Orpheus/Kartoffel‑3B über PEFT laden —
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
print(f"Loading base LM + PEFT from {model_name}…")
base = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model = PeftModel.from_pretrained(
base,
model_name,
device_map="auto",
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
# sicherstellen, dass pad_token_id gesetzt ist
model.config.pad_token_id = model.config.eos_token_id
# — Hilfsfunktionen —
def prepare_prompt(text: str, voice: str):
"""Setzt Start‑ und End‑Marker um den eigentlichen Prompt."""
if voice:
full = f"{voice}: {text}"
else:
full = text
start = torch.tensor([[128259]], dtype=torch.int64) # BOS für Audio
end = torch.tensor([[128009, 128260]], dtype=torch.int64) # ggf. Speaker‑ID + Marker
enc = tokenizer(full, return_tensors="pt").input_ids
seq = torch.cat([start, enc, end], dim=1).to(device)
mask = torch.ones_like(seq).to(device)
return seq, mask
def extract_audio_tokens(generated: torch.LongTensor):
"""Croppe alles bis zum echten Audio-Start, entferne EOS und mache 7er-Batches."""
bos_tok = 128257
eos_tok = 128258
# letzten Start‑Token finden und ab da weiter
idxs = (generated == bos_tok).nonzero(as_tuple=True)[1]
if idxs.numel() > 0:
cut = idxs[-1].item() + 1
cropped = generated[:, cut:]
else:
cropped = generated
# EOS‑Marker entfernen
flat = cropped[0][cropped[0] != eos_tok]
# nur ein Vielfaches von 7 behalten
length = (flat.size(0) // 7) * 7
flat = flat[:length]
# Die Audio‑Token beginnen ab Offset 128266
return [(t.item() - 128266) for t in flat]
def decode_and_stream(tokens: list[int], ws: WebSocket):
"""Wandelt 7er‑Gruppen in Wave‑Samples um und streamt in 0.1 s Chunks."""
# gruppiere nach 7 und dekodiere jeweils
pcm16 = bytearray()
offset = 0
while offset + 7 <= len(tokens):
block = tokens[offset:offset+7]
offset += 7
# SNAC‑Input vorbereiten
# Layer‑1: direkt, Layer‑2/3 mit Offsets
l1, l2, l3 = [], [], []
l1.append(block[0])
l2.append(block[1] - 4096)
l3.append(block[2] - 2*4096)
l3.append(block[3] - 3*4096)
l2.append(block[4] - 4*4096)
l3.append(block[5] - 5*4096)
l3.append(block[6] - 6*4096)
t1 = torch.tensor(l1, device=device).unsqueeze(0)
t2 = torch.tensor(l2, device=device).unsqueeze(0)
t3 = torch.tensor(l3, device=device).unsqueeze(0)
audio = snac.decode([t1, t2, t3]).squeeze().cpu().numpy()
# in PCM16 @24 kHz
pcm = (audio * 32767).astype("int16").tobytes()
pcm16.extend(pcm)
# in 0.1 s‑Chunks (2400 Samples ×2 Bytes)
chunk_size = 2400 * 2
for i in range(0, len(pcm16), chunk_size):
ws.send_bytes(pcm16[i : i+chunk_size])
# ohne Pause kann das WebSocket überlastet werden
asyncio.sleep(0.1)
# — WebSocket TTS Endpoint —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
while True:
raw = await ws.receive_text()
req = json.loads(raw)
text = req.get("text", "")
voice = req.get("voice", "")
# Prompt vorbereiten
ids, mask = prepare_prompt(text, voice)
# Audio‑Token generieren
gen = model.generate(
input_ids=ids,
attention_mask=mask,
max_new_tokens=4000,
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=128258,
forced_bos_token_id=128259,
use_cache=True,
)
codes = extract_audio_tokens(gen)
# stream synchron
await decode_and_stream(codes, ws)
# sauber schließen
await ws.close(code=1000)
break
except WebSocketDisconnect:
print("Client disconnected")
except Exception as e:
print("Error in /ws/tts:", e)
await ws.close(code=1011)
# — Lokal starten —
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)