File size: 6,209 Bytes
2c15189
 
 
0316ec3
4189fe1
0316ec3
a09ea48
 
0dfc310
0316ec3
2c15189
a09ea48
2c15189
 
 
2008a3f
2c15189
1ab029d
0316ec3
2c15189
 
a09ea48
0316ec3
2c15189
674acbf
2c15189
 
0dfc310
 
 
2c15189
 
 
 
0dfc310
f001a32
2c15189
d408dd5
9cd424e
2c15189
f3890ef
 
a09ea48
9cd424e
b3e4aa7
0dfc310
2c15189
9cd424e
a09ea48
2c15189
 
 
a09ea48
2c15189
9cd424e
 
2c15189
9cd424e
2c15189
 
 
 
 
 
 
 
 
 
 
9cd424e
 
2c15189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97006e1
4189fe1
 
d408dd5
2c15189
 
d408dd5
a8606ac
2c15189
a09ea48
4189fe1
2c15189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4189fe1
2c15189
a09ea48
2c15189
a09ea48
f3890ef
2c15189
 
 
f3890ef
2c15189
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login, snapshot_download

# ─── ENV & HF TOKEN ────────────────────────────────────────────────────────────
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN)

# ─── DEVICE ─────────────────────────────────────────────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"

# ─── SNAC ───────────────────────────────────────────────────────────────────────
print("Loading SNAC model…")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)

# ─── ORPHEUS ────────────────────────────────────────────────────────────────────
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"

# pre‑download only the config + safetensors, damit das Image schlank bleibt
snapshot_download(
    repo_id=model_name,
    allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
    ignore_patterns=[
        "optimizer.pt", "pytorch_model.bin", "training_args.bin",
        "scheduler.pt", "tokenizer.*", "vocab.json", "merges.txt"
    ]
)

print("Loading Orpheus model…")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16  # optional: beschleunigt das FP16‑Àhnliche Rechnen
)
model = model.to(device)
model.config.pad_token_id = model.config.eos_token_id

tokenizer = AutoTokenizer.from_pretrained(model_name)

# ─── HILFSFUNKTIONEN ──────────────────────────────────────────────────────────

def process_prompt(text: str, voice: str):
    """
    Baut aus Text+Voice ein batch‑Tensor input_ids fΓΌr `model.generate`.
    """
    prompt = f"{voice}: {text}"
    tok = tokenizer(prompt, return_tensors="pt").to(device)
    start = torch.tensor([[128259]], device=device)
    end   = torch.tensor([[128009, 128260]], device=device)
    return torch.cat([start, tok.input_ids, end], dim=1)

def parse_output(generated_ids: torch.LongTensor):
    """
    Schneidet bis zum letzten 128257 und entfernt 128258, gibt reine Token‑Liste zurΓΌck.
    """
    START, PAD = 128257, 128258
    idxs = (generated_ids == START).nonzero(as_tuple=True)[1]
    if idxs.numel() > 0:
        cropped = generated_ids[:, idxs[-1].item()+1:]
    else:
        cropped = generated_ids
    row = cropped[0][cropped[0] != PAD]
    return row.tolist()

def redistribute_codes(code_list: list[int], snac_model: SNAC):
    """
    Verteilt 7er‑BlΓΆcke auf die drei SNAC‑Layer und dekodiert zu Audio (numpy float32).
    """
    layer1, layer2, layer3 = [], [], []
    for i in range((len(code_list) + 1) // 7):
        base = code_list[7*i : 7*i+7]
        layer1.append(base[0])
        layer2.append(base[1] -   4096)
        layer3.append(base[2] - 2*4096)
        layer3.append(base[3] - 3*4096)
        layer2.append(base[4] - 4*4096)
        layer3.append(base[5] - 5*4096)
        layer3.append(base[6] - 6*4096)
    dev = next(snac_model.parameters()).device
    codes = [
        torch.tensor(layer1, device=dev).unsqueeze(0),
        torch.tensor(layer2, device=dev).unsqueeze(0),
        torch.tensor(layer3, device=dev).unsqueeze(0),
    ]
    audio = snac_model.decode(codes)
    return audio.detach().squeeze().cpu().numpy()

# ─── FASTAPI ────────────────────────────────────────────────────────────────────

app = FastAPI()

@app.get("/")
async def healthcheck():
    return {"status": "ok", "msg": "Hello, Orpheus TTS up!"}

@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
    await ws.accept()
    try:
        while True:
            # 1) Eintreffende JSON‐Nachricht parsen
            data = json.loads(await ws.receive_text())
            text  = data.get("text", "")
            voice = data.get("voice", "Jakob")

            # 2) Prompt β†’ input_ids
            ids = process_prompt(text, voice)

            # 3) Token‐Erzeugung
            gen_ids = model.generate(
                input_ids=ids,
                max_new_tokens=2000,      # hier z.B. 20k geht auch, wird aber speicherintensiv
                do_sample=True,
                temperature=0.7,
                top_p=0.95,
                repetition_penalty=1.1,
                eos_token_id=model.config.eos_token_id,
            )

            # 4) Tokens β†’ Code‐Liste β†’ Audio
            codes    = parse_output(gen_ids)
            audio_np = redistribute_codes(codes, snac)

            # 5) PCM16‐Stream in 0.1 s‐BlΓΆcken
            pcm16 = (audio_np * 32767).astype("int16").tobytes()
            chunk = 2400 * 2
            for i in range(0, len(pcm16), chunk):
                await ws.send_bytes(pcm16[i : i+chunk])
                await asyncio.sleep(0.1)

    except WebSocketDisconnect:
        print("Client disconnected")
    except Exception as e:
        print("Error in /ws/tts:", e)
        await ws.close(code=1011)

# ─── START ──────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")