File size: 16,562 Bytes
0b5b901
dbb4a9f
 
 
 
 
 
4189fe1
9bf14d0
dbb4a9f
 
 
 
 
 
 
 
 
 
 
 
0316ec3
e3958ab
479f253
 
dbb4a9f
479f253
2008a3f
dbb4a9f
e3958ab
 
dbb4a9f
 
 
 
 
 
 
 
 
 
 
e3958ab
dbb4a9f
479f253
dbb4a9f
e3958ab
dbb4a9f
e3958ab
dbb4a9f
 
 
 
e3958ab
dbb4a9f
 
 
 
 
a0cc672
dbb4a9f
 
 
a0cc672
e3958ab
dbb4a9f
 
 
0dfc310
dbb4a9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bf14d0
 
dbb4a9f
 
 
 
e3958ab
 
dbb4a9f
 
 
 
 
 
 
 
9bf14d0
5031731
dbb4a9f
 
 
d44e840
dbb4a9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8606ac
d44e840
a09ea48
dbb4a9f
 
 
 
4189fe1
dbb4a9f
 
 
 
 
 
e3958ab
dbb4a9f
 
 
 
 
 
 
7db0e09
dbb4a9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a09ea48
dbb4a9f
 
 
 
 
 
 
 
 
 
 
 
5031731
dbb4a9f
 
 
 
 
 
 
 
 
5031731
e3958ab
dbb4a9f
 
 
 
 
 
 
5031731
dbb4a9f
a4cfefc
dbb4a9f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
# app.py ──────────────────────────────────────────────────────────────
import os
import json
import torch
import asyncio
import traceback # Import traceback for better error logging

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList
# Import BaseStreamer for the interface
from transformers.generation.streamers import BaseStreamer
from snac import SNAC # Ensure you have 'pip install snac'

# --- Globals (populated in load_models) ---
tok = None
model = None
snac = None
masker = None
stopping_criteria = None
device = "cuda" if torch.cuda.is_available() else "cpu"

# 0) Login + Device ---------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    print("🔑 Logging in to Hugging Face Hub...")
    login(HF_TOKEN)

# torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug

# 1) Konstanten -------------------------------------------------------
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
# CHUNK_TOKENS = 50 # Not directly used by us with the streamer approach
START_TOKEN = 128259
NEW_BLOCK = 128257
EOS_TOKEN = 128258
AUDIO_BASE = 128266
AUDIO_SPAN = 4096 * 7  # 28672 Codes
# Create AUDIO_IDS on the correct device later in load_models
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)

# 2) Logit‑Mask -------------------------------------------------------
class AudioMask(LogitsProcessor):
    def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
        super().__init__()
        # Allow NEW_BLOCK and all valid audio tokens initially
        self.allow = torch.cat([
            torch.tensor([new_block_token_id], device=audio_ids.device), # Add NEW_BLOCK token ID
            audio_ids
        ], dim=0)
        self.eos = torch.tensor([eos_token_id], device=audio_ids.device) # Store EOS token ID as tensor
        self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0) # Precompute combined tensor
        self.sent_blocks = 0 # State: Number of audio blocks sent

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # Determine which tokens are allowed based on whether blocks have been sent
        current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow

        # Create a mask initialized to negative infinity
        mask = torch.full_like(scores, float("-inf"))
        # Set allowed token scores to 0 (effectively allowing them)
        mask[:, current_allow] = 0
        # Apply the mask to the scores
        return scores + mask

    def reset(self):
        """Resets the state for a new generation request."""
        self.sent_blocks = 0

# 3) StoppingCriteria für EOS ---------------------------------------
# generate() needs explicit stopping criteria when using a streamer
class EosStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_token_id: int):
        self.eos_token_id = eos_token_id

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # Check if the *last* generated token is the EOS token
        if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
            # print("StoppingCriteria: EOS detected.")
            return True
        return False

