File size: 17,480 Bytes
0b5b901
dbb4a9f
 
 
 
 
 
4189fe1
9bf14d0
dbb4a9f
 
 
 
 
 
 
 
 
 
 
96dc59a
dbb4a9f
0316ec3
e3958ab
479f253
 
dbb4a9f
479f253
2008a3f
dbb4a9f
e3958ab
 
dbb4a9f
55145d2
dbb4a9f
55145d2
96dc59a
55145d2
 
53012c3
dbb4a9f
 
 
 
96dc59a
e3958ab
dbb4a9f
479f253
96dc59a
 
 
 
dbb4a9f
96dc59a
 
 
dbb4a9f
e3958ab
dbb4a9f
96dc59a
55145d2
96dc59a
 
a0cc672
96dc59a
dbb4a9f
96dc59a
a0cc672
e3958ab
dbb4a9f
96dc59a
dbb4a9f
0dfc310
dbb4a9f
96dc59a
dbb4a9f
 
 
96dc59a
 
dbb4a9f
 
96dc59a
 
 
dbb4a9f
96dc59a
dbb4a9f
 
 
 
 
96dc59a
dbb4a9f
 
53012c3
 
641d199
96dc59a
53012c3
 
dbb4a9f
 
 
 
53012c3
 
 
dbb4a9f
 
 
53012c3
dbb4a9f
 
53012c3
 
 
 
 
 
 
 
 
 
 
 
 
 
55145d2
 
dbb4a9f
53012c3
 
 
dbb4a9f
53012c3
 
 
 
 
 
 
 
 
dbb4a9f
53012c3
 
 
96dc59a
53012c3
 
 
 
 
 
 
dbb4a9f
53012c3
 
 
 
 
 
 
 
dbb4a9f
 
 
53012c3
dbb4a9f
 
 
 
 
 
96dc59a
 
 
 
 
 
dbb4a9f
 
 
53012c3
 
dbb4a9f
 
 
96dc59a
 
53012c3
55145d2
dbb4a9f
 
96dc59a
55145d2
 
53012c3
96dc59a
55145d2
53012c3
96dc59a
 
 
55145d2
 
 
53012c3
55145d2
53012c3
96dc59a
55145d2
 
 
 
96dc59a
55145d2
53012c3
55145d2
 
 
 
 
 
 
 
 
 
 
 
53012c3
96dc59a
55145d2
 
 
 
 
 
 
 
96dc59a
55145d2
53012c3
55145d2
 
 
 
 
53012c3
55145d2
 
 
 
53012c3
55145d2
53012c3
55145d2
 
 
 
96dc59a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55145d2
96dc59a
 
55145d2
 
96dc59a
 
55145d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53012c3
 
 
55145d2
 
 
 
 
 
 
 
96dc59a
55145d2
 
53012c3
 
55145d2
 
 
 
 
53012c3
 
55145d2
 
 
53012c3
55145d2
 
 
 
53012c3
96dc59a
 
55145d2
 
96dc59a
55145d2
96dc59a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55145d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53012c3
55145d2
53012c3
55145d2
53012c3
55145d2
 
 
 
 
 
 
 
 
 
53012c3
55145d2
96dc59a
 
 
55145d2
 
 
 
 
 
 
 
 
 
53012c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
# app.py ──────────────────────────────────────────────────────────────
import os
import json
import torch
import asyncio
import traceback # Import traceback for better error logging

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList
# Import BaseStreamer for the interface
from transformers.generation.streamers import BaseStreamer
from snac import SNAC # Ensure you have 'pip install snac'

# --- Globals (populated in load_models) ---
tok = None
model = None
snac = None
masker = None
stopping_criteria = None
actual_eos_token_id = None # Will be determined during startup
device = "cuda" if torch.cuda.is_available() else "cpu"

# 0) Login + Device ---------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    print("πŸ”‘ Logging in to Hugging Face Hub...")
    login(HF_TOKEN)

# torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug

# 1) Konstanten -------------------------------------------------------
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
# CHUNK_TOKENS = 50 # Not directly used by us with the streamer approach
START_TOKEN = 128259
NEW_BLOCK = 128257
# EOS_TOKEN = 128258 # REMOVED - Will be determined from model/tokenizer config
AUDIO_BASE = 128266
AUDIO_SPAN = 4096 * 7  # 28672 Codes
CODEBOOK_SIZE = 4096  # Explicitly define the codebook size
# Create AUDIO_IDS on the correct device later in load_models
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)

