Spaces:
Paused
Paused
File size: 4,552 Bytes
a09ea48 0316ec3 4189fe1 0316ec3 a09ea48 0dfc310 0316ec3 0dfc310 a09ea48 2008a3f 9cd424e 1ab029d 0316ec3 9cd424e d408dd5 a09ea48 0316ec3 9cd424e 674acbf 9cd424e 0dfc310 f001a32 d408dd5 9cd424e d408dd5 a09ea48 9cd424e b3e4aa7 0dfc310 d408dd5 9cd424e a09ea48 9cd424e d408dd5 b3e4aa7 9cd424e b3e4aa7 9cd424e 97006e1 9cd424e 4189fe1 d408dd5 9cd424e d408dd5 a8606ac a09ea48 4189fe1 9cd424e b3e4aa7 4189fe1 a09ea48 9cd424e a09ea48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login, snapshot_download
# — ENV & HF‑AUTH —
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
# — Device —
device = "cuda" if torch.cuda.is_available() else "cpu"
# — SNAC laden —
print("Loading SNAC model...")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
# — Orpheus‑Modell vorbereiten —
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
# Nur Konfig+Weights (ermöglicht schlankeren Container)
snapshot_download(
repo_id=model_name,
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
ignore_patterns=[
"optimizer.pt", "pytorch_model.bin", "training_args.bin",
"scheduler.pt", "tokenizer.json", "tokenizer_config.json",
"special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.*"
]
)
print("Loading Orpheus model...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
).to(device)
model.config.pad_token_id = model.config.eos_token_id
tokenizer = AutoTokenizer.from_pretrained(model_name)
# — Hilfsfunktionen —
def process_prompt(text: str, voice: str):
prompt = f"{voice}: {text}"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# füge Start-/End-Tokens hinzu
start = torch.tensor([[128259]], device=device)
end = torch.tensor([[128009, 128260]], device=device)
input_ids = torch.cat([start, inputs.input_ids, end], dim=1)
return input_ids
def parse_output(generated_ids: torch.LongTensor):
token_to_find = 128257
token_to_remove = 128258
idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
if idxs.numel() > 0:
cropped = generated_ids[:, idxs[-1].item() + 1 :]
else:
cropped = generated_ids
row = cropped[0][cropped[0] != token_to_remove]
return row.tolist()
def redistribute_codes(code_list: list[int], snac_model: SNAC):
layer1, layer2, layer3 = [], [], []
for i in range((len(code_list) + 1) // 7):
base = code_list[7*i : 7*i+7]
layer1.append(base[0])
layer2.append(base[1] - 4096)
layer3.append(base[2] - 2*4096)
layer3.append(base[3] - 3*4096)
layer2.append(base[4] - 4*4096)
layer3.append(base[5] - 5*4096)
layer3.append(base[6] - 6*4096)
dev = next(snac_model.parameters()).device
codes = [
torch.tensor(layer1, device=dev).unsqueeze(0),
torch.tensor(layer2, device=dev).unsqueeze(0),
torch.tensor(layer3, device=dev).unsqueeze(0),
]
audio = snac_model.decode(codes)
return audio.detach().squeeze().cpu().numpy()
# — FastAPI App —
app = FastAPI()
@app.get("/")
async def hello():
return {"message": "Hello, Orpheus TTS is up and running!"}
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
# **Nur EIN Request pro Connection**
raw = await ws.receive_text()
data = json.loads(raw)
text = data.get("text", "")
voice = data.get("voice", "Jakob")
# 1) Text → input_ids
input_ids = process_prompt(text, voice)
# 2) Generation
gen_ids = model.generate(
input_ids=input_ids,
max_new_tokens=2000, # hier kannst du hochsetzen
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=model.config.eos_token_id,
)
# 3) Token → Audio
codes = parse_output(gen_ids)
audio_np = redistribute_codes(codes, snac)
# 4) PCM16-Bytes in ~0.1s‑Chunks streamen
pcm16 = (audio_np * 32767).astype("int16").tobytes()
chunk_size = 2400 * 2 # 2400 Samples @24kHz = 0.1s * 2 Byte
for i in range(0, len(pcm16), chunk_size):
await ws.send_bytes(pcm16[i : i+chunk_size])
await asyncio.sleep(0.1)
# Sauber schließen, Client erhält ConnectionClosedOK
await ws.close()
except WebSocketDisconnect:
print("Client disconnected")
except Exception as e:
# Log und saubere Fehler‑Closure
print("Error in /ws/tts:", e)
await ws.close(code=1011)
|