File size: 9,013 Bytes
0b5b901
dbb4a9f
 
 
 
 
 
4189fe1
9bf14d0
dbb4a9f
 
 
 
 
 
 
 
 
 
 
 
0316ec3
e3958ab
479f253
 
dbb4a9f
641d199
479f253
2008a3f
dbb4a9f
e3958ab
 
dbb4a9f
 
641d199
 
 
 
dbb4a9f
 
 
 
e3958ab
641d199
 
 
 
 
dbb4a9f
479f253
dbb4a9f
641d199
 
e3958ab
dbb4a9f
641d199
 
 
dbb4a9f
e3958ab
dbb4a9f
 
641d199
dbb4a9f
 
a0cc672
dbb4a9f
 
 
a0cc672
e3958ab
dbb4a9f
 
 
0dfc310
dbb4a9f
 
 
 
 
 
 
 
641d199
dbb4a9f
641d199
dbb4a9f
 
 
 
 
641d199
 
 
 
 
 
dbb4a9f
 
 
 
641d199
 
dbb4a9f
 
 
 
 
 
 
641d199
 
 
dbb4a9f
 
 
 
 
641d199
dbb4a9f
641d199
 
 
 
 
 
 
 
 
dbb4a9f
 
 
641d199
 
 
 
 
 
 
 
dbb4a9f
 
641d199
 
 
 
 
 
 
 
dbb4a9f
 
 
 
 
 
 
 
 
 
 
 
 
 
641d199
dbb4a9f
 
 
 
 
 
 
 
 
 
 
 
 
641d199
 
 
 
 
 
dbb4a9f
 
 
 
 
 
 
641d199
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# app.py ──────────────────────────────────────────────────────────────
import os
import json
import torch
import asyncio
import traceback # Import traceback for better error logging

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList
# Import BaseStreamer for the interface
from transformers.generation.streamers import BaseStreamer
from snac import SNAC # Ensure you have 'pip install snac'

# --- Globals (populated in load_models) ---
tok = None
model = None
snac = None
masker = None
stopping_criteria = None
device = "cuda" if torch.cuda.is_available() else "cpu"

# 0) Login + Device ---------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    print("πŸ”‘ Logging in to Hugging Face Hub...")
    # Consider adding error handling for login failure if necessary
    login(HF_TOKEN)

# torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug

# 1) Konstanten -------------------------------------------------------
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
START_TOKEN = 128259
NEW_BLOCK = 128257 # Token indicating start of audio generation
EOS_TOKEN = 128258 # End Of Speech token
AUDIO_BASE = 128266 # Base ID for audio tokens
AUDIO_SPAN = 4096 * 7  # 7 codebooks * 4096 codes per book = 28672 possible audio tokens
# Create AUDIO_IDS on the correct device later in load_models
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)

# 2) Logit‑Mask -------------------------------------------------------
class AudioMask(LogitsProcessor):
    """
    Manages allowed tokens during generation.
    - Initially allows NEW_BLOCK and AUDIO tokens.
    - Allows EOS_TOKEN only after at least one audio block has been sent.
    """
    def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
        super().__init__()
        # Allow NEW_BLOCK and all valid audio tokens initially
        self.allow_initial = torch.cat([
            torch.tensor([new_block_token_id], device=audio_ids.device),
            audio_ids
        ], dim=0)
        self.eos = torch.tensor([eos_token_id], device=audio_ids.device)
        # Precompute combined tensor for allowing audio, NEW_BLOCK, and EOS
        self.allow_with_eos = torch.cat([self.allow_initial, self.eos], dim=0)
        self.sent_blocks = 0 # State: Number of audio blocks sent

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # Determine which tokens are allowed based on whether blocks have been sent
        current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow_initial

        # Create a mask initialized to negative infinity
        mask = torch.full_like(scores, float("-inf"))
        # Set allowed token scores to 0 (effectively allowing them)
        mask[:, current_allow] = 0
        # Apply the mask to the scores
        return scores + mask

    def reset(self):
        """Resets the state for a new generation request."""
        self.sent_blocks = 0