# 4) Benutzerdefinierter AudioStreamer -------------------------------
class AudioStreamer(BaseStreamer):
    def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop):
        self.ws = ws
        self.snac = snac_decoder
        self.masker = audio_mask # Reference to the mask to update sent_blocks
        self.loop = loop # Event loop of the main thread for run_coroutine_threadsafe
        self.device = snac_decoder.device # Get device from the decoder
        self.buf: list[int] = [] # Buffer for audio token values (AUDIO_BASE subtracted)
        self.tasks = set() # Keep track of pending send tasks

    def _decode_block(self, block7: list[int]) -> bytes:
        """
        Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
        NOTE: The mapping from the 7 tokens to the 3 SNAC codebooks (l1, l2, l3)
              is based on a common interleaving hypothesis. Verify if model docs specify otherwise.
        """
        if len(block7) != 7:
            print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
            return b"" # Return empty bytes if block is incomplete

        # --- Hypothesis: Interleaving mapping ---
        # Assumes 7 tokens map to 3 codebooks like this:
        # Codebook 1 (l1) uses tokens at indices 0, 3, 6
        # Codebook 2 (l2) uses tokens at indices 1, 4
        # Codebook 3 (l3) uses tokens at indices 2, 5
        try:
            l1 = [block7[0], block7[3], block7[6]]
            l2 = [block7[1], block7[4]]
            l3 = [block7[2], block7[5]]
        except IndexError:
            print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
            return b""

        # Convert lists to tensors on the correct device
        codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0)
        codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0)
        codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0)
        codes = [codes_l1, codes_l2, codes_l3] # List of tensors for SNAC

        # Decode using SNAC
        with torch.no_grad():
            audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim

        # Squeeze, move to CPU, convert to numpy
        audio_np = audio.squeeze().detach().cpu().numpy()

        # Convert to 16-bit PCM bytes
        audio_bytes = (audio_np * 32767).astype("int16").tobytes()
        return audio_bytes

    async def _send_audio_bytes(self, data: bytes):
        """Coroutine to send bytes over WebSocket."""
        if not data: # Don't send empty bytes
            return
        try:
            await self.ws.send_bytes(data)
            # print(f"Streamer: Sent {len(data)} audio bytes.")
        except WebSocketDisconnect:
            print("Streamer: WebSocket disconnected during send.")
        except Exception as e:
            print(f"Streamer: Error sending bytes: {e}")

    def put(self, value: torch.LongTensor):
        """
        Receives new token IDs (Tensor) from generate() (runs in worker thread).
        Processes tokens, decodes full blocks, and schedules sending via run_coroutine_threadsafe.
        """
        # Ensure value is on CPU and flatten to a list of ints
        if value.numel() == 0:
            return
        new_token_ids = value.squeeze().tolist()
        if isinstance(new_token_ids, int): # Handle single token case
            new_token_ids = [new_token_ids]

        for t in new_token_ids:
            if t == EOS_TOKEN:
                # print("Streamer: EOS token encountered.")
                # EOS is handled by StoppingCriteria, no action needed here except maybe logging.
                break # Stop processing this batch if EOS is found

            if t == NEW_BLOCK:
                # print("Streamer: NEW_BLOCK token encountered.")
                # NEW_BLOCK indicates the start of audio, might reset buffer if needed
                self.buf.clear()
                continue # Move to the next token

            # Check if token is within the expected audio range
            if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
                self.buf.append(t - AUDIO_BASE) # Store value relative to base
            else:
                # Log unexpected tokens (like START_TOKEN or others if generation goes wrong)
                # print(f"Streamer Warning: Ignoring unexpected token {t}")
                pass # Ignore tokens outside the audio range

            # If buffer has 7 tokens, decode and send
            if len(self.buf) == 7:
                audio_bytes = self._decode_block(self.buf)
                self.buf.clear() # Clear buffer after processing

                if audio_bytes: # Only send if decoding was successful
                    # Schedule the async send function to run on the main event loop
                    future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
                    self.tasks.add(future)
                    # Optional: Remove completed tasks to prevent memory leak if generation is very long
                    future.add_done_callback(self.tasks.discard)


                    # Allow EOS only after the first full block has been processed and scheduled for sending
                    if self.masker.sent_blocks == 0:
                        # print("Streamer: First audio block processed, allowing EOS.")
                        self.masker.sent_blocks = 1 # Update state in the mask

        # Note: No need to explicitly wait for tasks here. put() should return quickly.

    def end(self):
        """Called by generate() when generation finishes."""
        # Handle any remaining tokens in the buffer (optional, here we discard them)
        if len(self.buf) > 0:
            print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
            self.buf.clear()

        # Optional: Wait briefly for any outstanding send tasks to complete?
        # This is tricky because end() is sync. A robust solution might involve
        # signaling the WebSocket handler to wait before closing.
        # For simplicity, we rely on FastAPI/Uvicorn's graceful shutdown handling.
        # print(f"Streamer: Generation finished. Pending send tasks: {len(self.tasks)}")
        pass

# 5) FastAPI App ------------------------------------------------------
app = FastAPI()