# 2) Logit‑Mask -------------------------------------------------------
# Uses the dynamically determined EOS token ID
class AudioMask(LogitsProcessor):
    def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
        super().__init__()
        # Ensure input tensors are Long type for concatenation if needed, although indices are usually int
        new_block_tensor = torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long)
        eos_tensor = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)

        # Allow NEW_BLOCK and all valid audio tokens initially
        self.allow = torch.cat([new_block_tensor, audio_ids], dim=0)
        self.eos = eos_tensor # Store EOS token ID as tensor
        self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0) # Precompute combined tensor
        self.sent_blocks = 0 # State: Number of audio blocks sent

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # Determine which tokens are allowed based on whether blocks have been sent
        current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow

        # Create a mask initialized to negative infinity
        mask = torch.full_like(scores, float("-inf"))
        # Set allowed token scores to 0 (effectively allowing them)
        mask[:, current_allow] = 0
        # Apply the mask to the scores
        return scores + mask

    def reset(self):
        """Resets the state for a new generation request."""
        self.sent_blocks = 0

# 3) StoppingCriteria fΓΌr EOS ---------------------------------------
# Uses the dynamically determined EOS token ID
class EosStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_token_id: int):
        self.eos_token_id = eos_token_id
        if self.eos_token_id is None:
             print("⚠️ EosStoppingCriteria initialized with eos_token_id=None!")

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if self.eos_token_id is None:
             return False # Cannot stop if EOS ID is unknown
        # Check if the *last* generated token is the EOS token
        if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
            # print("StoppingCriteria: EOS detected.")
            return True
        return False

# 4) Benutzerdefinierter AudioStreamer -------------------------------
class AudioStreamer(BaseStreamer):
    def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str, eos_token_id: int):
        self.ws = ws
        self.snac = snac_decoder
        self.masker = audio_mask
        self.loop = loop
        self.device = target_device
        self.eos_token_id = eos_token_id # Store EOS ID for potential use in put (optional)
        self.buf: list[int] = []
        self.tasks = set()

    def _decode_block(self, block7: list[int]) -> bytes:
        """
        Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
        NOTE: Extracts base code value (0-4095) using modulo, assuming
              input values represent (slot_offset + code_value).
              Maps extracted values using the structure potentially correct for Kartoffel_Orpheus.
        """
        if len(block7) != 7:
            print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
            return b""

        try:
            # --- Extract base code value (0 to CODEBOOK_SIZE-1) for each slot using modulo ---
            code_val_0 = block7[0] % CODEBOOK_SIZE
            code_val_1 = block7[1] % CODEBOOK_SIZE
            code_val_2 = block7[2] % CODEBOOK_SIZE
            code_val_3 = block7[3] % CODEBOOK_SIZE
            code_val_4 = block7[4] % CODEBOOK_SIZE
            code_val_5 = block7[5] % CODEBOOK_SIZE
            code_val_6 = block7[6] % CODEBOOK_SIZE

            # --- Map the extracted code values to the SNAC codebooks (l1, l2, l3) ---
            l1 = [code_val_0]
            l2 = [code_val_1, code_val_4]
            l3 = [code_val_2, code_val_3, code_val_5, code_val_6]

        except IndexError:
            print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
            return b""
        except Exception as e_map: # Catch potential issues with modulo/mapping
            print(f"Streamer Error: Exception during code value extraction/mapping: {e_map}. Block: {block7}")
            return b""

        # --- Convert lists to tensors on the correct device ---
        try:
            codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0)
            codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0)
            codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0)
            codes = [codes_l1, codes_l2, codes_l3]
        except Exception as e_tensor:
            print(f"Streamer Error: Exception during tensor conversion: {e_tensor}. l1={l1}, l2={l2}, l3={l3}")
            return b""

        # --- Decode using SNAC ---
        try:
            with torch.no_grad():
                audio = self.snac.decode(codes)[0]
        except Exception as e_decode:
            print(f"Streamer Error: Exception during snac.decode: {e_decode}")
            print(f"Input codes shapes: {[c.shape for c in codes]}")
            print(f"Input codes dtypes: {[c.dtype for c in codes]}")
            print(f"Input codes devices: {[c.device for c in codes]}")
            print(f"Input code values (min/max): L1({min(l1)}/{max(l1)}) L2({min(l2)}/{max(l2)}) L3({min(l3)}/{max(l3)})")
            return b""

        # --- Post-processing ---
        try:
            audio_np = audio.squeeze().detach().cpu().numpy()
            audio_bytes = (audio_np * 32767).astype("int16").tobytes()
            return audio_bytes
        except Exception as e_post:
            print(f"Streamer Error: Exception during post-processing: {e_post}. Audio tensor shape: {audio.shape}")
            return b""

    async def _send_audio_bytes(self, data: bytes):
        """Coroutine to send bytes over WebSocket."""
        if not data:
            return
        try:
            await self.ws.send_bytes(data)
        except WebSocketDisconnect:
            print("Streamer: WebSocket disconnected during send.")
        except Exception as e:
            # Handle cases where sending fails after connection closed
            if "Cannot call \"send\" once a close message has been sent" in str(e):
                 # This is expected if client disconnects during generation, suppress repetitive logs
                 pass
            else:
                 print(f"Streamer: Error sending bytes: {e}")

    def put(self, value: torch.LongTensor):
        """
        Receives new token IDs (Tensor) from generate().
        Processes tokens, decodes full blocks, and schedules sending.
        """
        if value.numel() == 0:
            return
        # Ensure value is on CPU and flatten to a list of ints
        new_token_ids = value.squeeze().cpu().tolist()
        if isinstance(new_token_ids, int):
            new_token_ids = [new_token_ids]

        for t in new_token_ids:
            # No need to check for EOS here, StoppingCriteria handles it
            if t == NEW_BLOCK:
                self.buf.clear()
                continue

            if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
                self.buf.append(t - AUDIO_BASE) # Store value relative to base
            # else: # Optionally log ignored tokens outside audio range
                # if t != self.eos_token_id: # Don't warn about the EOS token itself
                #      print(f"Streamer Warning: Ignoring unexpected token {t}")

            if len(self.buf) == 7:
                audio_bytes = self._decode_block(self.buf)
                self.buf.clear()

                if audio_bytes:
                    # Schedule the async send function to run on the main event loop
                    future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
                    self.tasks.add(future)
                    future.add_done_callback(self.tasks.discard)

                    # Allow EOS only after the first full block has been processed
                    if self.masker.sent_blocks == 0:
                        self.masker.sent_blocks = 1

    def end(self):
        """Called by generate() when generation finishes."""
        if len(self.buf) > 0:
            print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
            self.buf.clear()
        pass

