File size: 14,403 Bytes
0b5b901
dbb4a9f
 
 
 
 
 
4189fe1
9bf14d0
dbb4a9f
 
 
 
 
 
 
 
 
 
 
 
0316ec3
e3958ab
479f253
 
dbb4a9f
479f253
2008a3f
dbb4a9f
e3958ab
 
dbb4a9f
55145d2
dbb4a9f
55145d2
 
 
 
53012c3
dbb4a9f
 
 
 
e3958ab
dbb4a9f
479f253
dbb4a9f
55145d2
53012c3
e3958ab
dbb4a9f
53012c3
 
dbb4a9f
e3958ab
dbb4a9f
55145d2
a0cc672
dbb4a9f
a0cc672
e3958ab
dbb4a9f
 
0dfc310
dbb4a9f
 
 
 
 
 
 
 
 
 
 
 
641d199
dbb4a9f
 
53012c3
 
641d199
53012c3
 
dbb4a9f
 
 
 
53012c3
 
 
dbb4a9f
 
 
53012c3
dbb4a9f
 
53012c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55145d2
 
dbb4a9f
53012c3
 
 
dbb4a9f
53012c3
 
 
 
 
 
 
 
 
dbb4a9f
53012c3
 
 
 
 
 
 
 
 
 
 
 
 
 
dbb4a9f
53012c3
 
 
 
 
 
 
 
dbb4a9f
 
 
53012c3
dbb4a9f
 
 
 
 
 
 
 
 
 
53012c3
 
dbb4a9f
 
 
55145d2
53012c3
55145d2
dbb4a9f
 
 
53012c3
55145d2
 
53012c3
55145d2
53012c3
 
55145d2
 
 
 
53012c3
55145d2
53012c3
55145d2
 
 
 
 
53012c3
55145d2
 
 
 
 
 
 
 
 
 
 
 
 
53012c3
55145d2
 
 
 
 
 
 
 
 
 
 
53012c3
55145d2
 
 
 
 
53012c3
55145d2
 
 
 
53012c3
55145d2
53012c3
55145d2
53012c3
55145d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53012c3
 
 
55145d2
 
 
 
 
 
 
 
 
 
53012c3
 
55145d2
 
 
 
 
53012c3
 
55145d2
 
 
53012c3
55145d2
 
 
 
53012c3
55145d2
 
 
 
 
 
 
53012c3
55145d2
 
53012c3
55145d2
53012c3
55145d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53012c3
55145d2
53012c3
55145d2
53012c3
55145d2
 
 
 
 
 
 
 
 
 
53012c3
55145d2
 
 
 
 
 
 
 
 
 
 
 
53012c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
# app.py ──────────────────────────────────────────────────────────────
import os
import json
import torch
import asyncio
import traceback # Import traceback for better error logging

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList
# Import BaseStreamer for the interface
from transformers.generation.streamers import BaseStreamer
from snac import SNAC # Ensure you have 'pip install snac'

# --- Globals (populated in load_models) ---
tok = None
model = None
snac = None
masker = None
stopping_criteria = None
device = "cuda" if torch.cuda.is_available() else "cpu"

# 0) Login + Device ---------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    print("πŸ”‘ Logging in to Hugging Face Hub...")
    login(HF_TOKEN)

# torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug

# 1) Konstanten -------------------------------------------------------
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
# CHUNK_TOKENS = 50 # Not directly used by us with the streamer approach
START_TOKEN = 128259
NEW_BLOCK = 128257
EOS_TOKEN = 128258
AUDIO_BASE = 128266
AUDIO_SPAN = 4096 * 7  # 28672 Codes
CODEBOOK_SIZE = 4096  # Explicitly define the codebook size
# Create AUDIO_IDS on the correct device later in load_models
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)

# 2) Logit‑Mask -------------------------------------------------------
class AudioMask(LogitsProcessor):
    def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
        super().__init__()
        # Allow NEW_BLOCK and all valid audio tokens initially
        self.allow = torch.cat([
            torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long),
            audio_ids
        ], dim=0)
        self.eos = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)
        self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0)
        self.sent_blocks = 0 # State: Number of audio blocks sent

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
        mask = torch.full_like(scores, float("-inf"))
        mask[:, current_allow] = 0
        return scores + mask

    def reset(self):
        self.sent_blocks = 0

# 3) StoppingCriteria fΓΌr EOS ---------------------------------------
class EosStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_token_id: int):
        self.eos_token_id = eos_token_id

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
            return True
        return False