@app.on_event("startup")
async def load_models_startup(): # Make startup async if needed for future async loads
    global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU

    print(f"🚀 Starting up on device: {device}")
    print("⏳ Lade Modelle …", flush=True)

    tok = AutoTokenizer.from_pretrained(REPO)
    print("Tokenizer loaded.")

    # Load SNAC first (usually smaller)
    snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
    print(f"SNAC loaded to {snac.device}.")

    # Load the main model
    model = AutoModelForCausalLM.from_pretrained(
        REPO,
        device_map={"": 0} if device == "cuda" else None, # Assign to GPU 0 if cuda
        torch_dtype=torch.bfloat16 if device == "cuda" and torch.cuda.is_bf16_supported() else torch.float32, # Use bfloat16 if supported
        low_cpu_mem_usage=True, # Good practice for large models
    )
    model.config.pad_token_id = model.config.eos_token_id # Set pad token
    print(f"Model loaded to {model.device}.")

    # Ensure model is in evaluation mode
    model.eval()

    # Initialize AudioMask (needs AUDIO_IDS on the correct device)
    audio_ids_device = AUDIO_IDS_CPU.to(device)
    masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN)
    print("AudioMask initialized.")

    # Initialize StoppingCriteria
    # IMPORTANT: Create the list and add the criteria instance
    stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)])
    print("StoppingCriteria initialized.")

    print("✅ Modelle geladen und bereit!", flush=True)

@app.get("/")
def hello():
    return {"status": "ok", "message": "TTS Service is running"}

# 6) Helper zum Prompt Bauen -------------------------------------------
def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
    """Builds the input_ids and attention_mask for the model."""
    # Format: <START> <VOICE>: <TEXT> <NEW_BLOCK>
    prompt_text = f"{voice}: {text}"
    prompt_ids = tok(prompt_text, return_tensors="pt").input_ids.to(device)

    # Construct input_ids tensor
    input_ids = torch.cat([
        torch.tensor([[START_TOKEN]], device=device), # Start token
        prompt_ids,                                  # Encoded prompt
        torch.tensor([[NEW_BLOCK]], device=device)   # New block token to trigger audio
    ], dim=1)

    # Create attention mask (all ones)
    attention_mask = torch.ones_like(input_ids)
    return input_ids, attention_mask

# 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
    await ws.accept()
    print(" клиент подключился") # Client connected
    streamer = None # Initialize for finally block
    main_loop = asyncio.get_running_loop() # Get the current event loop

    try:
        # Receive configuration
        req_text = await ws.receive_text()
        print(f"Received request: {req_text}")
        req = json.loads(req_text)
        text = req.get("text", "Hallo Welt, wie geht es dir heute?") # Default text
        voice = req.get("voice", "Jakob") # Default voice

        if not text:
            await ws.close(code=1003, reason="Text cannot be empty")
            return

        print(f"Generating audio for: '{text}' with voice '{voice}'")

        # Prepare prompt
        ids, attn = build_prompt(text, voice)

        # --- Reset stateful components ---
        masker.reset() # CRITICAL: Reset the mask state for the new request

        # --- Create Streamer Instance ---
        streamer = AudioStreamer(ws, snac, masker, main_loop)

        # --- Run model.generate in a separate thread ---
        # This prevents blocking the main FastAPI event loop
        print("Starting generation...")
        await asyncio.to_thread(
            model.generate,
            input_ids=ids,
            attention_mask=attn,
            max_new_tokens=1500, # Limit generation length (adjust as needed)
            logits_processor=[masker],
            stopping_criteria=stopping_criteria,
            do_sample=False, # Use greedy decoding for potentially more stable audio
            # do_sample=True, temperature=0.7, top_p=0.95, # Or use sampling
            use_cache=True,
            streamer=streamer # Pass the custom streamer
            # No need to manage past_key_values manually
        )
        print("Generation finished.")

    except WebSocketDisconnect:
        print("Client disconnected.")
    except json.JSONDecodeError:
        print("❌ Invalid JSON received.")
        await ws.close(code=1003, reason="Invalid JSON format") # 1003 = Cannot accept data type
    except Exception as e:
        error_details = traceback.format_exc()
        print(f"❌ WS‑Error: {e}\n{error_details}", flush=True)
        # Try to send an error message before closing, if possible
        error_payload = json.dumps({"error": str(e)})
        try:
            if ws.client_state.name == "CONNECTED":
                 await ws.send_text(error_payload) # Send error as text/json
        except Exception:
            pass # Ignore error during error reporting
        # Close with internal server error code
        if ws.client_state.name == "CONNECTED":
            await ws.close(code=1011) # 1011 = Internal Server Error
    finally:
        # Ensure streamer's end method is called if it exists
        if streamer:
            try:
                streamer.end()
            except Exception as e_end:
                 print(f"Error during streamer.end(): {e_end}")

        # Ensure WebSocket is closed
        print("Closing connection.")
        if ws.client_state.name != "DISCONNECTED":
            try:
                await ws.close(code=1000) # 1000 = Normal Closure
            except RuntimeError as e_close:
                 # Can happen if connection is already closing/closed
                 print(f"Runtime error closing websocket: {e_close}")
            except Exception as e_close_final:
                 print(f"Error closing websocket: {e_close_final}")
        print("Connection closed.")

# 8) Dev‑Start --------------------------------------------------------
if __name__ == "__main__":
    import uvicorn
    print("Starting Uvicorn server...")
    # Use reload=True only for development, remove for production
    uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info") #, reload=True)