Spaces:
Paused
Paused
import os | |
import json | |
import asyncio | |
import torch | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from huggingface_hub import login | |
from snac import SNAC | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# — HF‑Token & Login — | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if HF_TOKEN: | |
login(HF_TOKEN) | |
# — Gerät wählen — | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# — Modell‑Parameter — | |
MODEL_NAME = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
START_MARKER = 128259 # <|startoftranscript|> | |
RESTART_MARKER = 128257 # <|startoftranscript_again|> | |
EOS_TOKEN = 128258 # <|endoftranscript|> | |
AUDIO_TOKEN_OFFSET = 128266 # Offset zum Zurückrechnen | |
BLOCK_TOKENS = 7 # SNAC erwartet 7 Audio‑Tokens pro Block | |
CHUNK_TOKENS = 50 # Anzahl neuer Tokens pro Generate‑Runde | |
# — FastAPI instanziieren — | |
app = FastAPI() | |
# — Damit GET / nicht 404 wirft — | |
async def read_root(): | |
return {"message": "Orpheus TTS Server ist live 🎙️"} | |
# — Modelle bei Startup laden — | |
async def load_models(): | |
global tokenizer, model, snac | |
# SNAC laden | |
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
# Tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
# TTS‑LM | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
device_map="auto", | |
torch_dtype=torch.bfloat16 if device=="cuda" else None, | |
low_cpu_mem_usage=True | |
) | |
model.config.pad_token_id = EOS_TOKEN | |
# — Eingabe aufbereiten — | |
def prepare_inputs(text: str, voice: str): | |
prompt = f"{voice}: {text}" | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
start = torch.tensor([[START_MARKER]], device=device) | |
end = torch.tensor([[128009, EOS_TOKEN]], device=device) | |
ids = torch.cat([start, input_ids, end], dim=1) | |
attn_mask = torch.ones_like(ids) | |
return ids, attn_mask | |
# — Aus 7 Audio‑Tokens ein PCM‑Block erzeugen — | |
def decode_block(block: list[int]) -> bytes: | |
l1, l2, l3 = [], [], [] | |
b = block | |
l1.append(b[0]) | |
l2.append(b[1] - 4096) | |
l3.append(b[2] - 2*4096) | |
l3.append(b[3] - 3*4096) | |
l2.append(b[4] - 4*4096) | |
l3.append(b[5] - 5*4096) | |
l3.append(b[6] - 6*4096) | |
codes = [ | |
torch.tensor(l1, device=device).unsqueeze(0), | |
torch.tensor(l2, device=device).unsqueeze(0), | |
torch.tensor(l3, device=device).unsqueeze(0), | |
] | |
audio = snac.decode(codes).squeeze().cpu().numpy() | |
pcm16 = (audio * 32767).astype("int16").tobytes() | |
return pcm16 | |
# — Generator: kleine Chunks token‑weise erzeugen und block‑weise dekodieren — | |
async def generate_and_stream(ws: WebSocket, ids, attn_mask): | |
buffer: list[int] = [] | |
past_kvs = None | |
while True: | |
# wir rufen model.generate in Häppchen auf | |
outputs = model.generate( | |
input_ids = ids if past_kvs is None else None, | |
attention_mask = attn_mask if past_kvs is None else None, | |
past_key_values= past_kvs, | |
use_cache = True, | |
max_new_tokens = CHUNK_TOKENS, | |
do_sample = True, | |
temperature = 0.7, | |
top_p = 0.95, | |
repetition_penalty = 1.1, | |
eos_token_id = EOS_TOKEN, | |
pad_token_id = EOS_TOKEN, | |
return_dict_in_generate = True, | |
output_scores = False, | |
) | |
# update past_kvs | |
past_kvs = outputs.past_key_values | |
# erhalte nur die gerade neu generierten Token | |
seq = outputs.sequences[0] | |
new_tokens = seq[-CHUNK_TOKENS:].tolist() if past_kvs is not None else seq[ids.shape[-1]:].tolist() | |
for tok in new_tokens: | |
# Neustart bei erneutem START‑Marker | |
if tok == RESTART_MARKER: | |
buffer = [] | |
continue | |
# Ende | |
if tok == EOS_TOKEN: | |
return | |
# Audio‑Code berechnen | |
buffer.append(tok - AUDIO_TOKEN_OFFSET) | |
# sobald 7 Audio‑Tokens, dekodieren und streamen | |
if len(buffer) >= BLOCK_TOKENS: | |
block = buffer[:BLOCK_TOKENS] | |
buffer = buffer[BLOCK_TOKENS:] | |
pcm = decode_block(block) | |
await ws.send_bytes(pcm) | |
# — WebSocket‑Endpoint für TTS Streaming — | |
async def tts_ws(ws: WebSocket): | |
await ws.accept() | |
try: | |
data = await ws.receive_text() | |
req = json.loads(data) | |
text = req.get("text", "") | |
voice = req.get("voice", "Jakob") | |
ids, attn_mask = prepare_inputs(text, voice) | |
await generate_and_stream(ws, ids, attn_mask) | |
await ws.close() | |
except WebSocketDisconnect: | |
pass | |
except Exception as e: | |
print("Error in /ws/tts:", e) | |
await ws.close(code=1011) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=7860) | |