# 5) FastAPI App ------------------------------------------------------
app = FastAPI()

@app.on_event("startup")
async def load_models_startup():
    global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU, actual_eos_token_id

    print(f"πŸš€ Starting up on device: {device}")
    print("⏳ Lade Modelle …", flush=True)

    tok = AutoTokenizer.from_pretrained(REPO)
    print("Tokenizer loaded.")

    snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
    print(f"SNAC loaded to {device}.")

    model_dtype = torch.float32
    if device == "cuda":
        if torch.cuda.is_bf16_supported():
            model_dtype = torch.bfloat16
            print("Using bfloat16 for model.")
        else:
            model_dtype = torch.float16
            print("Using float16 for model.")

    model = AutoModelForCausalLM.from_pretrained(
        REPO,
        device_map={"": 0} if device == "cuda" else None,
        torch_dtype=model_dtype,
        low_cpu_mem_usage=True,
    )
    print(f"Model loaded to {model.device} with dtype {model.dtype}.")
    model.eval()

    # --- Determine and set the correct EOS token ID ---
    conf_eos = model.config.eos_token_id
    tok_eos = tok.eos_token_id
    print(f"Model Config EOS ID: {conf_eos}")
    print(f"Tokenizer EOS ID: {tok_eos}")

    if conf_eos is not None:
        actual_eos_token_id = conf_eos
    elif tok_eos is not None:
        actual_eos_token_id = tok_eos
        print(f"⚠️ Model config EOS ID is None, using Tokenizer EOS ID: {actual_eos_token_id}")
    else:
        raise ValueError("Could not determine EOS token ID from model config or tokenizer.")

    print(f"Using EOS Token ID: {actual_eos_token_id}")
    # Set pad_token_id to eos_token_id if not already set (common practice for generation)
    if model.config.pad_token_id is None:
         print(f"Setting model.config.pad_token_id to EOS token ID ({actual_eos_token_id})")
         model.config.pad_token_id = actual_eos_token_id
    # --- End EOS Token ID determination ---

    audio_ids_device = AUDIO_IDS_CPU.to(device)
    # Pass the determined EOS ID to the mask
    masker = AudioMask(audio_ids_device, NEW_BLOCK, actual_eos_token_id)
    print("AudioMask initialized.")

    # Pass the determined EOS ID to the stopping criteria
    stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(actual_eos_token_id)])
    print("StoppingCriteria initialized.")

    print("βœ… Modelle geladen und bereit!", flush=True)

