File size: 4,489 Bytes
2c15189
 
a3af518
0316ec3
4189fe1
9bf14d0
d9ea17d
a3af518
0316ec3
a3af518
d9ea17d
2c15189
9bf14d0
2008a3f
a3af518
1ab029d
0316ec3
a3af518
9bf14d0
0dfc310
a3af518
9bf14d0
 
d9ea17d
9bf14d0
a3af518
9bf14d0
d9ea17d
a3af518
 
7b0d42c
d9ea17d
9bf14d0
d9ea17d
a3af518
7b0d42c
9bf14d0
a3af518
 
9bf14d0
3281189
a3af518
 
 
 
 
 
 
 
 
 
 
 
7b0d42c
d9ea17d
a3af518
 
 
 
 
 
 
 
 
 
 
 
 
9bf14d0
a3af518
 
 
9bf14d0
a3af518
 
9bf14d0
a3af518
a8606ac
2c15189
a09ea48
4189fe1
a3af518
10be82b
9bf14d0
a3af518
d9ea17d
10be82b
 
a3af518
10be82b
a3af518
10be82b
a3af518
 
10be82b
 
 
 
a3af518
10be82b
 
a3af518
 
 
10be82b
a3af518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10be82b
 
 
a3af518
c70d8eb
fd06e70
4189fe1
a3af518
fd06e70
a3af518
a09ea48
a3af518
2c15189
a09ea48
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer

# — HF‑Token & Login —
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(HF_TOKEN)

# — Device wählen —
device = "cuda" if torch.cuda.is_available() else "cpu"

# — FastAPI instanziieren —
app = FastAPI()

# — Hello‑Route, damit GET / nicht 404 wirft —
@app.get("/")
async def read_root():
    return {"message": "Hello, world!"}

# — Modelle bei Startup laden —
@app.on_event("startup")
async def load_models():
    global tokenizer, model, snac
    snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
    REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
    tokenizer = AutoTokenizer.from_pretrained(REPO)
    model = AutoModelForCausalLM.from_pretrained(
        REPO,
        device_map="auto",
        torch_dtype=torch.bfloat16 if device == "cuda" else None,
        low_cpu_mem_usage=True
    )
    # Für pad-token fallback auf eos
    model.config.pad_token_id = model.config.eos_token_id

# — Hilfsfunktionen — 
START_TOKEN       = 128259
END_TOKENS        = [128009, 128260]
RESET_TOKEN       = 128257
AUDIO_OFFSET      = 128266
EOS_TOKEN         = model.config.eos_token_id if 'model' in globals() else 128258

def prepare_inputs(text: str, voice: str):
    prompt = f"{voice}: {text}"
    ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    start = torch.tensor([[START_TOKEN]], device=device)
    end   = torch.tensor([END_TOKENS], device=device)
    input_ids = torch.cat([start, ids, end], dim=1)
    attention_mask = torch.ones_like(input_ids)
    return input_ids, attention_mask

def decode_block(block: list[int]):
    # aus genau 7 Audio‑Codes ein PCM‑Byte‑Block bauen
    l1, l2, l3 = [], [], []
    b = block
    l1.append(b[0])
    l2.append(b[1] -   4096)
    l3.append(b[2] - 2*4096)
    l3.append(b[3] - 3*4096)
    l2.append(b[4] - 4*4096)
    l3.append(b[5] - 5*4096)
    l3.append(b[6] - 6*4096)
    codes = [
        torch.tensor(l1, device=device).unsqueeze(0),
        torch.tensor(l2, device=device).unsqueeze(0),
        torch.tensor(l3, device=device).unsqueeze(0),
    ]
    audio = snac.decode(codes).squeeze().cpu().numpy()
    return (audio * 32767).astype("int16").tobytes()

# — WebSocket‑Endpoint für TTS Streaming —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
    await ws.accept()
    try:
        msg = await ws.receive_text()
        req   = json.loads(msg)
        text  = req.get("text", "")
        voice = req.get("voice", "Jakob")

        input_ids, attention_mask = prepare_inputs(text, voice)
        past_kvs = None
        collected = []

        # Token‑für‑Token mit eigener Sampling‑Schleife
        while True:
            out = model(
                input_ids=input_ids if past_kvs is None else None,
                attention_mask=attention_mask if past_kvs is None else None,
                past_key_values=past_kvs,
                use_cache=True,
            )
            logits = out.logits[:, -1, :]
            past_kvs = out.past_key_values

            # Sampling
            probs = torch.softmax(logits, dim=-1)
            nxt   = torch.multinomial(probs, num_samples=1).item()

            # EOS → fertig
            if nxt == EOS_TOKEN:
                break
            # RESET → alte Sammlung verwerfen
            if nxt == RESET_TOKEN:
                collected = []
                # und input_ids für nächsten Durchlauf auf None setzen
                input_ids = None
                attention_mask = None
                continue

            # Audio‑Code abziehen & sammeln
            collected.append(nxt - AUDIO_OFFSET)
            # jede 7 Codes → dekodieren & streamen
            if len(collected) == 7:
                pcm = decode_block(collected)
                collected = []
                await ws.send_bytes(pcm)

            # nur beim allerersten Schritt mit IDs arbeiten
            input_ids = None
            attention_mask = None

        # Stream sauber beenden
        await ws.close()

    except WebSocketDisconnect:
        # Client hat Disconnect gemacht → nichts tun
        pass

    except Exception as e:
        # auf Fehler 1011 senden
        print("Error in /ws/tts:", e)
        await ws.close(code=1011)