File size: 5,266 Bytes
2c15189
 
 
0316ec3
4189fe1
d4630a2
0316ec3
a09ea48
 
d4630a2
0316ec3
67c3132
a09ea48
2c15189
 
d4630a2
 
2008a3f
d4630a2
 
 
 
 
 
 
 
1ab029d
0316ec3
d4630a2
 
a09ea48
0316ec3
d4630a2
674acbf
d4630a2
 
 
 
 
0dfc310
d4630a2
 
9cd424e
d4630a2
 
 
0dfc310
d4630a2
 
 
3281189
67c3132
d4630a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c15189
67c3132
d4630a2
2c15189
d4630a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8606ac
2c15189
a09ea48
4189fe1
2c15189
d4630a2
 
 
 
2c15189
d4630a2
 
2c15189
d4630a2
 
2c15189
67c3132
d4630a2
2c15189
 
 
 
d4630a2
 
 
2c15189
 
d4630a2
 
 
2c15189
d4630a2
 
 
2c15189
4189fe1
2c15189
a09ea48
2c15189
a09ea48
f3890ef
d4630a2
2c15189
f3890ef
67c3132
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import PlainTextResponse
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# — ENV & HF‑AUTH —
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    # automatisch über huggingface-cli eingeloggt
    os.environ["HUGGINGFACE_HUB_TOKEN"] = HF_TOKEN

# — FastAPI →
app = FastAPI()

@app.get("/")
async def hello():
    return PlainTextResponse("Hallo Welt!")

# — Device konfigurieren —
device = "cuda" if torch.cuda.is_available() else "cpu"

# — SNAC laden —
print("Loading SNAC model…")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)

# — Orpheus/Kartoffel‑3B über PEFT laden —
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
print(f"Loading base LM + PEFT from {model_name}…")
base = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model = PeftModel.from_pretrained(
    base,
    model_name,
    device_map="auto",
)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name)
# sicherstellen, dass pad_token_id gesetzt ist
model.config.pad_token_id = model.config.eos_token_id

# — Hilfsfunktionen —
def prepare_prompt(text: str, voice: str):
    """Setzt Start‑ und End‑Marker um den eigentlichen Prompt."""
    if voice:
        full = f"{voice}: {text}"
    else:
        full = text
    start = torch.tensor([[128259]], dtype=torch.int64)  # BOS für Audio
    end   = torch.tensor([[128009, 128260]], dtype=torch.int64)  # ggf. Speaker‑ID + Marker
    enc = tokenizer(full, return_tensors="pt").input_ids
    seq = torch.cat([start, enc, end], dim=1).to(device)
    mask = torch.ones_like(seq).to(device)
    return seq, mask

def extract_audio_tokens(generated: torch.LongTensor):
    """Croppe alles bis zum echten Audio-Start, entferne EOS und mache 7er-Batches."""
    bos_tok = 128257
    eos_tok = 128258

    # letzten Start‑Token finden und ab da weiter
    idxs = (generated == bos_tok).nonzero(as_tuple=True)[1]
    if idxs.numel() > 0:
        cut = idxs[-1].item() + 1
        cropped = generated[:, cut:]
    else:
        cropped = generated

    # EOS‑Marker entfernen
    flat = cropped[0][cropped[0] != eos_tok]

    # nur ein Vielfaches von 7 behalten
    length = (flat.size(0) // 7) * 7
    flat = flat[:length]

    # Die Audio‑Token beginnen ab Offset 128266
    return [(t.item() - 128266) for t in flat]

def decode_and_stream(tokens: list[int], ws: WebSocket):
    """Wandelt 7er‑Gruppen in Wave‑Samples um und streamt in 0.1 s Chunks."""
    # gruppiere nach 7 und dekodiere jeweils
    pcm16 = bytearray()
    offset = 0
    while offset + 7 <= len(tokens):
        block = tokens[offset:offset+7]
        offset += 7

        # SNAC‑Input vorbereiten
        # Layer‑1: direkt, Layer‑2/3 mit Offsets
        l1, l2, l3 = [], [], []
        l1.append(block[0])
        l2.append(block[1] -   4096)
        l3.append(block[2] - 2*4096)
        l3.append(block[3] - 3*4096)
        l2.append(block[4] - 4*4096)
        l3.append(block[5] - 5*4096)
        l3.append(block[6] - 6*4096)

        t1 = torch.tensor(l1, device=device).unsqueeze(0)
        t2 = torch.tensor(l2, device=device).unsqueeze(0)
        t3 = torch.tensor(l3, device=device).unsqueeze(0)
        audio = snac.decode([t1, t2, t3]).squeeze().cpu().numpy()

        # in PCM16 @24 kHz
        pcm = (audio * 32767).astype("int16").tobytes()
        pcm16.extend(pcm)

    # in 0.1 s‑Chunks (2400 Samples ×2 Bytes)
    chunk_size = 2400 * 2
    for i in range(0, len(pcm16), chunk_size):
        ws.send_bytes(pcm16[i : i+chunk_size])
        # ohne Pause kann das WebSocket überlastet werden
        asyncio.sleep(0.1)

# — WebSocket TTS Endpoint —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
    await ws.accept()
    try:
        while True:
            raw = await ws.receive_text()
            req = json.loads(raw)
            text  = req.get("text", "")
            voice = req.get("voice", "")

            # Prompt vorbereiten
            ids, mask = prepare_prompt(text, voice)

            # Audio‑Token generieren
            gen = model.generate(
                input_ids=ids,
                attention_mask=mask,
                max_new_tokens=4000,
                do_sample=True,
                temperature=0.7,
                top_p=0.95,
                repetition_penalty=1.1,
                eos_token_id=128258,
                forced_bos_token_id=128259,
                use_cache=True,
            )

            codes = extract_audio_tokens(gen)
            # stream synchron
            await decode_and_stream(codes, ws)

            # sauber schließen
            await ws.close(code=1000)
            break

    except WebSocketDisconnect:
        print("Client disconnected")
    except Exception as e:
        print("Error in /ws/tts:", e)
        await ws.close(code=1011)

# — Lokal starten —
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)