File size: 5,170 Bytes
2c15189
 
a3af518
0316ec3
4189fe1
9bf14d0
d9ea17d
a3af518
0316ec3
a3af518
d9ea17d
2c15189
9bf14d0
2008a3f
986d4cd
1ab029d
0316ec3
986d4cd
 
 
 
 
 
 
 
 
a3af518
9bf14d0
0dfc310
986d4cd
9bf14d0
 
986d4cd
9bf14d0
a3af518
9bf14d0
d9ea17d
a3af518
986d4cd
a3af518
986d4cd
 
 
9bf14d0
986d4cd
a3af518
986d4cd
9bf14d0
a3af518
986d4cd
a3af518
986d4cd
a3af518
986d4cd
 
 
 
 
 
 
 
 
 
a3af518
 
 
 
 
 
 
 
 
9bf14d0
a3af518
 
 
9bf14d0
a3af518
986d4cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bf14d0
a3af518
a8606ac
2c15189
a09ea48
4189fe1
986d4cd
 
9bf14d0
a3af518
d9ea17d
986d4cd
 
a3af518
c70d8eb
4189fe1
fd06e70
a09ea48
2c15189
a09ea48
986d4cd
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer

# — HF‑Token & Login —
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(HF_TOKEN)

# — Gerät wählen —
device = "cuda" if torch.cuda.is_available() else "cpu"

# — Modell‑Parameter —
MODEL_NAME         = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
START_MARKER       = 128259        # <|startoftranscript|>
RESTART_MARKER     = 128257        # <|startoftranscript_again|>
EOS_TOKEN          = 128258        # <|endoftranscript|>
AUDIO_TOKEN_OFFSET = 128266        # Offset zum Zurückrechnen
BLOCK_TOKENS       = 7             # SNAC erwartet 7 Audio‑Tokens pro Block
CHUNK_TOKENS       = 50            # Anzahl neuer Tokens pro Generate‑Runde

# — FastAPI instanziieren —
app = FastAPI()

# — Damit GET / nicht 404 wirft —
@app.get("/")
async def read_root():
    return {"message": "Orpheus TTS Server ist live 🎙️"}

# — Modelle bei Startup laden —
@app.on_event("startup")
async def load_models():
    global tokenizer, model, snac
    # SNAC laden
    snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    # TTS‑LM
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map="auto",
        torch_dtype=torch.bfloat16 if device=="cuda" else None,
        low_cpu_mem_usage=True
    )
    model.config.pad_token_id = EOS_TOKEN

# — Eingabe aufbereiten —
def prepare_inputs(text: str, voice: str):
    prompt     = f"{voice}: {text}"
    input_ids  = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    start      = torch.tensor([[START_MARKER]], device=device)
    end        = torch.tensor([[128009, EOS_TOKEN]], device=device)
    ids        = torch.cat([start, input_ids, end], dim=1)
    attn_mask  = torch.ones_like(ids)
    return ids, attn_mask

# — Aus 7 Audio‑Tokens ein PCM‑Block erzeugen —
def decode_block(block: list[int]) -> bytes:
    l1, l2, l3 = [], [], []
    b = block
    l1.append(b[0])
    l2.append(b[1] -   4096)
    l3.append(b[2] - 2*4096)
    l3.append(b[3] - 3*4096)
    l2.append(b[4] - 4*4096)
    l3.append(b[5] - 5*4096)
    l3.append(b[6] - 6*4096)
    codes = [
        torch.tensor(l1, device=device).unsqueeze(0),
        torch.tensor(l2, device=device).unsqueeze(0),
        torch.tensor(l3, device=device).unsqueeze(0),
    ]
    audio = snac.decode(codes).squeeze().cpu().numpy()
    pcm16 = (audio * 32767).astype("int16").tobytes()
    return pcm16

# — Generator: kleine Chunks token‑weise erzeugen und block‑weise dekodieren —
async def generate_and_stream(ws: WebSocket, ids, attn_mask):
    buffer: list[int] = []
    past_kvs = None

    while True:
        # wir rufen model.generate in Häppchen auf
        outputs = model.generate(
            input_ids      = ids      if past_kvs is None else None,
            attention_mask = attn_mask if past_kvs is None else None,
            past_key_values= past_kvs,
            use_cache      = True,
            max_new_tokens = CHUNK_TOKENS,
            do_sample      = True,
            temperature    = 0.7,
            top_p          = 0.95,
            repetition_penalty = 1.1,
            eos_token_id   = EOS_TOKEN,
            pad_token_id   = EOS_TOKEN,
            return_dict_in_generate = True,
            output_scores           = False,
        )

        # update past_kvs
        past_kvs = outputs.past_key_values

        # erhalte nur die gerade neu generierten Token
        seq       = outputs.sequences[0]
        new_tokens = seq[-CHUNK_TOKENS:].tolist() if past_kvs is not None else seq[ids.shape[-1]:].tolist()

        for tok in new_tokens:
            # Neustart bei erneutem START‑Marker
            if tok == RESTART_MARKER:
                buffer = []
                continue
            # Ende
            if tok == EOS_TOKEN:
                return
            # Audio‑Code berechnen
            buffer.append(tok - AUDIO_TOKEN_OFFSET)
            # sobald 7 Audio‑Tokens, dekodieren und streamen
            if len(buffer) >= BLOCK_TOKENS:
                block = buffer[:BLOCK_TOKENS]
                buffer = buffer[BLOCK_TOKENS:]
                pcm   = decode_block(block)
                await ws.send_bytes(pcm)

# — WebSocket‑Endpoint für TTS Streaming —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
    await ws.accept()
    try:
        data = await ws.receive_text()
        req  = json.loads(data)
        text  = req.get("text", "")
        voice = req.get("voice", "Jakob")

        ids, attn_mask = prepare_inputs(text, voice)
        await generate_and_stream(ws, ids, attn_mask)

        await ws.close()
    except WebSocketDisconnect:
        pass
    except Exception as e:
        print("Error in /ws/tts:", e)
        await ws.close(code=1011)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)