# 4) Benutzerdefinierter AudioStreamer -------------------------------
class AudioStreamer(BaseStreamer):
    def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str):
        self.ws = ws
        self.snac = snac_decoder
        self.masker = audio_mask
        self.loop = loop
        self.device = target_device
        self.buf: list[int] = []
        self.tasks = set()

    def _decode_block(self, block7: list[int]) -> bytes:
        """
        Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
        NOTE: Extracts base code value (0-4095) using modulo, assuming
              input values represent (slot_offset + code_value).
              Maps extracted values using the structure potentially correct for Kartoffel_Orpheus.
        """
        if len(block7) != 7:
            print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
            return b""

        try:
            # --- Extract base code value (0 to CODEBOOK_SIZE-1) for each slot using modulo ---
            code_val_0 = block7[0] % CODEBOOK_SIZE
            code_val_1 = block7[1] % CODEBOOK_SIZE
            code_val_2 = block7[2] % CODEBOOK_SIZE
            code_val_3 = block7[3] % CODEBOOK_SIZE
            code_val_4 = block7[4] % CODEBOOK_SIZE
            code_val_5 = block7[5] % CODEBOOK_SIZE
            code_val_6 = block7[6] % CODEBOOK_SIZE

            # --- Map the extracted code values to the SNAC codebooks (l1, l2, l3) ---
            # Using the structure from the user's previous version, believed to be correct
            l1 = [code_val_0]
            l2 = [code_val_1, code_val_4]
            l3 = [code_val_2, code_val_3, code_val_5, code_val_6]

        except IndexError:
            print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
            return b""
        except Exception as e_map: # Catch potential issues with modulo/mapping
            print(f"Streamer Error: Exception during code value extraction/mapping: {e_map}. Block: {block7}")
            return b""

        # --- Convert lists to tensors on the correct device ---
        try:
            codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0)
            codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0)
            codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0)
            codes = [codes_l1, codes_l2, codes_l3]
        except Exception as e_tensor:
            print(f"Streamer Error: Exception during tensor conversion: {e_tensor}. l1={l1}, l2={l2}, l3={l3}")
            return b""

        # --- Decode using SNAC ---
        try:
            with torch.no_grad():
                # self.snac should already be on self.device from load_models_startup
                audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim
        except Exception as e_decode:
            # Add more detailed logging here if it fails again
            print(f"Streamer Error: Exception during snac.decode: {e_decode}")
            print(f"Input codes shapes: {[c.shape for c in codes]}")
            print(f"Input codes dtypes: {[c.dtype for c in codes]}")
            print(f"Input codes devices: {[c.device for c in codes]}")
            # Avoid printing potentially huge lists, maybe just check min/max?
            print(f"Input code values (min/max): L1({min(l1)}/{max(l1)}) L2({min(l2)}/{max(l2)}) L3({min(l3)}/{max(l3)})")
            return b""

        # --- Post-processing ---
        try:
            audio_np = audio.squeeze().detach().cpu().numpy()
            audio_bytes = (audio_np * 32767).astype("int16").tobytes()
            return audio_bytes
        except Exception as e_post:
            print(f"Streamer Error: Exception during post-processing: {e_post}. Audio tensor shape: {audio.shape}")
            return b""

    async def _send_audio_bytes(self, data: bytes):
        """Coroutine to send bytes over WebSocket."""
        if not data:
            return
        try:
            await self.ws.send_bytes(data)
        except WebSocketDisconnect:
            print("Streamer: WebSocket disconnected during send.")
        except Exception as e:
            print(f"Streamer: Error sending bytes: {e}")

    def put(self, value: torch.LongTensor):
        """
        Receives new token IDs (Tensor) from generate().
        Processes tokens, decodes full blocks, and schedules sending.
        """
        if value.numel() == 0:
            return
        new_token_ids = value.squeeze().tolist()
        if isinstance(new_token_ids, int):
            new_token_ids = [new_token_ids]

        for t in new_token_ids:
            if t == EOS_TOKEN:
                break
            if t == NEW_BLOCK:
                self.buf.clear()
                continue
            if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
                self.buf.append(t - AUDIO_BASE) # Store value relative to base
            # else: # Optionally log ignored tokens
                # print(f"Streamer Warning: Ignoring unexpected token {t}")

            if len(self.buf) == 7:
                audio_bytes = self._decode_block(self.buf)
                self.buf.clear()

                if audio_bytes:
                    future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
                    self.tasks.add(future)
                    future.add_done_callback(self.tasks.discard)

                    if self.masker.sent_blocks == 0:
                        self.masker.sent_blocks = 1

    def end(self):
        """Called by generate() when generation finishes."""
        if len(self.buf) > 0:
            print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
            self.buf.clear()
        # print(f"Streamer: Generation finished. Pending send tasks: {len(self.tasks)}")
        pass