# 3) StoppingCriteria fΓΌr EOS ---------------------------------------
# generate() needs explicit stopping criteria when using a streamer
class EosStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_token_id: int):
        self.eos_token_id = eos_token_id

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # Check if the *last* generated token is the EOS token
        # Check input_ids shape to prevent index error on first token
        if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
            print("StoppingCriteria: EOS detected.")
            return True
        return False

# 4) Benutzerdefinierter AudioStreamer -------------------------------
class AudioStreamer(BaseStreamer):
    """
    Custom streamer to process audio tokens, decode them using SNAC,
    and send audio bytes over a WebSocket.
    """
    # Added target_device parameter
    def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str):
        self.ws = ws
        self.snac = snac_decoder
        self.masker = audio_mask # Reference to the mask to update sent_blocks
        self.loop = loop # Event loop of the main thread for run_coroutine_threadsafe
        # Use the passed target_device
        self.device = target_device
        self.buf: list[int] = [] # Buffer for audio token values (AUDIO_BASE subtracted)
        self.tasks = set() # Keep track of pending send tasks

    def _decode_block(self, block7: list[int]) -> bytes:
        """
        Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
        NOTE: The mapping from the 7 tokens to the 3 SNAC codebooks (l1, l2, l3)
              is CRITICAL and based on the structure used by the specific model.
              This implementation uses the mapping derived from the user's previous code.
              If audio is distorted, try the alternative mapping commented out below.
        """
        if len(block7) != 7:
            print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
            return b"" # Return empty bytes if block is incomplete

        # --- Mapping based on user's previous version ---
        try:
            l1 = [block7[0]]                         # Index 0
            l2 = [block7[1], block7[4]]             # Indices 1, 4
            l3 = [block7[2], block7[3], block7[5], block7[6]] # Indices 2, 3, 5, 6
            # --- Alternative Hypothesis Mapping (Try if above fails) ---
            # l1 = [block7[0], block7[3], block7[6]] # Indices 0, 3, 6
            # l2 = [block7[1], block7[4]]          # Indices 1, 4
            # l3 = [block7[2], block7[5]]          # Indices 2, 5
        except IndexError as e:
            print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}, Error: {e}")
            return b""

        # Convert lists to tensors on the correct device
        try:
            codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0)
            codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0)
            codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0)
            codes = [codes_l1, codes_l2, codes_l3] # List of tensors for SNAC
        except Exception as e:
             print(f"Streamer Error: Failed converting lists to tensors. Error: {e}")
             return b""

        # Decode using SNAC
        try:
            with torch.no_grad():
                # Ensure snac_decoder is on the correct device already (done via .to(device))
                audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim
        except Exception as e:
             print(f"Streamer Error: snac.decode failed. Input shapes: {[c.shape for c in codes]}. Error: {e}")
             return b""


        # Squeeze, move to CPU, convert to numpy
        audio_np = audio.squeeze().detach().cpu().numpy()

        # Convert to 16-bit PCM bytes
        audio_bytes = (audio_np * 32767).astype("int16").tobytes()
        return audio_bytes

    async def _send_audio_bytes(self, data: bytes):
        """Coroutine to send bytes over WebSocket."""
        if not data: # Don't send empty bytes
            return
        try:
            await self.ws.send_bytes(data)
            # print(f"Streamer: Sent {len(data)} audio bytes.") # Optional: Debug log
        except WebSocketDisconnect:
            print("Streamer: WebSocket disconnected during send.")
        except Exception as e:
            print(f"Streamer: Error sending bytes: {e}")

    def put(self, value: torch.LongTensor):
        """
        Receives new token IDs (Tensor) from generate() (runs in worker thread).
        Processes tokens, decodes full blocks, and schedules sending via run_coroutine_threadsafe.
        """
        # Ensure value is on CPU and flatten to a list of ints
        if value.numel() == 0:
            return
        # Handle potential shape issues, ensure it's iterable
        try:
            new_token_ids = value.view(-1).tolist()
        except Exception as e:
            print(f"Streamer Error: Could not process incoming tensor: {value}, Error: {e}")
            return

        for t in new_token_ids:
            if t == EOS_TOKEN:
                # print("Streamer: EOS token encountered.")
                # EOS is handled by StoppingCriteria, no action needed here except maybe logging.
                break # Stop processing this batch if EOS is found

            if t == NEW_BLOCK: