Spaces:
Paused
Paused
File size: 6,209 Bytes
2c15189 0316ec3 4189fe1 0316ec3 a09ea48 0dfc310 0316ec3 2c15189 a09ea48 2c15189 2008a3f 2c15189 1ab029d 0316ec3 2c15189 a09ea48 0316ec3 2c15189 674acbf 2c15189 0dfc310 2c15189 0dfc310 f001a32 2c15189 d408dd5 9cd424e 2c15189 f3890ef a09ea48 9cd424e b3e4aa7 0dfc310 2c15189 9cd424e a09ea48 2c15189 a09ea48 2c15189 9cd424e 2c15189 9cd424e 2c15189 9cd424e 2c15189 97006e1 4189fe1 d408dd5 2c15189 d408dd5 a8606ac 2c15189 a09ea48 4189fe1 2c15189 4189fe1 2c15189 a09ea48 2c15189 a09ea48 f3890ef 2c15189 f3890ef 2c15189 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login, snapshot_download
# βββ ENV & HF TOKEN ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
# βββ DEVICE βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
device = "cuda" if torch.cuda.is_available() else "cpu"
# βββ SNAC βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("Loading SNAC modelβ¦")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
# βββ ORPHEUS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
# preβdownload only the config + safetensors, damit das Image schlank bleibt
snapshot_download(
repo_id=model_name,
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
ignore_patterns=[
"optimizer.pt", "pytorch_model.bin", "training_args.bin",
"scheduler.pt", "tokenizer.*", "vocab.json", "merges.txt"
]
)
print("Loading Orpheus modelβ¦")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16 # optional: beschleunigt das FP16βΓ€hnliche Rechnen
)
model = model.to(device)
model.config.pad_token_id = model.config.eos_token_id
tokenizer = AutoTokenizer.from_pretrained(model_name)
# βββ HILFSFUNKTIONEN ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def process_prompt(text: str, voice: str):
"""
Baut aus Text+Voice ein batchβTensor input_ids fΓΌr `model.generate`.
"""
prompt = f"{voice}: {text}"
tok = tokenizer(prompt, return_tensors="pt").to(device)
start = torch.tensor([[128259]], device=device)
end = torch.tensor([[128009, 128260]], device=device)
return torch.cat([start, tok.input_ids, end], dim=1)
def parse_output(generated_ids: torch.LongTensor):
"""
Schneidet bis zum letzten 128257 und entfernt 128258, gibt reine TokenβListe zurΓΌck.
"""
START, PAD = 128257, 128258
idxs = (generated_ids == START).nonzero(as_tuple=True)[1]
if idxs.numel() > 0:
cropped = generated_ids[:, idxs[-1].item()+1:]
else:
cropped = generated_ids
row = cropped[0][cropped[0] != PAD]
return row.tolist()
def redistribute_codes(code_list: list[int], snac_model: SNAC):
"""
Verteilt 7erβBlΓΆcke auf die drei SNACβLayer und dekodiert zu Audio (numpy float32).
"""
layer1, layer2, layer3 = [], [], []
for i in range((len(code_list) + 1) // 7):
base = code_list[7*i : 7*i+7]
layer1.append(base[0])
layer2.append(base[1] - 4096)
layer3.append(base[2] - 2*4096)
layer3.append(base[3] - 3*4096)
layer2.append(base[4] - 4*4096)
layer3.append(base[5] - 5*4096)
layer3.append(base[6] - 6*4096)
dev = next(snac_model.parameters()).device
codes = [
torch.tensor(layer1, device=dev).unsqueeze(0),
torch.tensor(layer2, device=dev).unsqueeze(0),
torch.tensor(layer3, device=dev).unsqueeze(0),
]
audio = snac_model.decode(codes)
return audio.detach().squeeze().cpu().numpy()
# βββ FASTAPI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app = FastAPI()
@app.get("/")
async def healthcheck():
return {"status": "ok", "msg": "Hello, Orpheus TTS up!"}
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
while True:
# 1) Eintreffende JSONβNachricht parsen
data = json.loads(await ws.receive_text())
text = data.get("text", "")
voice = data.get("voice", "Jakob")
# 2) Prompt β input_ids
ids = process_prompt(text, voice)
# 3) TokenβErzeugung
gen_ids = model.generate(
input_ids=ids,
max_new_tokens=2000, # hier z.B. 20k geht auch, wird aber speicherintensiv
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=model.config.eos_token_id,
)
# 4) Tokens β CodeβListe β Audio
codes = parse_output(gen_ids)
audio_np = redistribute_codes(codes, snac)
# 5) PCM16βStream in 0.1βsβBlΓΆcken
pcm16 = (audio_np * 32767).astype("int16").tobytes()
chunk = 2400 * 2
for i in range(0, len(pcm16), chunk):
await ws.send_bytes(pcm16[i : i+chunk])
await asyncio.sleep(0.1)
except WebSocketDisconnect:
print("Client disconnected")
except Exception as e:
print("Error in /ws/tts:", e)
await ws.close(code=1011)
# βββ START ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")
|