# 5) FastAPI App ------------------------------------------------------
app = FastAPI()

@app.on_event("startup")
async def load_models_startup():
    global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU

    print(f"πŸš€ Starting up on device: {device}")
    print("⏳ Lade Modelle …", flush=True)

    tok = AutoTokenizer.from_pretrained(REPO)
    print("Tokenizer loaded.")

    snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
    print(f"SNAC loaded to {device}.") # Use the global device variable

    model_dtype = torch.float32
    if device == "cuda":
        if torch.cuda.is_bf16_supported():
            model_dtype = torch.bfloat16
            print("Using bfloat16 for model.")
        else:
            model_dtype = torch.float16
            print("Using float16 for model.")

    model = AutoModelForCausalLM.from_pretrained(
        REPO,
        device_map={"": 0} if device == "cuda" else None,
        torch_dtype=model_dtype,
        low_cpu_mem_usage=True,
    )
    model.config.pad_token_id = model.config.eos_token_id
    print(f"Model loaded to {model.device} with dtype {model.dtype}.")
    model.eval()

    audio_ids_device = AUDIO_IDS_CPU.to(device)
    masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN)
    print("AudioMask initialized.")

    stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)])
    print("StoppingCriteria initialized.")

    print("βœ… Modelle geladen und bereit!", flush=True)

@app.get("/")
def hello():
    return {"status": "ok", "message": "TTS Service is running"}

# 6) Helper zum Prompt Bauen -------------------------------------------
def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
    """Builds the input_ids and attention_mask for the model."""
    prompt_text = f"{voice}: {text}"
    prompt_ids = tok(prompt_text, return_tensors="pt").input_ids.to(device)

    input_ids = torch.cat([
        torch.tensor([[START_TOKEN]], device=device, dtype=torch.long),
        prompt_ids,
        torch.tensor([[NEW_BLOCK]], device=device, dtype=torch.long)
    ], dim=1)

    attention_mask = torch.ones_like(input_ids)
    return input_ids, attention_mask

# 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
    await ws.accept()
    print("πŸ”Œ Client connected")
    streamer = None
    main_loop = asyncio.get_running_loop()

    try:
        req_text = await ws.receive_text()
        print(f"Received request: {req_text}")
        req = json.loads(req_text)
        text = req.get("text", "Hallo Welt, wie geht es dir heute?")
        voice = req.get("voice", "Jakob")

        if not text:
            print("⚠️ Request text is empty.")
            await ws.close(code=1003, reason="Text cannot be empty")
            return

        print(f"Generating audio for: '{text}' with voice '{voice}'")
        ids, attn = build_prompt(text, voice)
        masker.reset()
        streamer = AudioStreamer(ws, snac, masker, main_loop, device)

        print("Starting generation in background thread...")
        await asyncio.to_thread(
            model.generate,
            input_ids=ids,
            attention_mask=attn,
            max_new_tokens=1500,
            logits_processor=[masker],
            stopping_criteria=stopping_criteria,
            do_sample=False, # Using greedy decoding
            use_cache=True,
            streamer=streamer
        )
        print("Generation thread finished.")

    except WebSocketDisconnect:
        print("πŸ”Œ Client disconnected.")
    except json.JSONDecodeError:
        print("❌ Invalid JSON received.")
        if ws.client_state.name == "CONNECTED":
            await ws.close(code=1003, reason="Invalid JSON format")
    except Exception as e:
        error_details = traceback.format_exc()
        print(f"❌ WS‑Error: {e}\n{error_details}", flush=True)
        error_payload = json.dumps({"error": str(e)})
        try:
            if ws.client_state.name == "CONNECTED":
                 await ws.send_text(error_payload)
        except Exception:
            pass
        if ws.client_state.name == "CONNECTED":
            await ws.close(code=1011)
    finally:
        if streamer:
            try:
                streamer.end()
            except Exception as e_end:
                 print(f"Error during streamer.end(): {e_end}")

        print("Closing connection.")
        if ws.client_state.name == "CONNECTED":
            try:
                await ws.close(code=1000)
            except RuntimeError as e_close:
                 print(f"Runtime error closing websocket: {e_close}")
            except Exception as e_close_final:
                 print(f"Error closing websocket: {e_close_final}")
        elif ws.client_state.name != "DISCONNECTED":
             print(f"WebSocket final state: {ws.client_state.name}")
        print("Connection closed.")

# 8) Dev‑Start --------------------------------------------------------
if __name__ == "__main__":
    import uvicorn
    print("Starting Uvicorn server...")
    uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")