File size: 5,244 Bytes
a09ea48
 
 
0316ec3
4189fe1
0316ec3
a09ea48
 
0dfc310
0316ec3
0dfc310
a09ea48
 
 
 
2008a3f
d408dd5
1ab029d
0316ec3
0dfc310
d408dd5
a09ea48
0316ec3
674acbf
0dfc310
 
 
 
 
 
 
 
 
f001a32
d408dd5
 
 
 
a09ea48
b3e4aa7
0dfc310
d408dd5
 
 
 
 
 
 
a09ea48
 
b3e4aa7
d408dd5
b3e4aa7
d408dd5
 
a09ea48
ad94d02
d408dd5
 
b3e4aa7
d408dd5
 
 
 
 
 
 
 
 
 
b3e4aa7
 
 
 
 
d408dd5
 
97006e1
d408dd5
4189fe1
 
d408dd5
 
 
 
 
 
a8606ac
a09ea48
 
4189fe1
 
d408dd5
 
 
 
a09ea48
0dfc310
d408dd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3e4aa7
4189fe1
a09ea48
 
 
 
4189fe1
d408dd5
a09ea48
4189fe1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login, snapshot_download

# — ENV & HF‑AUTH —
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN)

# — Gerät wählen —
device = "cuda" if torch.cuda.is_available() else "cpu"

# — Modelle laden —
print("Loading SNAC model...")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)

model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
snapshot_download(
    repo_id=model_name,
    allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
    ignore_patterns=[
        "optimizer.pt", "pytorch_model.bin", "training_args.bin",
        "scheduler.pt", "tokenizer.json", "tokenizer_config.json",
        "special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.*"
    ]
)

print("Loading Orpheus model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.bfloat16
).to(device)
model.config.pad_token_id = model.config.eos_token_id
tokenizer = AutoTokenizer.from_pretrained(model_name)

# — Konstanten für Token‑Mapping —
AUDIO_TOKEN_OFFSET = 128266
START_TOKEN  = 128259
SOS_TOKEN    = 128257
EOS_TOKEN    = 128258

# — Hilfsfunktionen —
def process_prompt(text: str, voice: str):
    prompt = f"{voice}: {text}"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    start = torch.tensor([[START_TOKEN]], dtype=torch.int64, device=device)
    end   = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
    ids  = torch.cat([start, input_ids, end], dim=1)
    mask = torch.ones_like(ids, dtype=torch.int64, device=device)
    return ids, mask

def redistribute_codes(block: list[int], snac_model: SNAC):
    # exakt wie vorher: 7 Codes → 3 Layer → SNAC.decode → NumPy float32 @24 kHz
    l1, l2, l3 = [], [], []
    for i in range(len(block)//7):
        b = block[7*i:7*i+7]
        l1.append(b[0])
        l2.append(b[1] -   4096)
        l3.append(b[2] - 2*4096)
        l3.append(b[3] - 3*4096)
        l2.append(b[4] - 4*4096)
        l3.append(b[5] - 5*4096)
        l3.append(b[6] - 6*4096)
    dev = next(snac_model.parameters()).device
    codes = [
        torch.tensor(l1, device=dev).unsqueeze(0),
        torch.tensor(l2, device=dev).unsqueeze(0),
        torch.tensor(l3, device=dev).unsqueeze(0),
    ]
    audio = snac_model.decode(codes)  # → Tensor[1, T]
    return audio.squeeze().cpu().numpy()

# — FastAPI Setup —
app = FastAPI()

# 1) Hello‑World Endpoint
@app.get("/")
async def root():
    return {"message": "Hallo Welt"}

# 2) WebSocket Token‑für‑Token TTS
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
    await ws.accept()
    try:
        while True:
            # JSON mit Text & Voice empfangen
            raw = await ws.receive_text()
            req = json.loads(raw)
            text, voice = req.get("text", ""), req.get("voice", "Jakob")
            ids, mask = process_prompt(text, voice)

            past_kv = None
            collected = []

            # im Sampling‑Loop Token für Token generieren
            with torch.no_grad():
                for _ in range(2000):  # max 200 Tokens
                    out = model(
                        input_ids=ids if past_kv is None else None,
                        attention_mask=mask if past_kv is None else None,
                        past_key_values=past_kv,
                        use_cache=True,
                    )
                    logits = out.logits[:, -1, :]
                    next_id = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)
                    past_kv = out.past_key_values

                    token = next_id.item()
                    # Ende
                    if token == EOS_TOKEN:
                        break
                    # Reset bei SOS
                    if token == SOS_TOKEN:
                        collected = []
                        continue

                    # in Audio‑Code konvertieren
                    collected.append(token - AUDIO_TOKEN_OFFSET)

                    # sobald 7 Codes → direkt dekodieren & streamen
                    if len(collected) >= 7:
                        block = collected[:7]
                        collected = collected[7:]
                        audio_np = redistribute_codes(block, snac)
                        pcm16 = (audio_np * 32767).astype("int16").tobytes()
                        await ws.send_bytes(pcm16)

                    # ab jetzt nur noch past_kv verwenden
                    ids  = None
                    mask = None

            # zum Schluss End‑Of‑Stream signalisieren
            await ws.send_text(json.dumps({"event": "eos"}))

    except WebSocketDisconnect:
        print("Client disconnected")
    except Exception as e:
        print("Error in /ws/tts:", e)
        await ws.close(code=1011)

# zum lokalen Test
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)