File size: 5,200 Bytes
2c15189
 
 
0316ec3
4189fe1
9bf14d0
a09ea48
 
0316ec3
9bf14d0
2c15189
 
9bf14d0
2008a3f
c70d8eb
1ab029d
0316ec3
9bf14d0
 
0dfc310
c70d8eb
9bf14d0
 
 
 
c70d8eb
9bf14d0
 
 
 
 
 
 
 
 
 
c70d8eb
 
9bf14d0
c70d8eb
9bf14d0
3281189
c70d8eb
9bf14d0
 
 
 
 
 
c70d8eb
9bf14d0
 
c70d8eb
 
 
 
 
 
 
 
 
 
 
9bf14d0
c70d8eb
 
 
9bf14d0
 
 
 
c70d8eb
a8606ac
2c15189
a09ea48
4189fe1
c70d8eb
9bf14d0
 
 
 
 
c70d8eb
9bf14d0
 
c70d8eb
 
 
 
 
 
 
 
9bf14d0
2c15189
c70d8eb
 
9bf14d0
c70d8eb
 
 
 
 
 
d4630a2
c70d8eb
 
 
2c15189
c70d8eb
9bf14d0
c70d8eb
 
9bf14d0
c70d8eb
 
 
9bf14d0
c70d8eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bf14d0
2c15189
c70d8eb
 
 
 
 
4189fe1
c70d8eb
a09ea48
2c15189
a09ea48
c70d8eb
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer

# — HF‑Token & Login —
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(HF_TOKEN)

# — Device auswählen —
device = "cuda" if torch.cuda.is_available() else "cpu"

# — FastAPI instanziieren —
app = FastAPI()

# — Hello‑Route, damit GET / nicht 404 gibt —
@app.get("/")
async def read_root():
    return {"message": "Hello, world!"}

# — Modelle beim Startup laden —
@app.on_event("startup")
async def load_models():
    global tokenizer, model, snac
    # SNAC laden
    snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
    # TTS‑Modell laden
    model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto" if device=="cuda" else None,
        torch_dtype=torch.bfloat16 if device=="cuda" else None,
        low_cpu_mem_usage=True
    ).to(device)
    model.config.pad_token_id = model.config.eos_token_id

# — Input‑Vorbereitung —
def prepare_inputs(text: str, voice: str):
    prompt = f"{voice}: {text}"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    start = torch.tensor([[128259]], dtype=torch.int64, device=device)
    end   = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
    ids   = torch.cat([start, input_ids, end], dim=1)
    mask  = torch.ones_like(ids, device=device)
    return ids, mask

# — SNAC‑Dekodierung eines 7‑Token‑Blocks →
def decode_block(tokens: list[int]) -> bytes:
    l1, l2, l3 = [], [], []
    b = tokens
    l1.append(b[0])
    l2.append(b[1]-4096)
    l3.append(b[2]-2*4096)
    l3.append(b[3]-3*4096)
    l2.append(b[4]-4*4096)
    l3.append(b[5]-5*4096)
    l3.append(b[6]-6*4096)
    codes = [
        torch.tensor(l1, device=device).unsqueeze(0),
        torch.tensor(l2, device=device).unsqueeze(0),
        torch.tensor(l3, device=device).unsqueeze(0),
    ]
    audio = snac.decode(codes).squeeze().cpu().numpy()
    return (audio * 32767).astype("int16").tobytes()

# — WebSocket‑Endpoint mit Chunked‑Generate (max_new_tokens=50) —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
    await ws.accept()
    try:
        # 1) Anfrage einlesen
        msg = await ws.receive_text()
        req = json.loads(msg)
        text  = req.get("text", "")
        voice = req.get("voice", "Jakob")

        # 2) Inputs bauen
        input_ids, attention_mask = prepare_inputs(text, voice)
        past_kvs = None
        buffer_codes: list[int] = []

        # 3) Chunk‑Generate‑Loop
        chunk_size = 50
        eos_id     = model.config.eos_token_id

        # Wir tracken bisher erzeugte Länge, um abzugrenzen, was neu ist
        prev_len = 0

        while True:
            out = model.generate(
                input_ids = input_ids     if past_kvs is None else None,
                attention_mask=attention_mask if past_kvs is None else None,
                max_new_tokens=chunk_size,
                do_sample=True,
                temperature=0.7,
                top_p=0.95,
                repetition_penalty=1.1,
                eos_token_id=eos_id,
                use_cache=True,
                return_dict_in_generate=True,
                output_scores=False,
                past_key_values=past_kvs
            )
            # Update past_kvs und sequences
            past_kvs = out.past_key_values
            seqs     = out.sequences  # (1, total_length)
            total_len = seqs.shape[1]

            # 4) Neue Tokens extrahieren
            new_tokens = seqs[0, prev_len:total_len].tolist()
            prev_len = total_len

            # 5) Jeden neuen Token aufbereiten
            for tok in new_tokens:
                if tok == eos_id:
                    # Ende
                    new_tokens = []  # clean up
                    break
                if tok == 128257:
                    buffer_codes.clear()
                    continue
                # offset und puffern
                buffer_codes.append(tok - 128266)
                # sobald 7 Codes gesammelt, dekodieren & senden
                if len(buffer_codes) >= 7:
                    block = buffer_codes[:7]
                    buffer_codes = buffer_codes[7:]
                    pcm = decode_block(block)
                    await ws.send_bytes(pcm)

            # 6) Abbruch, wenn EOS im Chunk war
            if eos_id in new_tokens:
                break

            # Inputs für nächsten Durchgang nur beim ersten Mal
            input_ids = attention_mask = None

        # 7) Zum Schluss sauber schließen
        await ws.close()
    except WebSocketDisconnect:
        return
    except Exception as e:
        print("Error in /ws/tts:", e)
        await ws.close(code=1011)

# — Main für lokalen Test —
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)