@app.get("/")
def hello():
    return {"status": "ok", "message": "TTS Service is running"}

# 6) Helper zum Prompt Bauen -------------------------------------------
def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
    """Builds the input_ids and attention_mask for the model."""
    prompt_text = f"{voice}: {text}"
    prompt_ids = tok(prompt_text, return_tensors="pt").input_ids.to(device)

    input_ids = torch.cat([
        torch.tensor([[START_TOKEN]], device=device, dtype=torch.long),
        prompt_ids,
        torch.tensor([[NEW_BLOCK]], device=device, dtype=torch.long)
    ], dim=1)

    attention_mask = torch.ones_like(input_ids)
    return input_ids, attention_mask

# 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
    global actual_eos_token_id # Ensure we can access the determined EOS ID
    await ws.accept()
    print("πŸ”Œ Client connected")
    streamer = None
    main_loop = asyncio.get_running_loop()

    try:
        req_text = await ws.receive_text()
        print(f"Received request: {req_text}")
        req = json.loads(req_text)
        text = req.get("text", "Hallo Welt, wie geht es dir heute?")
        voice = req.get("voice", "Jakob")

        if not text:
            print("⚠️ Request text is empty.")
            await ws.close(code=1003, reason="Text cannot be empty")
            return

        print(f"Generating audio for: '{text}' with voice '{voice}'")
        ids, attn = build_prompt(text, voice)
        masker.reset()
        # Pass the determined EOS ID to the streamer as well (optional, for logging/checks)
        streamer = AudioStreamer(ws, snac, masker, main_loop, device, actual_eos_token_id)

        print("Starting generation in background thread...")
        # Use sampling parameters to avoid repetition
        await asyncio.to_thread(
            model.generate,
            input_ids=ids,
            attention_mask=attn,
            max_new_tokens=2500, # Increased slightly, adjust as needed
            logits_processor=[masker],
            stopping_criteria=stopping_criteria,
            # --- Sampling Parameters ---
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            repetition_penalty=1.15,
            # --- End Sampling Parameters ---
            use_cache=True,
            streamer=streamer,
            eos_token_id=actual_eos_token_id # Explicitly pass correct EOS ID here too
        )
        print("Generation thread finished.")

    except WebSocketDisconnect:
        print("πŸ”Œ Client disconnected.")
    except json.JSONDecodeError:
        print("❌ Invalid JSON received.")
        if ws.client_state.name == "CONNECTED":
            await ws.close(code=1003, reason="Invalid JSON format")
    except Exception as e:
        error_details = traceback.format_exc()
        print(f"❌ WS‑Error: {e}\n{error_details}", flush=True)
        error_payload = json.dumps({"error": str(e)})
        try:
            if ws.client_state.name == "CONNECTED":
                 await ws.send_text(error_payload)
        except Exception:
            pass
        if ws.client_state.name == "CONNECTED":
            await ws.close(code=1011)
    finally:
        if streamer:
            try:
                streamer.end()
            except Exception as e_end:
                 print(f"Error during streamer.end(): {e_end}")

        print("Closing connection.")
        if ws.client_state.name == "CONNECTED":
            try:
                await ws.close(code=1000)
            except RuntimeError as e_close:
                 # Suppress "Cannot call 'send'..." error during final close if already disconnected
                 if "Cannot call \"send\"" not in str(e_close):
                      print(f"Runtime error closing websocket: {e_close}")
            except Exception as e_close_final:
                 print(f"Error closing websocket: {e_close_final}")
        elif ws.client_state.name != "DISCONNECTED":
             print(f"WebSocket final state: {ws.client_state.name}")
        print("Connection closed.")

# 8) Dev‑Start --------------------------------------------------------
if __name__ == "__main__":
    import uvicorn
    print("Starting Uvicorn server...")
    uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")